Skip to content

Commit 59929a5

Browse files
committed
add generation_config case
1 parent 17d50d1 commit 59929a5

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/diffusers/pipelines/transformers_loading_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def load_transformers_model_from_dduf(
7070
raise EnvironmentError(
7171
f"Could not find a config.json file for component {name} in DDUF file (contains {dduf_entries.keys()})."
7272
)
73+
generation_config = dduf_entries.get(f"{name}/generation_config.json", None)
7374

7475
weight_files = [
7576
entry
@@ -86,13 +87,16 @@ def load_transformers_model_from_dduf(
8687
)
8788

8889
with tempfile.TemporaryDirectory() as tmp_dir:
90+
from transformers import AutoConfig, GenerationConfig
8991
tmp_config_file = os.path.join(tmp_dir, "config.json")
9092
with open(tmp_config_file, "w") as f:
9193
f.write(config_file.read_text())
92-
# TODO: I feel like it is easier if we pass the config file directly. Otherwise, if we pass
93-
# pretrained_model_name_or_path, we will need to do more checks in transformers.
94-
from transformers import AutoConfig
9594
config = AutoConfig.from_pretrained(tmp_config_file)
95+
if generation_config is not None:
96+
tmp_generation_config_file = os.path.join(tmp_generation_config_file, "generation_config.json")
97+
with open(tmp_generation_config_file, "w") as f:
98+
f.write(generation_config.read_text())
99+
generation_config = GenerationConfig.from_pretrained(tmp_config_file)
96100
state_dict = {}
97101
with contextlib.ExitStack() as stack:
98102
for entry in tqdm(weight_files, desc="Loading state_dict"): # Loop over safetensors files
@@ -103,5 +107,5 @@ def load_transformers_model_from_dduf(
103107
# Update the state dictionary with tensors
104108
state_dict.update(tensors)
105109
return cls.from_pretrained(
106-
pretrained_model_name_or_path=None, config=config, state_dict=state_dict, **kwargs
107-
)
110+
pretrained_model_name_or_path=None, config=config, generation_config=generation_config, state_dict=state_dict, **kwargs
111+
)

0 commit comments

Comments
 (0)