2020from  huggingface_hub  import  DDUFEntry 
2121from  tqdm  import  tqdm 
2222
23- from  ..utils  import  is_safetensors_available , is_transformers_version 
23+ from  ..utils  import  is_safetensors_available , is_transformers_available ,  is_transformers_version 
2424
2525
2626if  TYPE_CHECKING :
@@ -93,15 +93,16 @@ def load_transformers_model_from_dduf(
9393
9494    with  tempfile .TemporaryDirectory () as  tmp_dir :
9595        from  transformers  import  AutoConfig , GenerationConfig 
96+ 
9697        tmp_config_file  =  os .path .join (tmp_dir , "config.json" )
9798        with  open (tmp_config_file , "w" ) as  f :
9899            f .write (config_file .read_text ())
99100        config  =  AutoConfig .from_pretrained (tmp_config_file )
100101        if  generation_config  is  not None :
101-             tmp_generation_config_file  =  os .path .join (tmp_generation_config_file , "generation_config.json" )
102+             tmp_generation_config_file  =  os .path .join (tmp_dir , "generation_config.json" )
102103            with  open (tmp_generation_config_file , "w" ) as  f :
103104                f .write (generation_config .read_text ())
104-             generation_config  =  GenerationConfig .from_pretrained (tmp_config_file )
105+             generation_config  =  GenerationConfig .from_pretrained (tmp_generation_config_file )
105106        state_dict  =  {}
106107        with  contextlib .ExitStack () as  stack :
107108            for  entry  in  tqdm (weight_files , desc = "Loading state_dict" ):  # Loop over safetensors files 
@@ -112,5 +113,9 @@ def load_transformers_model_from_dduf(
112113                # Update the state dictionary with tensors 
113114                state_dict .update (tensors )
114115            return  cls .from_pretrained (
115-                 pretrained_model_name_or_path = None , config = config , generation_config = generation_config , state_dict = state_dict , ** kwargs 
116-                 )
116+                 pretrained_model_name_or_path = None ,
117+                 config = config ,
118+                 generation_config = generation_config ,
119+                 state_dict = state_dict ,
120+                 ** kwargs ,
121+             )
0 commit comments