Skip to content

Commit 0fcdd69

Browse files
committed
style
1 parent 5af003a commit 0fcdd69

File tree

4 files changed

+89
-59
lines changed

4 files changed

+89
-59
lines changed

src/diffusers/guiders/guider_utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union, Optional
15+
import os
16+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1617

1718
import torch
1819
from huggingface_hub.utils import validate_hf_hub_args
1920
from typing_extensions import Self
2021

21-
import os
22-
2322
from ..configuration_utils import ConfigMixin
2423
from ..utils import PushToHubMixin, get_logger
2524

2625

27-
2826
if TYPE_CHECKING:
2927
from ..modular_pipelines.modular_pipeline import BlockState
3028

@@ -221,8 +219,8 @@ def from_pretrained(
221219
222220
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
223221
the Hub.
224-
- A path to a *directory* (for example `./my_model_directory`) containing the guider
225-
configuration saved with [`~BaseGuidance.save_pretrained`].
222+
- A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
223+
saved with [`~BaseGuidance.save_pretrained`].
226224
subfolder (`str`, *optional*):
227225
The subfolder location of a model file within a larger model repository on the Hub or locally.
228226
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
@@ -285,6 +283,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
285283
"""
286284
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
287285

286+
288287
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
289288
r"""
290289
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any, Dict, List, Optional, Tuple, Union
2323

2424
import torch
25+
from huggingface_hub import create_repo
2526
from huggingface_hub.utils import validate_hf_hub_args
2627
from tqdm.auto import tqdm
2728
from typing_extensions import Self
@@ -34,6 +35,7 @@
3435
logging,
3536
)
3637
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
38+
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
3739
from .components_manager import ComponentsManager
3840
from .modular_pipeline_utils import (
3941
ComponentSpec,
@@ -47,8 +49,7 @@
4749
format_intermediates_short,
4850
make_doc_string,
4951
)
50-
from huggingface_hub import create_repo
51-
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
52+
5253

