Skip to content

Commit a1c72f6

Browse files
committed
up
1 parent b79dcb8 commit a1c72f6

File tree

1 file changed

+73
-31
lines changed

1 file changed

+73
-31
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,9 +1825,10 @@ def __init__(
18251825
Args:
18261826
blocks: `ModularPipelineBlocks` instance. If None, will attempt to load
18271827
default blocks based on the pipeline class name.
1828-
pretrained_model_name_or_path: Path to a pretrained pipeline configuration. If provided,
1829-
will load component specs (only for from_pretrained components) and config values from the saved
1830-
modular_model_index.json file.
1828+
pretrained_model_name_or_path: Path to a pretrained pipeline configuration. Can be None if the pipeline
1829+
does not require any additional loading config. If provided, will first try to load component specs
1830+
(only for from_pretrained components) and config values from `modular_model_index.json`, then
1831+
fallback to `model_index.json` for compatibility with standard non-modular repositories.
18311832
components_manager:
18321833
Optional ComponentsManager for managing multiple component cross different pipelines and apply
18331834
offloading strategies.
@@ -1876,12 +1877,29 @@ def __init__(
18761877

18771878
# update component_specs and config_specs from modular_repo
18781879
if pretrained_model_name_or_path is not None:
1880+
cache_dir = kwargs.pop("cache_dir", None)
1881+
force_download = kwargs.pop("force_download", False)
1882+
proxies = kwargs.pop("proxies", None)
1883+
token = kwargs.pop("token", None)
1884+
local_files_only = kwargs.pop("local_files_only", False)
1885+
revision = kwargs.pop("revision", None)
1886+
1887+
load_config_kwargs = {
1888+
"cache_dir": cache_dir,
1889+
"force_download": force_download,
1890+
"proxies": proxies,
1891+
"token": token,
1892+
"local_files_only": local_files_only,
1893+
"revision": revision,
1894+
}
1895+
# try to load modular_model_index.json
18791896
try:
1880-
config_dict = self.load_config(pretrained_model_name_or_path, **kwargs)
1897+
config_dict = self.load_config(pretrained_model_name_or_path, **load_config_kwargs)
18811898
except EnvironmentError as e:
18821899
logger.debug(f"modular_model_index.json not found: {e}")
18831900
config_dict = None
18841901

1902+
# update component_specs and config_specs based on modular_model_index.json
18851903
if config_dict is not None:
18861904
for name, value in config_dict.items():
18871905
# all the components in modular_model_index.json are from_pretrained components
@@ -1894,24 +1912,35 @@ def __init__(
18941912
elif name in self._config_specs:
18951913
self._config_specs[name].default = value
18961914

1915+
# if modular_model_index.json is not found, try to load model_index.json
18971916
else:
18981917
logger.debug(" loading config from model_index.json")
1899-
from diffusers import DiffusionPipeline
1918+
try:
1919+
from diffusers import DiffusionPipeline
1920+
1921+
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
1922+
except EnvironmentError as e:
1923+
logger.debug(f" model_index.json not found in the repo: {e}")
1924+
config_dict = None
1925+
1926+
# update component_specs and config_specs based on model_index.json
1927+
if config_dict is not None:
1928+
for name, value in config_dict.items():
1929+
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
1930+
library, class_name = value
1931+
component_spec_dict = {
1932+
"repo": pretrained_model_name_or_path,
1933+
"subfolder": name,
1934+
"type_hint": (library, class_name),
1935+
}
1936+
component_spec = self._dict_to_component_spec(name, component_spec_dict)
1937+
component_spec.default_creation_method = "from_pretrained"
1938+
self._component_specs[name] = component_spec
1939+
elif name in self._config_specs:
1940+
self._config_specs[name].default = value
19001941

1901-
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **kwargs)
1902-
for name, value in config_dict.items():
1903-
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
1904-
library, class_name = value
1905-
component_spec_dict = {
1906-
"repo": pretrained_model_name_or_path,
1907-
"subfolder": name,
1908-
"type_hint": (library, class_name),
1909-
}
1910-
component_spec = self._dict_to_component_spec(name, component_spec_dict)
1911-
component_spec.default_creation_method = "from_pretrained"
1912-
self._component_specs[name] = component_spec
1913-
elif name in self._config_specs:
1914-
self._config_specs[name].default = value
1942+
if len(kwargs) > 0:
1943+
logger.warning(f"Unexpected input '{kwargs.keys()}' provided. This input will be ignored.")
19151944

19161945
register_components_dict = {}
19171946
for name, component_spec in self._component_specs.items():
@@ -2060,8 +2089,10 @@ def from_pretrained(
20602089
20612090
Args:
20622091
pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
2063-
Path to a pretrained pipeline configuration. If provided, will load component specs (only for
2064-
from_pretrained components) and config values from the modular_model_index.json file.
2092+
Path to a pretrained pipeline configuration. It will first try to load config from
2093+
`modular_model_index.json`, then fallback to `model_index.json` for compatibility with standard
2094+
non-modular repositories. If the repo does not contain any pipeline config, it will be set to None
2095+
during initialization.
20652096
trust_remote_code (`bool`, optional):
20662097
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
20672098
pipeline blocks based on the custom code in `pretrained_model_name_or_path`
@@ -2097,6 +2128,7 @@ def from_pretrained(
20972128
}
20982129

20992130
try:
2131+
# try to load modular_model_index.json
21002132
config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
21012133
except EnvironmentError as e:
21022134
logger.debug(f" modular_model_index.json not found in the repo: {e}")
@@ -2105,16 +2137,26 @@ def from_pretrained(
21052137
if config_dict is not None:
21062138
pipeline_class = _get_pipeline_class(cls, config=config_dict)
21072139
else:
2108-
logger.debug(" determining the modular pipeline class from model_index.json")
2109-
from diffusers import DiffusionPipeline
2110-
from diffusers.pipelines.auto_pipeline import _get_model
2111-
2112-
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
2113-
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
2114-
model_name = _get_model(standard_pipeline_class.__name__)
2115-
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
2116-
diffusers_module = importlib.import_module("diffusers")
2117-
pipeline_class = getattr(diffusers_module, pipeline_class_name)
2140+
try:
2141+
logger.debug(" try to load model_index.json")
2142+
from diffusers import DiffusionPipeline
2143+
from diffusers.pipelines.auto_pipeline import _get_model
2144+
2145+
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
2146+
except EnvironmentError as e:
2147+
logger.debug(f" model_index.json not found in the repo: {e}")
2148+
2149+
if config_dict is not None:
2150+
logger.debug(" try to determine the modular pipeline class from model_index.json")
2151+
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
2152+
model_name = _get_model(standard_pipeline_class.__name__)
2153+
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
2154+
diffusers_module = importlib.import_module("diffusers")
2155+
pipeline_class = getattr(diffusers_module, pipeline_class_name)
2156+
else:
2157+
# there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components
2158+
pipeline_class = cls
2159+
pretrained_model_name_or_path = None
21182160

21192161
pipeline = pipeline_class(
21202162
blocks=blocks,

0 commit comments

Comments
 (0)