Skip to content

Commit 3e61395

Browse files
authored
update z3_leaf_modules (#6082)
1 parent 35663af commit 3e61395

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

swift/llm/model/register.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def get_model_tokenizer_from_local(model_dir: str,
268268
torch_dtype = model_info.torch_dtype
269269
HfConfigFactory.set_config_attr(model_config, 'torch_dtype', torch_dtype, include_vit=True)
270270
HfConfigFactory.compat_zero3(model_config)
271+
leaf_modules = kwargs.get('leaf_modules')
271272
rope_scaling = kwargs.get('rope_scaling')
272273
max_model_len = kwargs.get('max_model_len')
273274
return_dummy_model = kwargs.get('return_dummy_model')
@@ -368,8 +369,9 @@ def get_model_tokenizer_from_local(model_dir: str,
368369
if model is not None:
369370
# fix seq classification task
370371
HfConfigFactory.set_model_config_attr(model, 'pad_token_id', pad_token)
371-
# deepspeed zero3
372-
deepspeed_set_z3_leaf_modules(model)
372+
if leaf_modules is not None or model_info.is_moe_model:
373+
# deepspeed zero3
374+
deepspeed_set_z3_leaf_modules(model)
373375

374376
return model, tokenizer
375377

0 commit comments

Comments
 (0)