@@ -1089,13 +1089,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10891089            state_dict  =  load_state_dict (
10901090                resolved_archive_file [0 ], disable_mmap = disable_mmap , dduf_entries = dduf_entries 
10911091            )
1092+             # We only fix it for non sharded checkpoints as we don't need it yet for sharded one. 
1093+             model ._fix_state_dict_keys_on_load (state_dict )
10921094
10931095        if  is_sharded :
10941096            loaded_keys  =  sharded_metadata ["all_checkpoint_keys" ]
10951097        else :
10961098            loaded_keys  =  list (state_dict .keys ())
1097-         # TODO: hacky solution 
1098-         loaded_keys  =  list (model ._fix_state_dict_keys_on_load ({key : ""  for  key  in  loaded_keys }))
10991099
11001100        if  hf_quantizer  is  not None :
11011101            hf_quantizer .preprocess_model (
@@ -1305,7 +1305,6 @@ def _load_pretrained_model(
13051305
13061306        for  shard_file  in  resolved_archive_file :
13071307            state_dict  =  load_state_dict (shard_file , dduf_entries = dduf_entries )
1308-             model ._fix_state_dict_keys_on_load (state_dict )
13091308
13101309            def  _find_mismatched_keys (
13111310                state_dict ,
@@ -1578,7 +1577,8 @@ def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None:
15781577        """ 
15791578        This function fix the state dict of the model to take into account some changes that were made in the model 
15801579        architecture: 
1581-         - depretated attention blocks 
1580+         - deprecated attention blocks (happened before we introduced sharded checkpoint, 
1581+         so this is why we apply this method only when loading non sharded checkpoints for now) 
15821582        """ 
15831583        deprecated_attention_block_paths  =  []
15841584
0 commit comments