5354
if is_accelerate_available():
5455
import accelerate
@@ -1670,16 +1671,21 @@ def __init__(
16701671
16711672
This method sets up the pipeline by:
16721673
1. creating default pipeline blocks if not provided
1673-
2. gather component and config specifications based on the pipeline blocks's requirement (e.g. expected_components, expected_configs)
1674-
3. update the loading specs of from_pretrained components based on the modular_model_index.json file from huggingface hub if `pretrained_model_name_or_path` is provided
1674+
2. gather component and config specifications based on the pipeline blocks's requirement (e.g.
1675+
expected_components, expected_configs)
1676+
3. update the loading specs of from_pretrained components based on the modular_model_index.json file from
1677+
huggingface hub if `pretrained_model_name_or_path` is provided
16751678
4. create defaultfrom_config components and register everything
16761679
16771680
Args:
16781681
blocks: `ModularPipelineBlocks` instance. If None, will attempt to load
16791682
default blocks based on the pipeline class name.
16801683
pretrained_model_name_or_path: Path to a pretrained pipeline configuration. If provided,
1681-
will load component specs (only for from_pretrained components) and config values from the saved modular_model_index.json file.
1682-
components_manager: Optional ComponentsManager for managing multiple component cross different pipelines and apply offloading strategies.
1684+
will load component specs (only for from_pretrained components) and config values from the saved
1685+
modular_model_index.json file.
1686+
components_manager:
1687+
Optional ComponentsManager for managing multiple component cross different pipelines and apply
1688+
offloading strategies.
16831689
collection: Optional collection name for organizing components in the ComponentsManager.
16841690
**kwargs: Additional arguments passed to `load_config()` when loading pretrained configuration.
16851691
@@ -1693,18 +1699,20 @@ def __init__(
16931699
16941700
# Initialize with components manager
16951701
pipeline = ModularPipeline(
1696-
blocks=my_blocks,
1697-
components_manager=ComponentsManager(),
1698-
collection="my_collection"
1702+
blocks=my_blocks, components_manager=ComponentsManager(), collection="my_collection"
16991703
)
17001704
```
17011705
17021706
Notes:
17031707
- If blocks is None, the method will try to find default blocks based on the pipeline class name
1704-
- Components with default_creation_method="from_config" are created immediately, its specs are not included in config dict and will not be saved in `modular_model_index.json`
1705-
- Components with default_creation_method="from_pretrained" are set to None and can be loaded later with `load_default_components()`/`load_components()`
1706-
- The pipeline's config dict is populated with component specs (only for from_pretrained components) and config values, which will be saved as `modular_model_index.json` during `save_pretrained`
1707-
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as `_blocks_class_name` in the config dict
1708+
- Components with default_creation_method="from_config" are created immediately, its specs are not included
1709+
in config dict and will not be saved in `modular_model_index.json`
1710+
- Components with default_creation_method="from_pretrained" are set to None and can be loaded later with
1711+
`load_default_components()`/`load_components()`
1712+
- The pipeline's config dict is populated with component specs (only for from_pretrained components) and
1713+
config values, which will be saved as `modular_model_index.json` during `save_pretrained`
1714+
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
1715+
`_blocks_class_name` in the config dict
17081716
"""
17091717
if blocks is None:
17101718
blocks_class_name = MODULAR_PIPELINE_BLOCKS_MAPPING.get(self.__class__.__name__)
@@ -1769,12 +1777,14 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
17691777
17701778
Args:
17711779
state (`PipelineState`, optional):
1772-
PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be created based on the user inputs and the pipeline blocks's requirement.
1780+
PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be
1781+
created based on the user inputs and the pipeline blocks's requirement.
17731782
output (`str` or `List[str]`, optional):
17741783
Optional specification of what to return:
17751784
- None: Returns the complete `PipelineState` with all inputs and intermediates (default)
17761785
- str: Returns a specific intermediate value from the state (e.g. `output="image"`)
1777-
- List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image", "latents"]`)
1786+
- List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image",
1787+
"latents"]`)
17781788
17791789
17801790
Examples:
@@ -1794,11 +1804,12 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
17941804
state = pipeline(prompt="A beautiful sunset")
17951805
new_state = pipeline(state=state, output="image") # Continue processing
17961806
```
1797-
1807+
17981808
Returns:
17991809
- If `output` is None: Complete `PipelineState` containing all inputs and intermediates
18001810
- If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
1801-
- If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g. `output=["image", "latents"]`)
1811+
- If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g.
1812+
`output=["image", "latents"]`)
18021813
"""
18031814
if state is None:
18041815
state = PipelineState()
@@ -1880,11 +1891,14 @@ def from_pretrained(
18801891
18811892
Args:
18821893
pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
1883-
Path to a pretrained pipeline configuration. If provided, will load component specs (only for from_pretrained components) and config values from the modular_model_index.json file.
1894+
Path to a pretrained pipeline configuration. If provided, will load component specs (only for
1895+
from_pretrained components) and config values from the modular_model_index.json file.
18841896
trust_remote_code (`bool`, optional):
1885-
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create pipeline blocks based on the custom code in `pretrained_model_name_or_path`
1897+
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
1898+
pipeline blocks based on the custom code in `pretrained_model_name_or_path`
18861899
components_manager (`ComponentsManager`, optional):
1887-
ComponentsManager instance for managing multiple component cross different pipelines and apply offloading strategies.
1900+
ComponentsManager instance for managing multiple component cross different pipelines and apply
1901+
offloading strategies.
18881902
collection (`str`, optional):`
18891903
Collection name for organizing components in the ComponentsManager.
18901904
"""
@@ -1935,8 +1949,6 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
19351949
push_to_hub (`bool`, optional):
19361950
Whether to push the pipeline to the huggingface hub.
19371951
**kwargs: Additional arguments passed to `save_config()` method
1938-
1939-
19401952
"""
19411953
if push_to_hub:
19421954
commit_message = kwargs.pop("commit_message", None)
@@ -1945,12 +1957,12 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
19451957
token = kwargs.pop("token", None)
19461958
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
19471959
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
1948-
1960+
19491961
# Create a new empty model card and eventually tag it
19501962
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
19511963
model_card = populate_model_card(model_card)
19521964
model_card.save(os.path.join(save_directory, "README.md"))
1953-
1965+
19541966
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
19551967
self.save_config(save_directory=save_directory)
19561968

@@ -1977,7 +1989,8 @@ def register_components(self, **kwargs):
19771989
19781990
This method is responsible for:
19791991
1. Sets component objects as attributes on the loader (e.g., self.unet = unet)
1980-
2. Updates the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` (only for from_pretrained components)
1992+
2. Updates the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` (only
1993+
for from_pretrained components)
19811994
3. Adds components to the component manager if one is attached (only for from_pretrained components)
19821995
19831996
This method is called when:
@@ -1986,15 +1999,18 @@ def register_components(self, **kwargs):
19861999
- non from_pretrained components are created during __init__ and registered as the object itself
19872000
- Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or
19882001
loader.update_components(guider=guider_spec)
1989-
- (from_pretrained) Components are loaded with the `load_default_components()` method: e.g. loader.load_default_components(names=["unet"])
2002+
- (from_pretrained) Components are loaded with the `load_default_components()` method: e.g.
2003+
loader.load_default_components(names=["unet"])
19902004
19912005
Args:
19922006
**kwargs: Keyword arguments where keys are component names and values are component objects.
19932007
E.g., register_components(unet=unet_model, text_encoder=encoder_model)
19942008
19952009
Notes:
1996-
- When registering None for a component, it sets attribute to None but still syncs specs with the config dict, which will be saved as `modular_model_index.json` during `save_pretrained`
1997-
- component_specs are updated to match the new component outside of this method, e.g. in `update_components()` method
2010+
- When registering None for a component, it sets attribute to None but still syncs specs with the config
2011+
dict, which will be saved as `modular_model_index.json` during `save_pretrained`
2012+
- component_specs are updated to match the new component outside of this method, e.g. in
2013+
`update_components()` method
19982014
"""
19992015
for name, module in kwargs.items():
20002016
# current component spec
@@ -2166,7 +2182,8 @@ def config_component_names(self) -> List[str]:
21662182
def components(self) -> Dict[str, Any]:
21672183
"""
21682184
Returns:
2169-
- Dictionary mapping component names to their objects (include both from_pretrained and from_config components)
2185+
- Dictionary mapping component names to their objects (include both from_pretrained and from_config
2186+
components)
21702187
"""
21712188
# return only components we've actually set as attributes on self
21722189
return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)}
@@ -2186,19 +2203,21 @@ def update_components(self, **kwargs):
21862203
1. Replace existing components with new ones (e.g., updating `self.unet` or `self.text_encoder`)
21872204
2. Update configuration values (e.g., changing `self.requires_safety_checker` flag)
21882205
2189-
In addition to updating the components and configuration values as pipeline attributes, the method also updates:
2206+
In addition to updating the components and configuration values as pipeline attributes, the method also
2207+
updates:
21902208
- the corresponding specs in `_component_specs` and `_config_specs`
21912209
- the `config` dict, which will be saved as `modular_model_index.json` during `save_pretrained`
21922210
21932211
Args:
21942212
**kwargs: Component objects, ComponentSpec objects, or configuration values to update:
2195-
- Component objects: Only supports components we can extract specs using `ComponentSpec.from_component()` method
2196-
i.e. components created with ComponentSpec.load() or ConfigMixin subclasses that aren't nn.Modules
2197-
(e.g., `unet=new_unet, text_encoder=new_encoder`)
2198-
- ComponentSpec objects: Only supports default_creation_method == "from_config", will call create() method to create a new component
2199-
(e.g., `guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`)
2200-
- Configuration values: Simple values to update configuration settings
2201-
(e.g., `requires_safety_checker=False`)
2213+
- Component objects: Only supports components we can extract specs using
2214+
`ComponentSpec.from_component()` method i.e. components created with ComponentSpec.load() or
2215+
ConfigMixin subclasses that aren't nn.Modules (e.g., `unet=new_unet, text_encoder=new_encoder`)
2216+
- ComponentSpec objects: Only supports default_creation_method == "from_config", will call create()
2217+
method to create a new component (e.g., `guider=ComponentSpec(name="guider",
2218+
type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`)
2219+
- Configuration values: Simple values to update configuration settings (e.g.,
2220+
`requires_safety_checker=False`)
22022221
22032222
Raises:
22042223
ValueError: If a component object is not supported in ComponentSpec.from_component() method:
@@ -2228,9 +2247,11 @@ def update_components(self, **kwargs):
22282247
```
22292248
22302249
Notes:
2231-
- Components with trained weights must be created using ComponentSpec.load(). If the component has not been shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()`
2250+
- Components with trained weights must be created using ComponentSpec.load(). If the component has not been
2251+
shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()`
22322252
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly
2233-
- ComponentSpec objects with default_creation_method="from_pretrained" are not supported in update_components()
2253+
- ComponentSpec objects with default_creation_method="from_pretrained" are not supported in
2254+
update_components()
22342255
"""
22352256

22362257
# extract component_specs_updates & config_specs_updates from `specs`
@@ -2244,7 +2265,7 @@ def update_components(self, **kwargs):
22442265

22452266
for name, component in passed_components.items():
22462267
current_component_spec = self._component_specs[name]
2247-
2268+
22482269
# warn if type changed
22492270
if current_component_spec.type_hint is not None and not isinstance(
22502271
component, current_component_spec.type_hint
@@ -2255,10 +2276,11 @@ def update_components(self, **kwargs):
22552276
# update _component_specs based on the new component
22562277
new_component_spec = ComponentSpec.from_component(name, component)
22572278
if new_component_spec.default_creation_method != current_component_spec.default_creation_method:
2258-
logger.warning(f"ModularPipeline.update_components: changing the default_creation_method of {name} from {current_component_spec.default_creation_method} to {new_component_spec.default_creation_method}.")
2259-
2260-
self._component_specs[name] = new_component_spec
2279+
logger.warning(
2280+
f"ModularPipeline.update_components: changing the default_creation_method of {name} from {current_component_spec.default_creation_method} to {new_component_spec.default_creation_method}."
2281+
)
22612282

2283+
self._component_specs[name] = new_component_spec
22622284

22632285
if len(kwargs) > 0:
22642286
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
@@ -2551,8 +2573,8 @@ def module_is_offloaded(module):
25512573
@staticmethod
25522574
def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
25532575
"""
2554-
Convert a ComponentSpec into a JSON‐serializable dict for saving as an entry in `modular_model_index.json`.
2555-
If the `default_creation_method` is not `from_pretrained`, return None.
2576+
Convert a ComponentSpec into a JSON‐serializable dict for saving as an entry in `modular_model_index.json`. If
2577+
the `default_creation_method` is not `from_pretrained`, return None.
25562578
25572579
This dict contains:
25582580
- "type_hint": Tuple[str, str]

0 commit comments

Comments
 (0)