@@ -1029,7 +1029,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10291029 keep_in_fp32_modules = []
10301030
10311031 is_sharded = False
1032- resolved_archive_file = None
1032+ resolved_model_file = None
10331033
10341034 # Determine if we're loading from a directory of sharded checkpoints.
10351035 sharded_metadata = None
@@ -1064,7 +1064,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10641064
10651065 # load model
10661066 if from_flax :
1067- resolved_archive_file = _get_model_file (
1067+ resolved_model_file = _get_model_file (
10681068 pretrained_model_name_or_path ,
10691069 weights_name = FLAX_WEIGHTS_NAME ,
10701070 cache_dir = cache_dir ,
@@ -1082,11 +1082,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10821082 # Convert the weights
10831083 from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
10841084
1085- model = load_flax_checkpoint_in_pytorch_model (model , resolved_archive_file )
1085+ model = load_flax_checkpoint_in_pytorch_model (model , resolved_model_file )
10861086 else :
10871087 # in the case it is sharded, we have already the index
10881088 if is_sharded :
1089- resolved_archive_file , sharded_metadata = _get_checkpoint_shard_files (
1089+ resolved_model_file , sharded_metadata = _get_checkpoint_shard_files (
10901090 pretrained_model_name_or_path ,
10911091 index_file ,
10921092 cache_dir = cache_dir ,
@@ -1100,7 +1100,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11001100 )
11011101 elif use_safetensors :
11021102 try :
1103- resolved_archive_file = _get_model_file (
1103+ resolved_model_file = _get_model_file (
11041104 pretrained_model_name_or_path ,
11051105 weights_name = _add_variant (SAFETENSORS_WEIGHTS_NAME , variant ),
11061106 cache_dir = cache_dir ,
@@ -1123,8 +1123,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11231123 "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
11241124 )
11251125
1126- if resolved_archive_file is None and not is_sharded :
1127- resolved_archive_file = _get_model_file (
1126+ if resolved_model_file is None and not is_sharded :
1127+ resolved_model_file = _get_model_file (
11281128 pretrained_model_name_or_path ,
11291129 weights_name = _add_variant (WEIGHTS_NAME , variant ),
11301130 cache_dir = cache_dir ,
@@ -1139,8 +1139,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11391139 dduf_entries = dduf_entries ,
11401140 )
11411141
1142- if not isinstance (resolved_archive_file , list ):
1143- resolved_archive_file = [resolved_archive_file ]
1142+ if not isinstance (resolved_model_file , list ):
1143+ resolved_model_file = [resolved_model_file ]
11441144
11451145 # set dtype to instantiate the model under:
11461146 # 1. If torch_dtype is not None, we use that dtype
@@ -1168,7 +1168,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11681168 if not is_sharded :
11691169 # Time to load the checkpoint
11701170 state_dict = load_state_dict (
1171- resolved_archive_file [0 ], disable_mmap = disable_mmap , dduf_entries = dduf_entries
1171+ resolved_model_file [0 ], disable_mmap = disable_mmap , dduf_entries = dduf_entries
11721172 )
11731173 # We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
11741174 model ._fix_state_dict_keys_on_load (state_dict )
@@ -1200,7 +1200,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
12001200 ) = cls ._load_pretrained_model (
12011201 model ,
12021202 state_dict ,
1203- resolved_archive_file ,
1203+ resolved_model_file ,
12041204 pretrained_model_name_or_path ,
12051205 loaded_keys ,
12061206 ignore_mismatched_sizes = ignore_mismatched_sizes ,
@@ -1361,7 +1361,7 @@ def _load_pretrained_model(
13611361 cls ,
13621362 model ,
13631363 state_dict : OrderedDict ,
1364- resolved_archive_file : List [str ],
1364+ resolved_model_file : List [str ],
13651365 pretrained_model_name_or_path : Union [str , os .PathLike ],
13661366 loaded_keys : List [str ],
13671367 ignore_mismatched_sizes : bool = False ,
@@ -1415,13 +1415,13 @@ def _load_pretrained_model(
14151415
14161416 if state_dict is not None :
14171417 # load_state_dict will manage the case where we pass a dict instead of a file
1418- # if state dict is not None, it means that we don't need to read the files from resolved_archive_file also
1419- resolved_archive_file = [state_dict ]
1418+ # if state dict is not None, it means that we don't need to read the files from resolved_model_file also
1419+ resolved_model_file = [state_dict ]
14201420
1421- if len (resolved_archive_file ) > 1 :
1422- resolved_archive_file = logging .tqdm (resolved_archive_file , desc = "Loading checkpoint shards" )
1421+ if len (resolved_model_file ) > 1 :
1422+ resolved_model_file = logging .tqdm (resolved_model_file , desc = "Loading checkpoint shards" )
14231423
1424- for shard_file in resolved_archive_file :
1424+ for shard_file in resolved_model_file :
14251425 state_dict = load_state_dict (shard_file , dduf_entries = dduf_entries )
14261426
14271427 def _find_mismatched_keys (
0 commit comments