1919from collections import OrderedDict
2020from copy import deepcopy
2121from dataclasses import dataclass , field
22- from types import SimpleNamespace
2322from typing import Any , Dict , List , Optional , Tuple , Union
2423
2524import torch
@@ -343,7 +342,7 @@ def init_pipeline(
343342 pipeline_class = getattr (diffusers_module , pipeline_class_name )
344343
345344 modular_pipeline = pipeline_class (
346- blocks = deepcopy (self ),
345+ blocks = deepcopy (self ),
347346 pretrained_model_name_or_path = pretrained_model_name_or_path ,
348347 components_manager = components_manager ,
349348 collection = collection ,
@@ -1686,11 +1685,7 @@ def __init__(
16861685
16871686 for name , value in config_dict .items ():
16881687 # all the components in modular_model_index.json are from_pretrained components
1689- if (
1690- name in self ._component_specs
1691- and isinstance (value , (tuple , list ))
1692- and len (value ) == 3
1693- ):
1688+ if name in self ._component_specs and isinstance (value , (tuple , list )) and len (value ) == 3 :
16941689 library , class_name , component_spec_dict = value
16951690 component_spec = self ._dict_to_component_spec (name , component_spec_dict )
16961691 component_spec .default_creation_method = "from_pretrained"
@@ -1794,15 +1789,16 @@ def from_pretrained(
17941789 components_manager : Optional [ComponentsManager ] = None ,
17951790 collection : Optional [str ] = None ,
17961791 ** kwargs ,
1797- ):
1792+ ):
17981793 from ..pipelines .pipeline_loading_utils import _get_pipeline_class
1794+
17991795 try :
18001796 blocks = ModularPipelineBlocks .from_pretrained (
18011797 pretrained_model_name_or_path , trust_remote_code = trust_remote_code , ** kwargs
18021798 )
18031799 except EnvironmentError :
18041800 blocks = None
1805-
1801+
18061802 cache_dir = kwargs .pop ("cache_dir" , None )
18071803 force_download = kwargs .pop ("force_download" , False )
18081804 proxies = kwargs .pop ("proxies" , None )
@@ -1818,33 +1814,29 @@ def from_pretrained(
18181814 "local_files_only" : local_files_only ,
18191815 "revision" : revision ,
18201816 }
1821-
1822- config_dict = cls .load_config (pretrained_model_name_or_path , ** kwargs )
1817+
1818+ config_dict = cls .load_config (pretrained_model_name_or_path , ** load_config_kwargs )
18231819 pipeline_class = _get_pipeline_class (cls , config = config_dict )
18241820
18251821 pipeline = pipeline_class (
1826- blocks = blocks ,
1827- pretrained_model_name_or_path = pretrained_model_name_or_path ,
1828- components_manager = components_manager ,
1829- collection = collection ,
1830- ** kwargs
1822+ blocks = blocks ,
1823+ pretrained_model_name_or_path = pretrained_model_name_or_path ,
1824+ components_manager = components_manager ,
1825+ collection = collection ,
1826+ ** kwargs ,
18311827 )
18321828 return pipeline
18331829
18341830 # YiYi TODO:
18351831 # 1. should support save some components too! currently only modular_model_index.json is saved
18361832 # 2. maybe order the json file to make it more readable: configs first, then components
1837- def save_pretrained (
1838- self , save_directory : Union [str , os .PathLike ], push_to_hub : bool = False , ** kwargs
1839- ):
1840-
1833+ def save_pretrained (self , save_directory : Union [str , os .PathLike ], push_to_hub : bool = False , ** kwargs ):
18411834 self .save_config (save_directory = save_directory , push_to_hub = push_to_hub , ** kwargs )
18421835
18431836 @property
18441837 def doc (self ):
18451838 return self .blocks .doc
18461839
1847-
18481840 def register_components (self , ** kwargs ):
18491841 """
18501842 Register components with their corresponding specifications.
@@ -1868,7 +1860,8 @@ def register_components(self, **kwargs):
18681860
18691861 Notes:
18701862 - Components must be created from ComponentSpec (have _diffusers_load_id attribute)
1871- - When registering None for a component, it sets attribute to None but still syncs specs with the modular_model_index.json config
1863+ - When registering None for a component, it sets attribute to None but still syncs specs with the
1864+ modular_model_index.json config
18721865 """
18731866 for name , module in kwargs .items ():
18741867 # current component spec
@@ -1884,12 +1877,12 @@ def register_components(self, **kwargs):
18841877 # make sure the component is created from ComponentSpec
18851878 if module is not None and not hasattr (module , "_diffusers_load_id" ):
18861879 raise ValueError ("`ModularPipeline` only supports components created from `ComponentSpec`." )
1887-
1880+
18881881 if module is not None :
18891882 # actual library and class name of the module
18901883 library , class_name = _fetch_class_library_tuple (module ) # e.g. ("diffusers", "UNet2DConditionModel")
18911884 else :
1892- # if module is None, e.g. self.register_components(unet=None) during __init__
1885+ # if module is None, e.g. self.register_components(unet=None) during __init__
18931886 # we do not update the spec,
18941887 # but we still need to update the modular_model_index.json config based on component spec
18951888 library , class_name = None , None
@@ -1949,7 +1942,6 @@ def register_components(self, **kwargs):
19491942 if module is not None and module ._diffusers_load_id != "null" and self ._components_manager is not None :
19501943 self ._components_manager .add (name , module , self ._collection )
19511944
1952-
19531945 @property
19541946 def device (self ) -> torch .device :
19551947 r"""
@@ -2394,12 +2386,11 @@ def module_is_offloaded(module):
23942386 )
23952387 return self
23962388
2397-
23982389 @staticmethod
23992390 def _component_spec_to_dict (component_spec : ComponentSpec ) -> Any :
24002391 """
2401- Convert a ComponentSpec into a JSON‐serializable dict for saving in `modular_model_index.json`.
2402- If the default_creation_method is not from_pretrained, return None.
2392+ Convert a ComponentSpec into a JSON‐serializable dict for saving in `modular_model_index.json`. If the
2393+ default_creation_method is not from_pretrained, return None.
24032394
24042395 This dict contains:
24052396 - "type_hint": Tuple[str, str]
@@ -2423,30 +2414,19 @@ def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
24232414 Dict[str, Any]: A mapping suitable for JSON serialization.
24242415
24252416 Example:
2426- >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec
2427- >>> from diffusers import UNet2DConditionModel
2428- >>> spec = ComponentSpec(
2429- ... name="unet",
2430- ... type_hint=UNet2DConditionModel,
2431- ... config=None,
2432- ... repo="path/to/repo",
2433- ... subfolder="subfolder",
2434- ... variant=None,
2435- ... revision=None,
2436- ... default_creation_method="from_pretrained",
2437- ... )
2438- >>> ModularPipeline._component_spec_to_dict(spec)
2439- {
2440- "type_hint": ("diffusers", "UNet2DConditionModel"),
2441- "repo": "path/to/repo",
2442- "subfolder": "subfolder",
2443- "variant": None,
2444- "revision": None,
2417+ >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import
2418+ UNet2DConditionModel >>> spec = ComponentSpec(
2419+ ... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ... repo="path/to/repo", ...
2420+ subfolder="subfolder", ... variant=None, ... revision=None, ...
2421+ default_creation_method="from_pretrained",
2422+ ... ) >>> ModularPipeline._component_spec_to_dict(spec) {
2423+ "type_hint": ("diffusers", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": "subfolder",
2424+ "variant": None, "revision": None,
24452425 }
24462426 """
24472427 if component_spec .default_creation_method != "from_pretrained" :
24482428 return None
2449-
2429+
24502430 if component_spec .type_hint is not None :
24512431 lib_name , cls_name = _fetch_class_library_tuple (component_spec .type_hint )
24522432 else :
@@ -2466,8 +2446,7 @@ def _dict_to_component_spec(
24662446 """
24672447 Reconstruct a ComponentSpec from a loading specdict.
24682448
2469- This method converts a dictionary representation back into a ComponentSpec object.
2470- The dict should contain:
2449+ This method converts a dictionary representation back into a ComponentSpec object. The dict should contain:
24712450 - "type_hint": Tuple[str, str]
24722451 Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
24732452 - All loading fields defined by `component_spec.loading_fields()`, typically:
@@ -2491,23 +2470,11 @@ def _dict_to_component_spec(
24912470 ComponentSpec: A reconstructed ComponentSpec object.
24922471
24932472 Example:
2494- >>> spec_dict = {
2495- ... "type_hint": ("diffusers", "UNet2DConditionModel"),
2496- ... "repo": "stabilityai/stable-diffusion-xl",
2497- ... "subfolder": "unet",
2498- ... "variant": None,
2499- ... "revision": None,
2500- ... }
2501- >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
2502- ComponentSpec(
2503- name="unet",
2504- type_hint=UNet2DConditionModel,
2505- config=None,
2506- repo="stabilityai/stable-diffusion-xl",
2507- subfolder="unet",
2508- variant=None,
2509- revision=None,
2510- default_creation_method="from_pretrained"
2473+ >>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ... "repo":
2474+ "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant": None, ... "revision": None, ...
2475+ } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) ComponentSpec(
2476+ name="unet", type_hint=UNet2DConditionModel, config=None, repo="stabilityai/stable-diffusion-xl",
2477+ subfolder="unet", variant=None, revision=None, default_creation_method="from_pretrained"
25112478 )
25122479 """
25132480 # make a shallow copy so we can pop() safely
@@ -2524,4 +2491,4 @@ def _dict_to_component_spec(
25242491 name = name ,
25252492 type_hint = type_hint ,
25262493 ** spec_dict ,
2527- )
2494+ )
0 commit comments