Skip to content

Commit 6521f59

Browse files
committed
make sure modularpipeline from_pretrained works without modular_model_index
1 parent ceeb3c1 commit 6521f59

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ def fn_recursive_get_trigger(blocks):
792792
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
793793

794794
# If block has sub_blocks, recursively check them
795-
if hasattr(block, "sub_blocks"):
795+
if block.sub_blocks:
796796
nested_triggers = fn_recursive_get_trigger(block.sub_blocks)
797797
trigger_values.update(nested_triggers)
798798

@@ -1077,7 +1077,7 @@ def fn_recursive_get_trigger(blocks):
10771077
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
10781078

10791079
# If block has sub_blocks, recursively check them
1080-
if hasattr(block, "sub_blocks"):
1080+
if block.sub_blocks:
10811081
nested_triggers = fn_recursive_get_trigger(block.sub_blocks)
10821082
trigger_values.update(nested_triggers)
10831083

@@ -1098,7 +1098,7 @@ def fn_recursive_traverse(block, block_name, active_triggers):
10981098

10991099
# sequential(include loopsequential) or PipelineBlock
11001100
if not hasattr(block, "block_trigger_inputs"):
1101-
if hasattr(block, "sub_blocks"):
1101+
if block.sub_blocks:
11021102
# sequential or LoopSequentialPipelineBlocks (keep traversing)
11031103
for sub_block_name, sub_block in block.sub_blocks.items():
11041104
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
@@ -1128,7 +1128,7 @@ def fn_recursive_traverse(block, block_name, active_triggers):
11281128

11291129
if this_block is not None:
11301130
# sequential/auto (keep traversing)
1131-
if hasattr(this_block, "sub_blocks"):
1131+
if this_block.sub_blocks:
11321132
result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
11331133
else:
11341134
# PipelineBlock
@@ -1642,9 +1642,8 @@ def set_progress_bar_config(self, **kwargs):
16421642

16431643

16441644
# YiYi TODO:
1645-
# 1. move the modular_repo arg and the logic to fetch info from repo out of __init__ so that __init__ alwasy create an default modular_model_index config
1646-
# 2. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
1647-
# 3. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader
1645+
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
1646+
# 2. do we need ConfigSpec? the are basically just key/val kwargs
16481647
# 4. add validator for methods where we accpet kwargs to be passed to from_pretrained()
16491648
class ModularPipeline(ConfigMixin, PushToHubMixin):
16501649
"""
@@ -1927,8 +1926,12 @@ def from_pretrained(
19271926
"revision": revision,
19281927
}
19291928

1930-
config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
1931-
pipeline_class = _get_pipeline_class(cls, config=config_dict)
1929+
try:
1930+
config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
1931+
pipeline_class = _get_pipeline_class(cls, config=config_dict)
1932+
except EnvironmentError:
1933+
pipeline_class = cls
1934+
pretrained_model_name_or_path = None
19321935

19331936
pipeline = pipeline_class(
19341937
blocks=blocks,

0 commit comments

Comments
 (0)