4747
4848if TYPE_CHECKING :
4949 from peft import PeftConfig
50+
5051 from transformers .tokenization_utils_base import PreTrainedTokenizerBase
5152
5253
@@ -313,9 +314,7 @@ def load_model(
313314
314315 # For models that need tensor merging and don't have an adapter, try using transformers' conversion
315316 if is_init_step and model_type and requires_tensor_merging (model_type ) and not has_state_dict_adapter :
316- converted_state_dict = _convert_checkpoint_with_transformers (
317- model_state .model [0 ], model_path , key_mapping
318- )
317+ converted_state_dict = _convert_checkpoint_with_transformers (model_state .model [0 ], model_path , key_mapping )
319318 if converted_state_dict is not None :
320319 # Load using full_state_dict=True to properly convert tensors to DTensors for FSDP
321320 _load_full_state_dict_into_model (model_state .model , converted_state_dict )
@@ -887,7 +886,8 @@ def _load_full_state_dict_into_model(
887886 state_dict: Full state dict with regular tensors
888887 """
889888 from functools import partial
890- from torch .distributed .checkpoint .state_dict import set_model_state_dict , StateDictOptions
889+
890+ from torch .distributed .checkpoint .state_dict import StateDictOptions , set_model_state_dict
891891
892892 # Use full_state_dict=True to tell PyTorch this is a complete, non-sharded state dict
893893 # It will properly shard the tensors to match the model's DTensor layout
@@ -925,16 +925,17 @@ def _convert_checkpoint_with_transformers(
925925 Converted state dict ready for loading, or None if conversion failed.
926926 """
927927 try :
928+ from copy import deepcopy
929+
930+ from safetensors import safe_open
931+
928932 from transformers .conversion_mapping import get_model_conversion_mapping
929933 from transformers .core_model_loading import (
930934 WeightConverter ,
931935 WeightRenaming ,
932- rename_source_key ,
933936 dot_natural_key ,
937+ rename_source_key ,
934938 )
935- from safetensors import safe_open
936- from copy import deepcopy
937- from collections import defaultdict
938939 except ImportError :
939940 logging .warning (
940941 "transformers library with conversion_mapping not available. "
@@ -946,7 +947,9 @@ def _convert_checkpoint_with_transformers(
946947 # Get the weight conversion mapping from transformers
947948 weight_mapping = get_model_conversion_mapping (model , key_mapping = key_mapping , add_legacy = True )
948949 if not weight_mapping :
949- logging .warning (f"No conversion mapping found for model type { getattr (model .config , 'model_type' , 'unknown' )} " )
950+ logging .warning (
951+ f"No conversion mapping found for model type { getattr (model .config , 'model_type' , 'unknown' )} "
952+ )
950953 return None
951954
952955 # Load the safetensors files
@@ -962,9 +965,6 @@ def _convert_checkpoint_with_transformers(
962965 for key in f .keys ():
963966 checkpoint_state_dict [key ] = f .get_tensor (key )
964967
965- # Get model's expected keys
966- model_state_dict = model .state_dict ()
967-
968968 # Separate renamings and converters
969969 renamings = [entry for entry in weight_mapping if isinstance (entry , WeightRenaming )]
970970 converters = [entry for entry in weight_mapping if isinstance (entry , WeightConverter )]
@@ -1010,6 +1010,7 @@ def _convert_checkpoint_with_transformers(
10101010 except Exception as e :
10111011 logging .warning (f"Failed to convert checkpoint with transformers: { e } " )
10121012 import traceback
1013+
10131014 traceback .print_exc ()
10141015 return None
10151016
0 commit comments