Skip to content

Commit e5cdff3

Browse files
committed
add base model override
1 parent c4e0c63 commit e5cdff3

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ You can use the following command for launching a CLI interface:
9999
```bash
100100
CUDA_VISIBLE_DEVICES=0 python -m medusa.inference.cli --model [path of medusa model]
101101
```
102-
You can also pass `--load-in-8bit` or `--load-in-4bit` to load the base model in quantized format.
102+
You can also pass `--load-in-8bit` or `--load-in-4bit` to load the base model in quantized format. If you download the base model elsewhere, you may override base model name or path with `--base-model [path of base model]`.
103103

104104
### Training
105105
For training, please install:

medusa/inference/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def main(args):
3636
try:
3737
model = MedusaModel.from_pretrained(
3838
args.model,
39+
args.base_model,
3940
torch_dtype=torch.float16,
4041
low_cpu_mem_usage=True,
4142
device_map="auto",
@@ -185,6 +186,7 @@ def reload_conv(conv):
185186
if __name__ == "__main__":
186187
parser = argparse.ArgumentParser()
187188
parser.add_argument("--model", type=str, required=True, help="Model name or path.")
189+
parser.add_argument("--base-model", type=str, default=None, help="Base model name or path.")
188190
parser.add_argument(
189191
"--load-in-8bit", action="store_true", help="Use 8-bit quantization"
190192
)

medusa/model/medusa_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def get_tokenizer(self):
110110
def from_pretrained(
111111
cls,
112112
medusa_head_name_or_path,
113+
base_model=None,
113114
**kwargs,
114115
):
115116
"""
@@ -121,6 +122,10 @@ def from_pretrained(
121122
MedusaModel: A MedusaModel instance loaded from the given path.
122123
"""
123124
medusa_config = MedusaConfig.from_pretrained(medusa_head_name_or_path)
125+
if base_model:
126+
print("Overriding base model as:", base_model)
127+
medusa_config.base_model_name_or_path = base_model
128+
124129
base_model = KVLlamaForCausalLM.from_pretrained(
125130
medusa_config.base_model_name_or_path, **kwargs
126131
)

0 commit comments

Comments
 (0)