Skip to content

Commit 58dfcf9

Browse files
committed
update
1 parent d032408 commit 58dfcf9

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,7 +1409,7 @@ def set_progress_bar_config(self, **kwargs):
14091409
# YiYi TODO:
14101410
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
14111411
# 2. do we need ConfigSpec? the are basically just key/val kwargs
1412-
# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_default_components(), load_components()
1412+
# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_components()
14131413
class ModularPipeline(ConfigMixin, PushToHubMixin):
14141414
"""
14151415
Base class for all Modular pipelines.
@@ -1478,7 +1478,7 @@ def __init__(
14781478
- Components with default_creation_method="from_config" are created immediately, its specs are not included
14791479
in config dict and will not be saved in `modular_model_index.json`
14801480
- Components with default_creation_method="from_pretrained" are set to None and can be loaded later with
1481-
`load_default_components()`/`load_components()`
1481+
`load_components()` (with or without specific component names)
14821482
- The pipeline's config dict is populated with component specs (only for from_pretrained components) and
14831483
config values, which will be saved as `modular_model_index.json` during `save_pretrained`
14841484
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
@@ -1548,12 +1548,9 @@ def load_default_components(self, **kwargs):
15481548
Args:
15491549
**kwargs: Additional arguments passed to `from_pretrained` method, e.g. torch_dtype, cache_dir, etc.
15501550
"""
1551-
names = [
1552-
name
1553-
for name in self._component_specs.keys()
1554-
if self._component_specs[name].default_creation_method == "from_pretrained"
1555-
]
1556-
self.load_components(names=names, **kwargs)
1551+
# Consolidated into load_components - just call it without names parameter
1552+
logger.warning("`load_default_components` is deprecated. Please use `load_components()` instead")
1553+
self.load_components(**kwargs)
15571554

15581555
@classmethod
15591556
@validate_hf_hub_args
@@ -1682,8 +1679,8 @@ def register_components(self, **kwargs):
16821679
- non from_pretrained components are created during __init__ and registered as the object itself
16831680
- Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or
16841681
loader.update_components(guider=guider_spec)
1685-
- (from_pretrained) Components are loaded with the `load_default_components()` method: e.g.
1686-
loader.load_default_components(names=["unet"])
1682+
- (from_pretrained) Components are loaded with the `load_components()` method: e.g.
1683+
loader.load_components(names=["unet"]) or loader.load_components() to load all default components
16871684
16881685
Args:
16891686
**kwargs: Keyword arguments where keys are component names and values are component objects.
@@ -1995,21 +1992,28 @@ def update_components(self, **kwargs):
19951992
self.register_to_config(**config_to_register)
19961993

19971994
# YiYi TODO: support map for additional from_pretrained kwargs
1998-
# YiYi/Dhruv TODO: consolidate load_components and load_default_components?
1999-
def load_components(self, names: Union[List[str], str], **kwargs):
1995+
def load_components(self, names: Optional[Union[List[str], str]] = None, **kwargs):
20001996
"""
20011997
Load selected components from specs.
20021998
20031999
Args:
2004-
names: List of component names to load; by default will not load any components
2000+
names: List of component names to load. If None, will load all components with
2001+
default_creation_method == "from_pretrained". If provided as a list or string,
2002+
will load only the specified components.
20052003
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
20062004
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
20072005
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
20082006
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`,
20092007
`variant`, `revision`, etc.
20102008
"""
20112009

2012-
if isinstance(names, str):
2010+
if names is None:
2011+
names = [
2012+
name
2013+
for name in self._component_specs.keys()
2014+
if self._component_specs[name].default_creation_method == "from_pretrained"
2015+
]
2016+
elif isinstance(names, str):
20132017
names = [names]
20142018
elif not isinstance(names, list):
20152019
raise ValueError(f"Invalid type for names: {type(names)}")

0 commit comments

Comments
 (0)