@@ -70,6 +70,7 @@ def load_transformers_model_from_dduf(
7070        raise  EnvironmentError (
7171            f"Could not find a config.json file for component { name } { dduf_entries .keys ()}  
7272        )
73+     generation_config  =  dduf_entries .get (f"{ name }  , None )
7374
7475    weight_files  =  [
7576        entry 
@@ -86,13 +87,16 @@ def load_transformers_model_from_dduf(
8687        )
8788
8889    with  tempfile .TemporaryDirectory () as  tmp_dir :
90+         from  transformers  import  AutoConfig , GenerationConfig 
8991        tmp_config_file  =  os .path .join (tmp_dir , "config.json" )
9092        with  open (tmp_config_file , "w" ) as  f :
9193            f .write (config_file .read_text ())
92-         # TODO: I feel like it is easier if we pass the config file directly. Otherwise, if we pass  
93-         # pretrained_model_name_or_path, we will need to do more checks in transformers.  
94-         from  transformers  import  AutoConfig 
9594        config  =  AutoConfig .from_pretrained (tmp_config_file )
95+         if  generation_config  is  not None :
96+             tmp_generation_config_file  =  os .path .join (tmp_generation_config_file , "generation_config.json" )
97+             with  open (tmp_generation_config_file , "w" ) as  f :
98+                 f .write (generation_config .read_text ())
99+             generation_config  =  GenerationConfig .from_pretrained (tmp_config_file )
96100        state_dict  =  {}
97101        with  contextlib .ExitStack () as  stack :
98102            for  entry  in  tqdm (weight_files , desc = "Loading state_dict" ):  # Loop over safetensors files 
@@ -103,5 +107,5 @@ def load_transformers_model_from_dduf(
103107                # Update the state dictionary with tensors 
104108                state_dict .update (tensors )
105109            return  cls .from_pretrained (
106-                 pretrained_model_name_or_path = None , config = config , state_dict = state_dict , ** kwargs 
107-             )
110+                 pretrained_model_name_or_path = None , config = config , generation_config = generation_config ,  state_dict = state_dict , ** kwargs 
111+                  )
0 commit comments