Skip to content

Commit b79dcb8

Browse files
committed
style
1 parent 2c45daf commit b79dcb8

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,6 +1878,11 @@ def __init__(
18781878
if pretrained_model_name_or_path is not None:
18791879
try:
18801880
config_dict = self.load_config(pretrained_model_name_or_path, **kwargs)
1881+
except EnvironmentError as e:
1882+
logger.debug(f"modular_model_index.json not found: {e}")
1883+
config_dict = None
1884+
1885+
if config_dict is not None:
18811886
for name, value in config_dict.items():
18821887
# all the components in modular_model_index.json are from_pretrained components
18831888
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
@@ -1889,12 +1894,11 @@ def __init__(
18891894
elif name in self._config_specs:
18901895
self._config_specs[name].default = value
18911896

1892-
except EnvironmentError as e:
1893-
logger.debug(e)
1894-
logger.debug(" modular_model_index.json not found in the repo, trying to load from model_index.json")
1897+
else:
1898+
logger.debug(" loading config from model_index.json")
18951899
from diffusers import DiffusionPipeline
18961900

1897-
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path)
1901+
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **kwargs)
18981902
for name, value in config_dict.items():
18991903
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
19001904
library, class_name = value
@@ -2094,10 +2098,23 @@ def from_pretrained(
20942098

20952099
try:
20962100
config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
2101+
except EnvironmentError as e:
2102+
logger.debug(f" modular_model_index.json not found in the repo: {e}")
2103+
config_dict = None
2104+
2105+
if config_dict is not None:
20972106
pipeline_class = _get_pipeline_class(cls, config=config_dict)
2098-
except EnvironmentError:
2099-
pipeline_class = cls
2100-
pretrained_model_name_or_path = None
2107+
else:
2108+
logger.debug(" determining the modular pipeline class from model_index.json")
2109+
from diffusers import DiffusionPipeline
2110+
from diffusers.pipelines.auto_pipeline import _get_model
2111+
2112+
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
2113+
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
2114+
model_name = _get_model(standard_pipeline_class.__name__)
2115+
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
2116+
diffusers_module = importlib.import_module("diffusers")
2117+
pipeline_class = getattr(diffusers_module, pipeline_class_name)
21012118

21022119
pipeline = pipeline_class(
21032120
blocks=blocks,

0 commit comments

Comments
 (0)