Skip to content

Commit d5f20d0

Browse files
authored
fix issue in scripts (#3456)
1 parent 5b94f4f commit d5f20d0

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

examples/cpu/llm/inference/distributed/run_generation_with_deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,8 @@ def get_checkpoint_files(model_name_or_path):
328328
model_type = next((x for x in MODEL_CLASSES.keys() if x in model_name.lower()), "auto")
329329
if model_type == "llama" and args.vision_text_model:
330330
model_type = "mllama"
331-
if model_type == "maira-2":
332-
model_type = "maira2"
331+
if model_type in ["maira-2", "deepseek-v2", "deepseek-v3"]:
332+
model_type = model_type.replace("-", "")
333333
model_class = MODEL_CLASSES[model_type]
334334
tokenizer = model_class[1].from_pretrained(model_name, trust_remote_code=True)
335335

examples/cpu/llm/inference/single_instance/run_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@
140140
)
141141
if model_type == "llama" and args.vision_text_model:
142142
model_type = "mllama"
143-
if model_type == "maira-2":
144-
model_type = "maira2"
143+
if model_type in ["maira-2", "deepseek-v2", "deepseek-v3"]:
144+
model_type = model_type.replace("-", "")
145145
model_class = MODEL_CLASSES[model_type]
146146
if args.config_file is None:
147147
if model_type == "chatglm":

examples/cpu/llm/inference/utils/create_shard_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353
)
5454
if model_type == "llama" and args.vision_text_model:
5555
model_type = "mllama"
56-
if model_type == "maira-2":
57-
model_type = "maira2"
56+
if model_type in ["maira-2", "deepseek-v2", "deepseek-v3"]:
57+
model_type = model_type.replace("-", "")
5858
model_class = MODEL_CLASSES[model_type]
5959
load_dtype = torch.float32
6060
if args.dtype == "float16":
@@ -83,7 +83,7 @@
8383
tokenizer.save_pretrained(save_directory=args.save_path)
8484
if model_type == "llava":
8585
image_processor.save_pretrained(save_directory=args.save_path)
86-
if model_type == "maira2":
86+
if model_type in ["maira2", "deepseekv2", "deepseekv3"]:
8787
import inspect
8888
import shutil
8989

examples/cpu/llm/inference/utils/supported_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
"jamba": (AutoModelForCausalLM, AutoTokenizer),
3838
"deepseek-v2": (AutoModelForCausalLM, AutoTokenizer),
3939
"deepseek-v3": (AutoModelForCausalLM, AutoTokenizer),
40+
"deepseekv2": (AutoModelForCausalLM, AutoTokenizer),
41+
"deepseekv3": (AutoModelForCausalLM, AutoTokenizer),
4042
"auto": (AutoModelForCausalLM, AutoTokenizer),
4143
}
4244

0 commit comments

Comments
 (0)