@@ -606,14 +606,22 @@ def _maybe_build_consolidated_index(
606606 fqn_to_file_index_mapping = get_fqn_to_file_index_mapping (
607607 index_path , getattr (model , "_checkpoint_conversion_mapping" , None )
608608 )
609- # some HF models like Moonlight-16B have non-persistent buffers in the base checkpoint
610- # however, HF initializes buffers with persistent=False, so we need to make sure these
611- # buffer keys are not saved during checkpointing
612- keys_to_remove = list (set (fqn_to_file_index_mapping .keys ()) - set (self .config .model_state_dict_keys ))
613- if model_state .is_tied_lm_head :
614- keys_to_remove .append (model_state .lm_head_param_name )
615- for key in keys_to_remove :
616- fqn_to_file_index_mapping .pop (key , None )
609+ model_part = model_state .model [0 ]
610+ config = getattr (model_part , "config" , None )
611+ model_type = getattr (config , "model_type" , None )
612+ if model_type and requires_tensor_merging (model_type ) and not hasattr (model_part , "state_dict_adapter" ):
613+ # in this case, Transformers performed weight conversion so we will save the converted format in the checkpoint
614+ num_shards = max (fqn_to_file_index_mapping .values ()) if fqn_to_file_index_mapping else 1
615+ fqn_to_file_index_mapping = _equally_divide_layers (num_shards , self .config .model_state_dict_keys )
616+ else :
617+ # some HF models like Moonlight-16B have non-persistent buffers in the base checkpoint
618+ # however, HF initializes buffers with persistent=False, so we need to make sure these
619+ # buffer keys are not saved during checkpointing
620+ keys_to_remove = list (set (fqn_to_file_index_mapping .keys ()) - set (self .config .model_state_dict_keys ))
621+ if model_state .is_tied_lm_head :
622+ keys_to_remove .append (model_state .lm_head_param_name )
623+ for key in keys_to_remove :
624+ fqn_to_file_index_mapping .pop (key , None )
617625 else :
618626 fqn_to_file_index_mapping = {k : 1 for k in state_dict .keys ()}
619627
@@ -1055,6 +1063,29 @@ def _maybe_adapt_state_dict_to_hf(
10551063 return state_dict
10561064
10571065
1066+ def _equally_divide_layers (num_shards : int , keys : list [str ]) -> dict [str , int ]:
1067+ """
1068+ Equally divide the state dict keys into num_shards shards.
1069+ """
1070+ if num_shards <= 0 :
1071+ raise ValueError (f"num_shards must be > 0, got { num_shards } " )
1072+
1073+ num_layers = len (keys )
1074+ if num_layers == 0 :
1075+ return {}
1076+
1077+ layers_per_shard , remainder = divmod (num_layers , num_shards )
1078+ fqn_to_index_mapping : dict [str , int ] = {}
1079+ start = 0
1080+ for shard_index in range (1 , num_shards + 1 ):
1081+ extra = 1 if shard_index <= remainder else 0
1082+ end = start + layers_per_shard + extra
1083+ for key in keys [start :end ]:
1084+ fqn_to_index_mapping [key ] = shard_index
1085+ start = end
1086+ return fqn_to_index_mapping
1087+
1088+
10581089def _maybe_adapt_state_dict_from_hf (
10591090 model_part : nn .Module , state_dict : dict [str , torch .Tensor ], moe_mesh : Optional [DeviceMesh ] = None
10601091) -> dict [str , torch .Tensor ]:
0 commit comments