Skip to content

Commit d690f7f

Browse files
authored
Merge pull request #26 from Btlmd/main
Add an option to override base model path
2 parents 6264fe8 + 68b5040 commit d690f7f

File tree

4 files changed

+9
-1
lines changed

4 files changed

+9
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ You can use the following command for launching a CLI interface:
9999
```bash
100100
CUDA_VISIBLE_DEVICES=0 python -m medusa.inference.cli --model [path of medusa model]
101101
```
102-
You can also pass `--load-in-8bit` or `--load-in-4bit` to load the base model in quantized format.
102+
You can also pass `--load-in-8bit` or `--load-in-4bit` to load the base model in quantized format. If you download the base model elsewhere, you may override base model name or path with `--base-model [path of base model]`.
103103

104104
### Training
105105
For training, please install:

medusa/inference/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def main(args):
3636
try:
3737
model = MedusaModel.from_pretrained(
3838
args.model,
39+
args.base_model,
3940
torch_dtype=torch.float16,
4041
low_cpu_mem_usage=True,
4142
device_map="auto",
@@ -185,6 +186,7 @@ def reload_conv(conv):
185186
if __name__ == "__main__":
186187
parser = argparse.ArgumentParser()
187188
parser.add_argument("--model", type=str, required=True, help="Model name or path.")
189+
parser.add_argument("--base-model", type=str, default=None, help="Base model name or path.")
188190
parser.add_argument(
189191
"--load-in-8bit", action="store_true", help="Use 8-bit quantization"
190192
)

medusa/model/medusa_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def get_tokenizer(self):
125125
def from_pretrained(
126126
cls,
127127
medusa_head_name_or_path,
128+
base_model=None,
128129
**kwargs,
129130
):
130131
"""
@@ -136,6 +137,10 @@ def from_pretrained(
136137
MedusaModel: A MedusaModel instance loaded from the given path.
137138
"""
138139
medusa_config = MedusaConfig.from_pretrained(medusa_head_name_or_path)
140+
if base_model:
141+
print("Overriding base model as:", base_model)
142+
medusa_config.base_model_name_or_path = base_model
143+
139144
base_model = KVLlamaForCausalLM.from_pretrained(
140145
medusa_config.base_model_name_or_path, **kwargs
141146
)

medusa/train/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def train():
392392
model,
393393
medusa_num_heads=training_args.medusa_num_heads,
394394
medusa_num_layers=training_args.medusa_num_layers,
395+
base_model_name_or_path=model_args.model_name_or_path,
395396
)
396397

397398
# Format output dir

0 commit comments

Comments
 (0)