diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fa3e88d999b5..7a3de0b95747 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -39,6 +39,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], + "modular_pipelines": [], "quantizers.quantization_config": [], "schedulers": [], "utils": [ @@ -254,13 +255,19 @@ "KarrasVePipeline", "LDMPipeline", "LDMSuperResolutionPipeline", - "ModularLoader", "PNDMPipeline", "RePaintPipeline", "ScoreSdeVePipeline", "StableDiffusionMixin", ] ) + _import_structure["modular_pipelines"].extend( + [ + "ModularLoader", + "ComponentSpec", + "ComponentsManager", + ] + ) _import_structure["quantizers"] = ["DiffusersQuantizer"] _import_structure["schedulers"].extend( [ @@ -509,12 +516,10 @@ "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLModularLoader", "StableDiffusionXLPAGImg2ImgPipeline", "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLAutoPipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", "StableVideoDiffusionPipeline", @@ -541,6 +546,24 @@ ] ) + +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torch_and_transformers_objects # noqa F403 + + _import_structure["utils.dummy_torch_and_transformers_objects"] = [ + name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") + ] + +else: + _import_structure["modular_pipelines"].extend( + [ + "StableDiffusionXLAutoPipeline", + "StableDiffusionXLModularLoader", + ] + ) try: if not (is_torch_available() and is_transformers_available() and is_opencv_available()): raise OptionalDependencyNotAvailable() @@ -864,12 +887,16 @@ KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, - ModularLoader, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, StableDiffusionMixin, ) + from .modular_pipelines import ( + ModularLoader, + ComponentSpec, + ComponentsManager, + ) from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, @@ -1097,12 +1124,10 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularLoader, StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, StableDiffusionXLPipeline, - StableDiffusionXLAutoPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, StableVideoDiffusionPipeline, @@ -1127,7 +1152,16 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - + try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_pipelines import ( + StableDiffusionXLAutoPipeline, + StableDiffusionXLModularLoader, + ) try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 7da1cc59a365..ef2f3f2c8420 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -13,14 +13,14 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class AdaptiveProjectedGuidance(BaseGuidance): @@ -73,14 +73,18 @@ def __init__( self.use_original_formulation = use_original_formulation self.momentum_buffer = None - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + if self._step == 0: if self.adaptive_projected_guidance_momentum is not None: self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index bfffb9f39cd2..791cc582add2 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple import torch @@ -22,7 +22,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class AutoGuidance(BaseGuidance): @@ -120,11 +120,15 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None: registry = HookRegistry.check_if_exists_or_initialize(denoiser) registry.remove_hook(name, recurse=True) - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 429f8450410a..a459e51cd083 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -13,14 +13,14 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class ClassifierFreeGuidance(BaseGuidance): @@ -75,11 +75,15 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index 4c9839ee78f3..a722f2605036 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -13,14 +13,14 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class ClassifierFreeZeroStarGuidance(BaseGuidance): @@ -73,11 +73,15 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 7d005442e89c..e8e873f5c88f 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState logger = get_logger(__name__) # pylint: disable=invalid-name @@ -171,10 +171,10 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da Returns: `BlockState`: The prepared batch of data. """ - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState if input_fields is None: - raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.") + raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.") data_batch = {} for key, value in input_fields.items(): try: @@ -186,7 +186,7 @@ def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], da # We've already checked that value is a string or a tuple of strings with length 2 pass except AttributeError: - raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.") + logger.debug(f"`data` does not have attribute(s) {value}, skipping.") data_batch[cls._identifier_key] = identifier return BlockState(**data_batch) diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index bdd9e4af81b6..7c19f6391f41 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple import torch @@ -22,7 +22,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class SkipLayerGuidance(BaseGuidance): @@ -156,7 +156,11 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None: for hook_name in self._skip_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] @@ -168,7 +172,7 @@ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 1c7ee45dc3db..3986da913f82 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple import torch @@ -22,7 +22,7 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class SmoothedEnergyGuidance(BaseGuidance): @@ -149,7 +149,11 @@ def cleanup_models(self, denoiser: torch.nn.Module): for hook_name in self._seg_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] @@ -161,7 +165,7 @@ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index 631f9a5f33b2..017693fd9f07 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -13,14 +13,14 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState + from ..modular_pipelines.modular_pipeline import BlockState class TangentialClassifierFreeGuidance(BaseGuidance): @@ -62,11 +62,15 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index c50d2b7471e4..65a99464ba2f 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -30,6 +30,8 @@ _LAYER_SKIP_HOOK = "layer_skip_hook" +# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed +# either remove or make it serializable @dataclass class LayerSkipConfig: r""" diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py new file mode 100644 index 000000000000..cb2ed78ce360 --- /dev/null +++ b/src/diffusers/modular_pipelines/__init__.py @@ -0,0 +1,82 @@ +from typing import TYPE_CHECKING + +from ..utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +# These modules contain pipelines from multiple libraries/frameworks +_dummy_objects = {} +_import_structure = {} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_pt_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["modular_pipeline"] = [ + "ModularPipelineMixin", + "PipelineBlock", + "AutoPipelineBlocks", + "SequentialPipelineBlocks", + "LoopSequentialPipelineBlocks", + "ModularLoader", + "PipelineState", + "BlockState", + ] + _import_structure["modular_pipeline_utils"] = [ + "ComponentSpec", + "ConfigSpec", + "InputParam", + "OutputParam", + ] + _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoPipeline", "StableDiffusionXLModularLoader"] + _import_structure["components_manager"] = ["ComponentsManager"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 + else: + from .modular_pipeline import ( + AutoPipelineBlocks, + BlockState, + LoopSequentialPipelineBlocks, + ModularLoader, + ModularPipelineMixin, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, + ) + from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, + ) + from .stable_diffusion_xl import ( + StableDiffusionXLAutoPipeline, + StableDiffusionXLModularLoader, + ) + from .components_manager import ComponentsManager +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py similarity index 82% rename from src/diffusers/pipelines/components_manager.py rename to src/diffusers/modular_pipelines/components_manager.py index bdff133e22d9..992353389b95 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -29,6 +29,9 @@ from .modular_pipeline_utils import ComponentSpec +import uuid + + if is_accelerate_available(): from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module from accelerate.state import PartialState @@ -231,8 +234,6 @@ def search_best_candidate(module_sizes, min_memory_offload): -from .modular_pipeline_utils import ComponentSpec -import uuid class ComponentsManager: def __init__(self): self.components = OrderedDict() @@ -242,78 +243,122 @@ def __init__(self): self._auto_offload_enabled = False - def _get_by_collection(self, collection: str): + def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None): """ - Select components by collection name. + Lookup component_ids by name, collection, or load_id. """ - selected_components = {} - if collection in self.collections: - component_ids = self.collections[collection] - for component_id in component_ids: - selected_components[component_id] = self.components[component_id] - return selected_components + if components is None: + components = self.components + + if name: + ids_by_name = set() + for component_id, component in components.items(): + comp_name = self._id_to_name(component_id) + if comp_name == name: + ids_by_name.add(component_id) + else: + ids_by_name = set(components.keys()) + if collection: + ids_by_collection = set() + for component_id, component in components.items(): + if component_id in self.collections[collection]: + ids_by_collection.add(component_id) + else: + ids_by_collection = set(components.keys()) + if load_id: + ids_by_load_id = set() + for name, component in components.items(): + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: + ids_by_load_id.add(name) + else: + ids_by_load_id = set(components.keys()) + ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id) + return ids - def _get_by_load_id(self, load_id: str): - """ - Select components by its load_id. - """ - selected_components = {} - for name, component in self.components.items(): - if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: - selected_components[name] = component - return selected_components - + @staticmethod + def _id_to_name(component_id: str): + return "_".join(component_id.split("_")[:-1]) def add(self, name, component, collection: Optional[str] = None): + + component_id = f"{name}_{uuid.uuid4()}" + # check for duplicated components for comp_id, comp in self.components.items(): if comp == component: - logger.warning(f"Component '{name}' already exists in ComponentsManager") - return comp_id + comp_name = self._id_to_name(comp_id) + if comp_name == name: + logger.warning( + f"component '{name}' already exists as '{comp_id}'" + ) + component_id = comp_id + break + else: + logger.warning( + f"Adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'" + f"To remove a duplicate, call `components_manager.remove('')`." + ) - component_id = f"{name}_{uuid.uuid4()}" + # check for duplicated load_id and warn (we do not delete for you) if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": - components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id) + components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id) + components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id] + if components_with_same_load_id: - existing = ", ".join(components_with_same_load_id.keys()) + existing = ", ".join(components_with_same_load_id) logger.warning( - f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " - f"To remove a duplicate, call `components_manager.remove('')`." + f"Adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " + f"To remove a duplicate, call `components_manager.remove('')`." ) - # add component to components manager self.components[component_id] = component self.added_time[component_id] = time.time() + if collection: if collection not in self.collections: self.collections[collection] = set() - self.collections[collection].add(component_id) + if not component_id in self.collections[collection]: + comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) + for comp_id in comp_ids_in_collection: + logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}") + self.remove(comp_id) + self.collections[collection].add(component_id) + logger.info(f"Added component '{name}' in collection '{collection}': {component_id}") + else: + logger.info(f"Added component '{name}' as '{component_id}'") if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) - logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") return component_id - def remove(self, name: Union[str, List[str]]): + def remove(self, component_id: str = None): - if name not in self.components: - logger.warning(f"Component '{name}' not found in ComponentsManager") + if component_id not in self.components: + logger.warning(f"Component '{component_id}' not found in ComponentsManager") return - - self.components.pop(name) - self.added_time.pop(name) + + component = self.components.pop(component_id) + self.added_time.pop(component_id) for collection in self.collections: - if name in self.collections[collection]: - self.collections[collection].remove(name) + if component_id in self.collections[collection]: + self.collections[collection].remove(component_id) if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) + else: + if isinstance(component, torch.nn.Module): + component.to("cpu") + del component + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, as_name_component_tuples: bool = False): @@ -342,16 +387,8 @@ def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = N or list of (base_name, component) tuples if as_name_component_tuples=True """ - if collection: - if collection not in self.collections: - logger.warning(f"Collection '{collection}' not found in ComponentsManager") - return [] if as_name_component_tuples else {} - components = self._get_by_collection(collection) - else: - components = self.components - - if load_id: - components = self._get_by_load_id(load_id) + selected_ids = self._lookup_ids(collection=collection, load_id=load_id) + components = {k: self.components[k] for k in selected_ids} # Helper to extract base name from component_id def get_base_name(component_id): @@ -541,11 +578,11 @@ def disable_auto_cpu_offload(self): self._auto_offload_enabled = False # YiYi TODO: add quantization info - def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: + def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. Args: - name: Name of the component to get info for + component_id: Name of the component to get info for fields: Optional field(s) to return. Can be a string for single field or list of fields. If None, returns all fields. @@ -554,16 +591,16 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No If fields is specified, returns only those fields. If a single field is requested as string, returns just that field's value. """ - if name not in self.components: - raise ValueError(f"Component '{name}' not found in ComponentsManager") + if component_id not in self.components: + raise ValueError(f"Component '{component_id}' not found in ComponentsManager") - component = self.components[name] + component = self.components[component_id] # Build complete info dict first info = { - "model_id": name, - "added_time": self.added_time[name], - "collection": next((coll for coll, comps in self.collections.items() if name in comps), None), + "model_id": component_id, + "added_time": self.added_time[component_id], + "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None, } # Additional info for torch.nn.Module components @@ -649,11 +686,19 @@ def format_device(component, info): ] max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 - # Collection names - collection_names = [ - next((coll for coll, comps in self.collections.items() if name in comps), "N/A") - for name in self.components.keys() - ] + # Get all collections for each component + component_collections = {} + for name in self.components.keys(): + component_collections[name] = [] + for coll, comps in self.collections.items(): + if name in comps: + component_collections[name].append(coll) + if not component_collections[name]: + component_collections[name] = ["N/A"] + + # Find the maximum collection name length + all_collections = [coll for colls in component_collections.values() for coll in colls] + max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10 col_widths = { "name": max(15, max(len(name) for name in simple_names)), @@ -662,7 +707,7 @@ def format_device(component, info): "dtype": 15, "size": 10, "load_id": max_load_id_len, - "collection": max(10, max(len(str(c)) for c in collection_names)) + "collection": max_collection_len } # Create the header lines @@ -691,11 +736,21 @@ def format_device(component, info): device_str = format_device(component, info) dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" load_id = get_load_id(component) - collection = info["collection"] or "N/A" + + # Print first collection on the main line + first_collection = component_collections[name][0] if component_collections[name] else "N/A" output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " - output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" + output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n" + + # Print additional collections on separate lines if they exist + for i in range(1, len(component_collections[name])): + collection = component_collections[name][i] + output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | " + output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | " + output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n" + output += dash_line # Other components section @@ -711,9 +766,17 @@ def format_device(component, info): for name, component in others.items(): info = self.get_model_info(name) simple_name = get_simple_name(name) - collection = info["collection"] or "N/A" - output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" + # Print first collection on the main line + first_collection = component_collections[name][0] if component_collections[name] else "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n" + + # Print additional collections on separate lines if they exist + for i in range(1, len(component_collections[name])): + collection = component_collections[name][i] + output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n" + output += dash_line # Add additional component info @@ -775,7 +838,7 @@ def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" ) - def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: + def get_one(self, component_id: Optional[str] = None, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: """ Get a single component by name. Raises an error if multiple components match or none are found. @@ -790,6 +853,15 @@ def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, Raises: ValueError: If no components match or multiple components match """ + + # if component_id is provided, return the component + if component_id is not None and (name is not None or collection is not None or load_id is not None): + raise ValueError(" if component_id is provided, name, collection, and load_id must be None") + elif component_id is not None: + if component_id not in self.components: + raise ValueError(f"Component '{component_id}' not found in ComponentsManager") + return self.components[component_id] + results = self.get(name, collection, load_id) if not results: diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py similarity index 61% rename from src/diffusers/pipelines/modular_pipeline.py rename to src/diffusers/modular_pipelines/modular_pipeline.py index 636b543395df..ef725c32f4f9 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -17,6 +17,7 @@ from collections import OrderedDict from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple, Union, Optional, Type +from copy import deepcopy import torch @@ -34,7 +35,7 @@ logging, PushToHubMixin, ) -from .pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj,_fetch_class_library_tuple +from ..pipelines.pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj, _fetch_class_library_tuple from .modular_pipeline_utils import ( ComponentSpec, ConfigSpec, @@ -73,18 +74,74 @@ class PipelineState: inputs: Dict[str, Any] = field(default_factory=dict) intermediates: Dict[str, Any] = field(default_factory=dict) + input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) + intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) - def add_input(self, key: str, value: Any): + def add_input(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an input to the pipeline state with optional metadata. + + Args: + key (str): The key for the input + value (Any): The input value + kwargs_type (str): The kwargs_type to store with the input + """ self.inputs[key] = value + if kwargs_type is not None: + if kwargs_type not in self.input_kwargs: + self.input_kwargs[kwargs_type] = [key] + else: + self.input_kwargs[kwargs_type].append(key) - def add_intermediate(self, key: str, value: Any): + def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an intermediate value to the pipeline state with optional metadata. + + Args: + key (str): The key for the intermediate value + value (Any): The intermediate value + kwargs_type (str): The kwargs_type to store with the intermediate value + """ self.intermediates[key] = value + if kwargs_type is not None: + if kwargs_type not in self.intermediate_kwargs: + self.intermediate_kwargs[kwargs_type] = [key] + else: + self.intermediate_kwargs[kwargs_type].append(key) def get_input(self, key: str, default: Any = None) -> Any: - return self.inputs.get(key, default) + value = self.inputs.get(key, default) + if value is not None: + return deepcopy(value) def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: return {key: self.inputs.get(key, default) for key in keys} + + def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all inputs with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of inputs with matching kwargs_type + """ + input_names = self.input_kwargs.get(kwargs_type, []) + return self.get_inputs(input_names) + + def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all intermediates with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of intermediates with matching kwargs_type + """ + intermediate_names = self.intermediate_kwargs.get(kwargs_type, []) + return self.get_intermediates(intermediate_names) def get_intermediate(self, key: str, default: Any = None) -> Any: return self.intermediates.get(key, default) @@ -106,11 +163,17 @@ def format_value(v): inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) + + # Format input_kwargs and intermediate_kwargs + input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) + intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) return ( f"PipelineState(\n" f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }}\n" + f" intermediates={{\n{intermediates}\n }},\n" + f" input_kwargs={{\n{input_kwargs_str}\n }},\n" + f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n" f")" ) @@ -124,6 +187,23 @@ def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) + def __getitem__(self, key: str): + # allows block_state["foo"] + return getattr(self, key, None) + + def __setitem__(self, key: str, value: Any): + # allows block_state["foo"] = "bar" + setattr(self, key, value) + + def as_dict(self): + """ + Convert BlockState to a dictionary. + + Returns: + Dict[str, Any]: Dictionary containing all attributes of the BlockState + """ + return {key: value for key, value in self.__dict__.items()} + def __repr__(self): def format_value(v): # Handle tensors directly @@ -146,10 +226,16 @@ def format_value(v): # Handle dicts with tensor values elif isinstance(v, dict): - if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()): - shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")} - return f"Dict of Tensors with shapes {shapes}" - return repr(v) + formatted_dict = {} + for k, val in v.items(): + if hasattr(val, "shape") and hasattr(val, "dtype"): + formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})" + elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"): + shapes = [t.shape for t in val] + formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}" + else: + formatted_dict[k] = repr(val) + return formatted_dict # Default case return repr(v) @@ -199,33 +285,38 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, state = PipelineState() if not hasattr(self, "loader"): - raise ValueError("Loader is not set, please call `setup_loader()` first.") + logger.info("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") + self.loader = None # Make a copy of the input kwargs - input_params = kwargs.copy() + passed_kwargs = kwargs.copy() - default_params = self.default_call_parameters # Add inputs to state, using defaults if not provided in the kwargs or the state # if same input already in the state, will override it if provided in the kwargs intermediates_inputs = [inp.name for inp in self.intermediates_inputs] - for name, default in default_params.items(): - if name in input_params: + for expected_input_param in self.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: if name not in intermediates_inputs: - state.add_input(name, input_params.pop(name)) + state.add_input(name, passed_kwargs.pop(name), kwargs_type) else: - state.add_input(name, input_params[name]) + state.add_input(name, passed_kwargs[name], kwargs_type) elif name not in state.inputs: - state.add_input(name, default) + state.add_input(name, default, kwargs_type) - for name in intermediates_inputs: - if name in input_params: - state.add_intermediate(name, input_params.pop(name)) + for expected_intermediate_param in self.intermediates_inputs: + name = expected_intermediate_param.name + kwargs_type = expected_intermediate_param.kwargs_type + if name in passed_kwargs: + state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) # Warn about unexpected inputs - if len(input_params) > 0: - logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + if len(passed_kwargs) > 0: + warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") # Run the pipeline with torch.no_grad(): try: @@ -285,7 +376,6 @@ def expected_configs(self) -> List[ConfigSpec]: return [] - # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable @property def inputs(self) -> List[InputParam]: """List of input parameters. Must be implemented by subclasses.""" @@ -301,13 +391,16 @@ def intermediates_outputs(self) -> List[OutputParam]: """List of intermediate output parameters. Must be implemented by subclasses.""" return [] + def _get_outputs(self): + return self.intermediates_outputs + + # YiYi TODO: is it too easy for user to unintentionally override these properties? # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks @property def outputs(self) -> List[OutputParam]: - return self.intermediates_outputs + return self._get_outputs() - @property - def required_inputs(self) -> List[str]: + def _get_required_inputs(self): input_names = [] for input_param in self.inputs: if input_param.required: @@ -315,13 +408,23 @@ def required_inputs(self) -> List[str]: return input_names @property - def required_intermediates_inputs(self) -> List[str]: + def required_inputs(self) -> List[str]: + return self._get_required_inputs() + + + def _get_required_intermediates_inputs(self): input_names = [] for input_param in self.intermediates_inputs: if input_param.required: input_names.append(input_param.name) return input_names + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediates_inputs(self) -> List[str]: + return self._get_required_intermediates_inputs() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") @@ -383,31 +486,81 @@ def doc(self): ) + # YiYi TODO: input and inteermediate inputs with same name? should warn? def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" data = {} # Check inputs for input_param in self.inputs: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - data[input_param.name] = value + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v # Check intermediates for input_param in self.intermediates_inputs: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - data[input_param.name] = value - + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v return BlockState(**data) def add_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediates_outputs: if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - state.add_intermediate(output_param.name, getattr(block_state, output_param.name)) + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) + + for input_param in self.intermediates_inputs: + if hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + + for input_param in self.intermediates_inputs: + if input_param.name and hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: + # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters + # we need to first find out which inputs are and loop through them. + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediates_kwargs.items(): + param = getattr(block_state, param_name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(param_name, param, input_param.kwargs_type) def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: @@ -427,22 +580,26 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li for block_name, inputs in named_input_lists: for input_param in inputs: - if input_param.name in combined_dict: - current_param = combined_dict[input_param.name] + if input_param.name is None and input_param.kwargs_type is not None: + input_name = "*_" + input_param.kwargs_type + else: + input_name = input_param.name + if input_name in combined_dict: + current_param = combined_dict[input_name] if (current_param.default is not None and input_param.default is not None and current_param.default != input_param.default): warnings.warn( - f"Multiple different default values found for input '{input_param.name}': " - f"{current_param.default} (from block '{value_sources[input_param.name]}') and " + f"Multiple different default values found for input '{input_name}': " + f"{current_param.default} (from block '{value_sources[input_name]}') and " f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." ) if current_param.default is None and input_param.default is not None: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name + combined_dict[input_name] = input_param + value_sources[input_name] = block_name else: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name + combined_dict[input_name] = input_param + value_sources[input_name] = block_name return list(combined_dict.values()) @@ -461,7 +618,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> for block_name, outputs in named_output_lists: for output_param in outputs: - if output_param.name not in combined_dict: + if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): combined_dict[output_param.name] = output_param return list(combined_dict.values()) @@ -544,7 +701,9 @@ def required_inputs(self) -> List[str]: required_by_all.intersection_update(block_required) return list(required_by_all) - + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: first_block = next(iter(self.blocks.values())) @@ -721,14 +880,21 @@ def __repr__(self): indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result @property @@ -750,13 +916,15 @@ class SequentialPipelineBlocks(ModularPipelineMixin): block_classes = [] block_names = [] - @property - def model_name(self): - return next(iter(self.blocks.values())).model_name @property def description(self): return "" + + @property + def model_name(self): + return next(iter(self.blocks.values())).model_name + @property def expected_components(self): @@ -812,6 +980,8 @@ def required_inputs(self) -> List[str]: return list(required_by_any) + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: required_intermediates_inputs = [] @@ -823,6 +993,9 @@ def required_intermediates_inputs(self) -> List[str]: # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: + return self.get_inputs() + + def get_inputs(self): named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] combined_inputs = combine_inputs(*named_inputs) # mark Required inputs only if that input is required any of the blocks @@ -835,13 +1008,20 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[str]: + return self.get_intermediates_inputs() + + def get_intermediates_inputs(self): inputs = [] outputs = set() + added_inputs = set() # Go through all blocks in order for block in self.blocks.values(): # Add inputs that aren't in outputs yet - inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + for inp in block.intermediates_inputs: + if inp.name not in outputs and inp.name not in added_inputs: + inputs.append(inp) + added_inputs.add(inp.name) # Only add outputs if the block cannot be skipped should_add_outputs = True @@ -856,14 +1036,21 @@ def intermediates_inputs(self) -> List[str]: @property def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + named_outputs = [] + for name, block in self.blocks.items(): + inp_names = set([inp.name for inp in block.intermediates_inputs]) + # so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) + # filter out them here so they do not end up as intermediates_outputs + if name not in inp_names: + named_outputs.append((name, block.intermediates_outputs)) combined_outputs = combine_outputs(*named_outputs) return combined_outputs + # YiYi TODO: I think we can remove the outputs property @property def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediates_outputs - + # return next(reversed(self.blocks.values())).intermediates_outputs + return self.intermediates_outputs @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: for block_name, block in self.blocks.items(): @@ -910,16 +1097,17 @@ def trigger_inputs(self): def _traverse_trigger_blocks(self, trigger_inputs): # Convert trigger_inputs to a set for easier manipulation active_triggers = set(trigger_inputs) - def fn_recursive_traverse(block, block_name, active_triggers): result_blocks = OrderedDict() - # sequential or PipelineBlock + # sequential(include loopsequential) or PipelineBlock if not hasattr(block, 'block_trigger_inputs'): if hasattr(block, 'blocks'): - # sequential - for block_name, block in block.blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) + # sequential or LoopSequentialPipelineBlocks (keep traversing) + for sub_block_name, sub_block in block.blocks.items(): + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = {f"{block_name}.{k}": v for k,v in blocks_to_update.items()} result_blocks.update(blocks_to_update) else: # PipelineBlock @@ -946,13 +1134,14 @@ def fn_recursive_traverse(block, block_name, active_triggers): matching_trigger = None if this_block is not None: - # sequential/auto + # sequential/auto (keep traversing) if hasattr(this_block, 'blocks'): result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) else: # PipelineBlock result_blocks[block_name] = this_block # Add this block's output names to active triggers if defined + # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? if hasattr(this_block, 'outputs'): active_triggers.update(out.name for out in this_block.outputs) @@ -1051,15 +1240,327 @@ def __repr__(self): indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs ) +#YiYi TODO: __repr__ +class LoopSequentialPipelineBlocks(ModularPipelineMixin): + """ + A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. + """ + + model_name = None + block_classes = [] + block_names = [] + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + raise NotImplementedError("description method must be implemented in subclasses") + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def loop_expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def loop_inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + + @property + def loop_required_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def loop_required_intermediates_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_intermediates_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + # modified from SequentialPipelineBlocks to include loop_expected_components + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + for component in self.loop_expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + # modified from SequentialPipelineBlocks to include loop_expected_configs + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + for config in self.loop_expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + # modified from SequentialPipelineBlocks to include loop_inputs + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + named_inputs.append(("loop", self.loop_inputs)) + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + # Copied from SequentialPipelineBlocks + @property + def inputs(self): + return self.get_inputs() + + + # modified from SequentialPipelineBlocks to include loop_intermediates_inputs + @property + def intermediates_inputs(self): + intermediates = self.get_intermediates_inputs() + intermediate_names = [input.name for input in intermediates] + for loop_intermediate_input in self.loop_intermediates_inputs: + if loop_intermediate_input.name not in intermediate_names: + intermediates.append(loop_intermediate_input) + return intermediates + + + # Copied from SequentialPipelineBlocks + def get_intermediates_inputs(self): + inputs = [] + outputs = set() + + # Go through all blocks in order + for block in self.blocks.values(): + # Add inputs that aren't in outputs yet + inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) + return inputs + + + # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + required_by_loop = set(getattr(self, "loop_required_inputs", set())) + required_by_any.update(required_by_loop) + + # Union with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediates_inputs(self) -> List[str]: + required_intermediates_inputs = [] + for input_param in self.intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + for input_param in self.loop_intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + return required_intermediates_inputs + + + # YiYi TODO: this need to be thought about more + # modified from SequentialPipelineBlocks to include loop_intermediates_outputs + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + for output in self.loop_intermediates_outputs: + if output.name not in set([output.name for output in combined_outputs]): + combined_outputs.append(output) + return combined_outputs + + # YiYi TODO: this need to be thought about more + # copied from SequentialPipelineBlocks + @property + def outputs(self) -> List[str]: + return next(reversed(self.blocks.values())).intermediates_outputs + + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + + @classmethod + def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks": + """Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks. + + Args: + blocks_dict: Dictionary mapping block names to block instances + + Returns: + A new LoopSequentialPipelineBlocks instance + """ + instance = cls() + instance.block_classes = [block.__class__ for block in blocks_dict.values()] + instance.block_names = list(blocks_dict.keys()) + instance.blocks = blocks_dict + return instance + + def loop_step(self, components, state: PipelineState, **kwargs): + + for block_name, block in self.blocks.items(): + try: + components, state = block(components, state, **kwargs) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return components, state + + def __call__(self, components, state: PipelineState) -> PipelineState: + raise NotImplementedError("`__call__` method needs to be implemented by the subclass") + + + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediates_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def add_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediates_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) + + for input_param in self.intermediates_inputs: + if input_param.name and hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: + # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters + # we need to first find out which inputs are and loop through them. + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediates_kwargs.items(): + if not hasattr(block_state, param_name): + continue + param = getattr(block_state, param_name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(param_name, param, input_param.kwargs_type) + @property def doc(self): @@ -1073,12 +1574,74 @@ def doc(self): expected_configs=self.expected_configs ) + # modified from SequentialPipelineBlocks, + #(does not need trigger_inputs related part so removed them, + # do not need to support auto block for loop blocks) + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split('\n') + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result + + # YiYi TODO: -# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) -# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader -# 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() +# 1. move the modular_repo arg and the logic to fetch info from repo out of __init__ so that __init__ alwasy create an default modular_model_index config +# 2. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) +# 3. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader +# 4. add validator for methods where we accpet kwargs to be passed to from_pretrained() class ModularLoader(ConfigMixin, PushToHubMixin): """ Base class for all Modular pipelines loaders. @@ -1089,54 +1652,68 @@ class ModularLoader(ConfigMixin, PushToHubMixin): def register_components(self, **kwargs): """ - Register components with their corresponding specs. - This method is called when component changed or __init__ is called. - + Register components with their corresponding specifications. + + This method is responsible for: + 1. Sets component objects as attributes on the loader (e.g., self.unet = unet) + 2. Updates the modular_model_index.json configuration for serialization + 4. Adds components to the component manager if one is attached + + This method is called when: + - Components are first initialized in __init__: + - from_pretrained components not loaded during __init__ so they are registered as None; + - non from_pretrained components are created during __init__ and registered as the object itself + - Components are updated with the `update()` method: e.g. loader.update(unet=unet) or loader.update(guider=guider_spec) + - (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(component_names=["unet"]) + Args: **kwargs: Keyword arguments where keys are component names and values are component objects. + E.g., register_components(unet=unet_model, text_encoder=encoder_model) + Notes: + - Components must be created from ComponentSpec (have _diffusers_load_id attribute) + - When registering None for a component, it updates the modular_model_index.json config but sets attribute to None """ for name, module in kwargs.items(): - # current component spec component_spec = self._component_specs.get(name) if component_spec is None: logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") continue + # check if it is the first time registration, i.e. calling from __init__ is_registered = hasattr(self, name) + # make sure the component is created from ComponentSpec if module is not None and not hasattr(module, "_diffusers_load_id"): raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - # actual library and class name of the module - if module is not None: - library, class_name = _fetch_class_library_tuple(module) - new_component_spec = ComponentSpec.from_component(name, module) - component_spec_dict = self._component_spec_to_dict(new_component_spec) + # actual library and class name of the module + library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") + + # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config + # e.g. {"repo": "stabilityai/stable-diffusion-2-1", + # "type_hint": ("diffusers", "UNet2DConditionModel"), + # "subfolder": "unet", + # "variant": None, + # "revision": None} + component_spec_dict = self._component_spec_to_dict(component_spec) else: + # if module is None, e.g. self.register_components(unet=None) during __init__ + # we do not update the spec, + # but we still need to update the modular_model_index.json config based oncomponent spec library, class_name = None, None - # if module is None, we do not update the spec, - # but we still need to update the config to make sure it's synced with the component spec - # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) - new_component_spec = component_spec component_spec_dict = self._component_spec_to_dict(component_spec) - - # do not register if component is not to be loaded from pretrained - if new_component_spec.default_creation_method == "from_pretrained": - register_dict = {name: (library, class_name, component_spec_dict)} - else: - register_dict = {} + register_dict = {name: (library, class_name, component_spec_dict)} # set the component as attribute # if it is not set yet, just set it and skip the process to check and warn below if not is_registered: self.register_to_config(**register_dict) - self._component_specs[name] = new_component_spec setattr(self, name, module) - if module is not None and self._component_manager is not None: + if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: self._component_manager.add(name, module, self._collection) continue @@ -1145,10 +1722,6 @@ def register_components(self, **kwargs): if current_module is module: logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") continue - - # it module is not an instance of the expected type, still register it but with a warning - if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): - logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") # warn if unregister if current_module is not None and module is None: @@ -1156,7 +1729,7 @@ def register_components(self, **kwargs): f"ModularLoader.register_components: setting '{name}' to None " f"(was {current_module.__class__.__name__})" ) - # same type, new instance → debug + # same type, new instance → replace but send debug log elif current_module is not None \ and module is not None \ and isinstance(module, current_module.__class__) \ @@ -1166,13 +1739,12 @@ def register_components(self, **kwargs): f"(same type {type(current_module).__name__}, new instance)" ) - # save modular_model_index.json config + # update modular_model_index.json config self.register_to_config(**register_dict) - # update component spec - self._component_specs[name] = new_component_spec # finally set models setattr(self, name, module) - if module is not None and self._component_manager is not None: + # add to component manager if one is attached + if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: self._component_manager.add(name, module, self._collection) @@ -1196,6 +1768,7 @@ def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: config_dict = self.load_config(modular_repo, **kwargs) for name, value in config_dict.items(): + # only update component_spec for from_pretrained components if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: library, class_name, component_spec_dict = value component_spec = self._dict_to_component_spec(name, component_spec_dict) @@ -1206,7 +1779,11 @@ def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: register_components_dict = {} for name, component_spec in self._component_specs.items(): - register_components_dict[name] = None + if component_spec.default_creation_method == "from_config": + component = component_spec.create() + else: + component = None + register_components_dict[name] = component self.register_components(**register_components_dict) default_configs = {} @@ -1308,6 +1885,7 @@ def update(self, **kwargs): **kwargs: Component objects or configuration values to update: - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) + - ComponentSpec objects: if passed a ComponentSpec object, only support from_config spec, will call create() method to create it Raises: ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) @@ -1331,22 +1909,52 @@ def update(self, **kwargs): unet=new_unet_model, requires_safety_checker=False ) + # update with ComponentSpec objects + loader.update( + guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={"guidance_scale": 5.0}, default_creation_method="from_config") + ) ``` """ # extract component_specs_updates & config_specs_updates from `specs` - passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} + passed_component_specs = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)} + passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)} passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} for name, component in passed_components.items(): if not hasattr(component, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + # YiYi TODO: remove this if we remove support for non config mixin components in `create()` method + if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): + raise ValueError( + f"The passed component '{name}' is not supported in update() method " + f"because it is not supported in `ComponentSpec.from_component()`. " + f"Please pass a ComponentSpec object instead." + ) + current_component_spec = self._component_specs[name] + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance(component, current_component_spec.type_hint): + logger.warning(f"ModularLoader.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + # update _component_specs based on the new component + new_component_spec = ComponentSpec.from_component(name, component) + self._component_specs[name] = new_component_spec if len(kwargs) > 0: logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") - - self.register_components(**passed_components) + created_components = {} + for name, component_spec in passed_component_specs.items(): + if component_spec.default_creation_method == "from_pretrained": + raise ValueError(f"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method") + created_components[name] = component_spec.create() + current_component_spec = self._component_specs[name] + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance(created_components[name], current_component_spec.type_hint): + logger.warning(f"ModularLoader.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + # update _component_specs based on the user passed component_spec + self._component_specs[name] = component_spec + self.register_components(**passed_components, **created_components) config_to_register = {} @@ -1370,8 +1978,9 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. """ + # if not specific name, load all the components with default_creation_method == "from_pretrained" if component_names is None: - component_names = list(self._component_specs.keys()) + component_names = [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"] elif not isinstance(component_names, list): component_names = [component_names] @@ -1396,7 +2005,7 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): # check if the default is specified component_load_kwargs[key] = value["default"] try: - components_to_register[name] = spec.create(**component_load_kwargs) + components_to_register[name] = spec.load(**component_load_kwargs) except Exception as e: logger.warning(f"Failed to create component '{name}': {e}") @@ -1424,7 +2033,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: @classmethod @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) expected_component = set(config_dict.pop("_components_names")) @@ -1435,20 +2044,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P for name, value in config_dict.items(): if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: library, class_name, component_spec_dict = value - component_spec = cls._dict_to_component_spec(name, component_spec_dict) - component_specs.append(component_spec) + # only pick up pretrained components from the repo + if component_spec_dict.get("repo", None) is not None: + component_spec = cls._dict_to_component_spec(name, component_spec_dict) + component_specs.append(component_spec) elif name in expected_config: config_specs.append(ConfigSpec(name=name, default=value)) - - for name in expected_component: - for spec in component_specs: - if spec.name == name: - break - else: - # append a empty component spec for these not in modular_model_index - component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) - return cls(component_specs + config_specs) + + return cls(component_specs + config_specs, component_manager=component_manager, collection=collection) @staticmethod diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py similarity index 87% rename from src/diffusers/pipelines/modular_pipeline_utils.py rename to src/diffusers/modular_pipelines/modular_pipeline_utils.py index c8064a5215aa..0c6d1b585589 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -71,34 +71,31 @@ def __eq__(self, other): self.default_creation_method == other.default_creation_method) @classmethod - def from_component(cls, name: str, component: torch.nn.Module) -> Any: - """Create a ComponentSpec from a Component created by `create` method.""" + def from_component(cls, name: str, component: Any) -> Any: + """Create a ComponentSpec from a Component created by `create` or `load` method.""" if not hasattr(component, "_diffusers_load_id"): - raise ValueError("Component is not created by `create` method") + raise ValueError("Component is not created by `create` or `load` method") + # throw a error if component is created with `create` method but not a subclass of ConfigMixin + # YiYi TODO: remove this check if we remove support for non configmixin in `create()` method + if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): + raise ValueError( + "We currently only support creating ComponentSpec from a component with " + "created with `ComponentSpec.load` method" + "or created with `ComponentSpec.create` and a subclass of ConfigMixin" + ) type_hint = component.__class__ + default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained" - if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): + if isinstance(component, ConfigMixin): config = component.config else: config = None load_spec = cls.decode_load_id(component._diffusers_load_id) - return cls(name=name, type_hint=type_hint, config=config, **load_spec) - - @classmethod - def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: - """Create a ComponentSpec from a load_id string.""" - if load_id == "null": - raise ValueError("Cannot create ComponentSpec from null load_id") - - # Decode the load_id into a dictionary of loading fields - load_fields = cls.decode_load_id(load_id) - - # Create a new ComponentSpec instance with the decoded fields - return cls(name=name, **load_fields) + return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec) @classmethod def loading_fields(cls) -> List[str]: @@ -137,7 +134,7 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: "revision": "revision" } If a segment value is "null", it's replaced with None. - Returns None if load_id is "null" (indicating component not loaded from pretrained). + Returns None if load_id is "null" (indicating component not created with `load` method). """ # Get all loading fields in order @@ -158,20 +155,12 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: return result - # YiYi TODO: add validator - def create(self, **kwargs) -> Any: - """Create the component using the preferred creation method.""" - - # from_pretrained creation - if self.default_creation_method == "from_pretrained": - return self.create_from_pretrained(**kwargs) - elif self.default_creation_method == "from_config": - # from_config creation - return self.create_from_config(**kwargs) - else: - raise ValueError(f"Invalid creation method: {self.default_creation_method}") - def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: + # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin) + # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component) + # the config info is lost in the process + # remove error check in from_component spec and ModularLoader.update() if we remove support for non configmixin in `create()` method + def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: """Create component using from_config with config.""" if self.type_hint is None or not isinstance(self.type_hint, type): @@ -201,34 +190,35 @@ def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] return component # YiYi TODO: add guard for type of model, if it is supported by from_pretrained - def create_from_pretrained(self, **kwargs) -> Any: - """Create component using from_pretrained.""" + def load(self, **kwargs) -> Any: + """Load component using from_pretrained.""" + # select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} + # merge loading field value in the spec with user passed values to create load_kwargs load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path repo = load_kwargs.pop("repo", None) if repo is None: - raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") if self.type_hint is None: try: from diffusers import AutoModel component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) except Exception as e: - raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}") + raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}") + # update type_hint if AutoModel load successfully self.type_hint = component.__class__ else: try: component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) except Exception as e: - raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") + raise ValueError(f"Unable to load {self.name} using load method: {e}") - if repo != self.repo: - self.repo = repo - for k, v in passed_loading_kwargs.items(): - if v is not None: - setattr(self, k, v) + self.repo = repo + for k, v in load_kwargs.items(): + setattr(self, k, v) component._diffusers_load_id = self.load_id return component @@ -241,14 +231,22 @@ class ConfigSpec: name: str default: Any description: Optional[str] = None + + +# YiYi Notes: both inputs and intermediates_inputs are InputParam objects +# however some fields are not relevant for intermediates_inputs +# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed +# default is not used for intermediates_inputs, we only use default from inputs, so it is ignored if it is set for intermediates_inputs +# -> should we use different class for inputs and intermediates_inputs? @dataclass class InputParam: """Specification for an input parameter.""" - name: str + name: str = None type_hint: Any = None default: Any = None required: bool = False description: str = "" + kwargs_type: str = None # YiYi Notes: experimenting with this, not sure if we should keep it def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" @@ -260,6 +258,7 @@ class OutputParam: name: str type_hint: Any = None description: str = "" + kwargs_type: str = None def __repr__(self): return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" @@ -320,7 +319,11 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu if inp.name in required_intermediates_inputs: input_parts.append(f"Required({inp.name})") else: - input_parts.append(inp.name) + if inp.name is None and inp.kwargs_type is not None: + inp_name = "*_" + inp.kwargs_type + else: + inp_name = inp.name + input_parts.append(inp_name) # Handle modified variables (appear in both inputs and outputs) inputs_set = {inp.name for inp in intermediates_inputs} diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py new file mode 100644 index 000000000000..f3f961d61a13 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"] + _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] + _import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"] + _import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_pipeline_presets import StableDiffusionXLAutoPipeline + from .modular_loader import StableDiffusionXLModularLoader + from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep + from .decoders import StableDiffusionXLAutoDecodeStep +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py new file mode 100644 index 000000000000..07f096249c0d --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -0,0 +1,1764 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...models import ControlNetModel, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel +from ...utils import logging +from ...utils.torch_utils import randn_tensor, unwrap_module + +from ...pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel +from ...schedulers import EulerDiscreteScheduler +from ...configuration_utils import FrozenDict + +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + AutoPipelineBlocks, + ModularLoader, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def prepare_latents_img2img(vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True): + + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + latents_mean = latents_std = None + if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: + latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: + latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) + # make sure the VAE is in float32 mode, as it overflows in float16 + if vae.config.force_upcast: + image = image.float() + vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(vae.encode(image), generator=generator) + + if vae.config.force_upcast: + vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std + else: + init_latents = vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + +class StableDiffusionXLInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), + InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), + InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [ + OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), + OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), + ] + + def check_inputs(self, components, block_state): + + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." + ) + + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): + raise ValueError("`ip_adapter_embeds` must be a list") + + if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list): + raise ValueError("`negative_ip_adapter_embeds` must be a list") + + if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: + raise ValueError( + "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" + f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" + f" {block_state.negative_ip_adapter_embeds[i].shape}." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.negative_pooled_prompt_embeds is not None: + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) + + if block_state.negative_ip_adapter_embeds is not None: + for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): + block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + \ + "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + InputParam("strength", default=0.3), + InputParam("denoising_start"), + # YiYi TODO: do we need num_images_per_prompt here? + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), + OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") + ] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components + def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (denoising_start * components.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if components.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(components.scheduler.timesteps) - num_inference_steps + timesteps = components.scheduler.timesteps[t_start:] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas + ) + + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( + components, + block_state.num_inference_steps, + block_state.strength, + block_state.device, + denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, + ) + block_state.latent_timestep = block_state.timesteps[:1].repeat(block_state.batch_size * block_state.num_images_per_prompt) + + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLSetTimestepsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that sets the scheduler's timesteps for inference" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas + ) + + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that prepares the latents for the inpainting process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + InputParam( + "strength", + default=0.9999, + description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " + "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " + "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " + "be maximum and the denoising process will run for the full number of iterations specified in " + "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " + "`denoising_start` being declared as an integer, the value of `strength` will be ignored." + ), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." + ), + InputParam( + "mask", + required=True, + type_hint=torch.Tensor, + description="The mask for the inpainting generation. Can be generated in vae_encode step." + ), + InputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs" + ) + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), + OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] + + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + @staticmethod + def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument + def prepare_latents_inpaint( + self, + components, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(components, image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * components.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + block_state.is_strength_max = block_state.strength == 1.0 + + # for non-inpainting specific unet, we do not need masked_image_latents + if hasattr(components,"unet") and components.unet is not None: + if components.unet.config.in_channels == 4: + block_state.masked_image_latents = None + + block_state.add_noise = True if block_state.denoising_start is None else False + + block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor + block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor + + block_state.latents, block_state.noise = self.prepare_latents_inpaint( + components, + block_state.batch_size * block_state.num_images_per_prompt, + components.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + image=block_state.image_latents, + timestep=block_state.latent_timestep, + is_strength_max=block_state.is_strength_max, + add_noise=block_state.add_noise, + return_noise=True, + return_image_latents=False, + ) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image_latents, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that prepares the latents for the image-to-image generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), + InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + block_state.add_noise = True if block_state.denoising_start is None else False + if block_state.latents is None: + block_state.latents = prepare_latents_img2img( + components.vae, + components.scheduler, + block_state.image_latents, + block_state.latent_timestep, + block_state.batch_size, + block_state.num_images_per_prompt, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.add_noise, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Prepare latents step that prepares the latents for the text-to-image generation process" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs" + ) + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process" + ) + ] + + + @staticmethod + def check_inputs(components, block_state): + if ( + block_state.height is not None + and block_state.height % components.vae_scale_factor != 0 + or block_state.width is not None + and block_state.width % components.vae_scale_factor != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components + @staticmethod + def prepare_latents(components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * components.scheduler.init_noise_sigma + return latents + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.dtype is None: + block_state.dtype = components.vae.dtype + + block_state.device = components._execution_device + + self.check_inputs(components, block_state) + + block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor + block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor + block_state.num_channels_latents = components.num_channels_latents + block_state.latents = self.prepare_latents( + components, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ConfigSpec("requires_aesthetics_score", False),] + + @property + def description(self) -> str: + return ( + "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam("aesthetic_score", default=6.0), + InputParam("negative_aesthetic_score", default=2.0), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components + @staticmethod + def _get_add_time_ids_img2img( + components, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if components.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod + def get_guidance_scale_embedding( + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.vae_scale_factor = components.vae_scale_factor + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * block_state.vae_scale_factor + block_state.width = block_state.width * block_state.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + if block_state.negative_original_size is None: + block_state.negative_original_size = block_state.original_size + if block_state.negative_target_size is None: + block_state.negative_target_size = block_state.target_size + + block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids_img2img( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.aesthetic_score, + block_state.negative_aesthetic_score, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + dtype=block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Step that prepares the additional conditioning for the text-to-image generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components + @staticmethod + def _get_add_time_ids( + components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod + def get_guidance_scale_embedding( + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + block_state.add_time_ids = self._get_add_time_ids( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + if block_state.negative_original_size is not None and block_state.negative_target_size is not None: + block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + else: + block_state.negative_add_time_ids = block_state.add_time_ids + + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLControlNetInputStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetModel), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "step that prepare inputs for controlnet" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), + OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), + OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + # (1) prepare controlnet inputs + block_state.device = components._execution_device + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + controlnet = unwrap_module(components.controlnet) + + # (1.1) + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], + ) + + # (1.2) + # controlnet_conditioning_scale (align format) + if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) + + # (1.3) + # global_pool_conditions + block_state.global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + # (1.4) + # guess_mode + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # (1.5) + # control_image + if isinstance(controlnet, ControlNetModel): + block_state.control_image = self.prepare_control_image( + components, + image=block_state.control_image, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in block_state.control_image: + control_image = self.prepare_control_image( + components, + image=control_image_, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + + control_images.append(control_image) + + block_state.control_image = control_images + else: + assert False + + # (1.6) + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + keeps = [ + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) + ] + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetUnionModel), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "step that prepares inputs for the ControlNetUnion model" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_mode", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of model tensor inputs. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), + OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), + OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), + OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), + OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + device = components._execution_device + dtype = block_state.dtype or components.controlnet.dtype + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + + # guess_mode + block_state.global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # control_image + if not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + # control_mode + if not isinstance(block_state.control_mode, list): + block_state.control_mode = [block_state.control_mode] + + if len(block_state.control_image) != len(block_state.control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + # control_type + block_state.num_control_type = controlnet.config.num_control_type + block_state.control_type = [0 for _ in range(block_state.num_control_type)] + for control_idx in block_state.control_mode: + block_state.control_type[control_idx] = 1 + block_state.control_type = torch.Tensor(block_state.control_type) + + block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) + repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] + block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) + + # prepare control_image + for idx, _ in enumerate(block_state.control_image): + block_state.control_image[idx] = self.prepare_control_image( + components, + image=block_state.control_image[idx], + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=device, + dtype=dtype, + crops_coords=block_state.crops_coords, + ) + block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + block_state.controlnet_keep.append( + 1.0 + - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) + ) + block_state.control_type_idx = block_state.control_mode + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): + + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] + block_names = ["controlnet_union", "controlnet"] + block_trigger_inputs = ["control_mode", "control_image"] + + + +# Before denoise +class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] + block_names = ["inpaint", "img2img", "text2img"] + block_trigger_inputs = ["mask", "image_latents", None] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step.\n" + \ + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + \ + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + \ + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + \ + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py new file mode 100644 index 000000000000..ca848e20984f --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -0,0 +1,215 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +import numpy as np +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...models import AutoencoderKL +from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ...utils import logging + +from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from ...configuration_utils import FrozenDict + +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + AutoPipelineBlocks, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + + +class StableDiffusionXLDecodeStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components + @staticmethod + def upcast_vae(components): + dtype = components.vae.dtype + components.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + components.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + components.vae.post_quant_conv.to(dtype) + components.vae.decoder.conv_in.to(dtype) + components.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if not block_state.output_type == "latent": + latents = block_state.latents + # make sure the VAE is in float32 mode, as it overflows in float16 + block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast + + if block_state.needs_upcasting: + self.upcast_vae(components) + latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != components.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + components.vae = components.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + block_state.has_latents_mean = ( + hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None + ) + block_state.has_latents_std = ( + hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None + ) + if block_state.has_latents_mean and block_state.has_latents_std: + block_state.latents_mean = ( + torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + block_state.latents_std = ( + torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + else: + latents = latents / components.vae.config.scaling_factor + + block_state.images = components.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if block_state.needs_upcasting: + components.vae.to(dtype=torch.float16) + else: + block_state.images = block_state.latents + + # apply watermark if available + if hasattr(components, "watermark") and components.watermark is not None: + block_state.images = components.watermark.apply_watermark(block_state.images) + + block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \ + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"), + InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.") + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: + block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images] + + self.add_block_state(state, block_state) + + return components, state + + + +class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep] + block_names = ["decode", "mask_overlay"] + + @property + def description(self): + return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + \ + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image" + + +class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] + block_names = ["inpaint", "non-inpaint"] + block_trigger_inputs = ["padding_mask_crop", None] + + @property + def description(self): + return "Decode step that decode the denoised latents into images outputs.\n" + \ + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." + + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py new file mode 100644 index 000000000000..bc567a6b034f --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -0,0 +1,1334 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from tqdm.auto import tqdm + +from ...configuration_utils import FrozenDict +from ...models import ControlNetModel, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import unwrap_module + +from ...guiders import ClassifierFreeGuidance +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + PipelineBlock, + PipelineState, + AutoPipelineBlocks, + LoopSequentialPipelineBlocks, + BlockState, +) +from dataclasses import asdict + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +# YiYi experimenting composible denoise loop +# loop step (1): prepare latent input for denoiser +class StableDiffusionXLDenoiseLoopBeforeDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + + return components, block_state + +# loop step (1): prepare latent input for denoiser (with inpainting) +class StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + + @staticmethod + def check_inputs(components, block_state): + + num_channels_unet = components.num_channels_unet + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if block_state.mask is None or block_state.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `components.unet` or your `mask_image` or `image` input." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + self.check_inputs(components, block_state) + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + if components.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + + return components, block_state + +# loop step (2): denoise the latents with guidance +class StableDiffusionXLDenoiseLoopDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=cond_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (2): denoise the latents with guidance (with controlnet) +class StableDiffusionXLControlNetDenoiseLoopDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "controlnet_cond", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "conditioning_scale", + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + InputParam( + kwargs_type="controlnet_kwargs", + description=( + "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" + "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ) + ] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + extra_controlnet_kwargs = self.prepare_extra_kwargs(components.controlnet.forward, **block_state.controlnet_kwargs) + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] + else: + controlnet_cond_scale = block_state.conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + + # Prepare additional conditionings + added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: + added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds + + # Prepare controlnet additional conditionings + controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + # run controlnet for the guidance batch + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + down_block_res_samples = block_state.down_block_res_samples_zeros + mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + down_block_res_samples, mid_block_res_sample = components.controlnet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + controlnet_cond=block_state.controlnet_cond, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + **extra_controlnet_kwargs, + ) + + # assign it to block_state so it will be available for the uncond guidance batch + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) + + # Predict the noise + # store the noise_pred in guider_state_batch so we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (3): scheduler step to update latents +class StableDiffusionXLDenoiseLoopAfterDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("generator"), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + #YiYi TODO: move this out of here + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + return components, block_state + +# loop step (3): scheduler step to update latents (with inpainting) +class StableDiffusionXLInpaintDenoiseLoopAfterDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("generator"), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + def check_inputs(self, components, block_state): + if components.num_channels_unet == 4: + if block_state.image_latents is None: + raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") + if block_state.mask is None: + raise ValueError(f"mask is required for this step {self.__class__.__name__}") + if block_state.noise is None: + raise ValueError(f"noise is required for this step {self.__class__.__name__}") + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + self.check_inputs(components, block_state) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + # adjust latent for inpainting + if components.num_channels_unet == 4: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) + ) + + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + + + + return components, block_state + + +# the loop wrapper that iterates over the timesteps +class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() + else: + components.guider.enable() + + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): + progress_bar.update() + + self.add_block_state(state, block_state) + + return components, state + + +# composing the denoising loops +class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# control_cond +class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# mask +class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# control_cond + mask +class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + + +# all task without controlnet +class StableDiffusionXLDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDenoiseLoop, StableDiffusionXLDenoiseLoop] + block_names = ["inpaint_denoise", "denoise"] + block_trigger_inputs = ["mask", None] + +# all task with controlnet +class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop] + block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] + block_trigger_inputs = ["mask", None] + +# all task with or without controlnet +class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] + block_names = ["controlnet_denoise", "denoise"] + block_trigger_inputs = ["controlnet_cond", None] + + + + + + + +# YiYi Notes: alternatively, this is you can just write the denoise loop using a pipeline block, easier but not composible +# class StableDiffusionXLDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ] + +# @property +# def description(self) -> str: +# return ( +# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" +# ) + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("num_images_per_prompt", default=1), +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) + +# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components +# @staticmethod +# def prepare_extra_step_kwargs(components, generator, eta): +# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature +# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. +# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 +# # and should be between [0, 1] + +# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# extra_step_kwargs = {} +# if accepts_eta: +# extra_step_kwargs["eta"] = eta + +# # check if the scheduler accepts generator +# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# if accepts_generator: +# extra_step_kwargs["generator"] = generator +# return extra_step_kwargs + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) + +# block_state.num_channels_unet = components.unet.config.in_channels +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_data = components.guider.prepare_inputs(block_state) + +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + +# # Prepare for inpainting +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + +# for batch in guider_data: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# batch.added_cond_kwargs = { +# "text_embeds": batch.pooled_prompt_embeds, +# "time_ids": batch.add_time_ids, +# } +# if batch.ip_adapter_embeds is not None: +# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + +# # Predict the noise residual +# batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=batch.added_cond_kwargs, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) + +# return components, state + + + +# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ComponentSpec("controlnet", ControlNetModel), +# ] + +# @property +# def description(self) -> str: +# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("num_images_per_prompt", default=1), +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "controlnet_cond", +# required=True, +# type_hint=torch.Tensor, +# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_start", +# required=True, +# type_hint=float, +# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_end", +# required=True, +# type_hint=float, +# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "conditioning_scale", +# type_hint=float, +# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "guess_mode", +# required=True, +# type_hint=bool, +# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "controlnet_keep", +# required=True, +# type_hint=List[float], +# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "crops_coords", +# type_hint=Optional[Tuple[int]], +# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) +# @staticmethod +# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + +# accepted_kwargs = set(inspect.signature(func).parameters.keys()) +# extra_kwargs = {} +# for key, value in kwargs.items(): +# if key in accepted_kwargs and key not in exclude_kwargs: +# extra_kwargs[key] = value + +# return extra_kwargs + + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) +# block_state.device = components._execution_device +# print(f" block_state: {block_state}") + +# controlnet = unwrap_module(components.controlnet) + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) +# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) + +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# # (1) setup guider +# # disable for LCMs +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# # (5) Denoise loop +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): + +# # prepare latent input for unet +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) +# # adjust latent input for inpainting +# block_state.num_channels_unet = components.unet.config.in_channels +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + +# # cond_scale (controlnet input) +# if isinstance(block_state.controlnet_keep[i], list): +# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] +# else: +# block_state.controlnet_cond_scale = block_state.conditioning_scale +# if isinstance(block_state.controlnet_cond_scale, list): +# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] +# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] + +# # default controlnet output/unet input for guess mode + conditional path +# block_state.down_block_res_samples_zeros = None +# block_state.mid_block_res_sample_zeros = None + +# # guided denoiser step +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_state = components.guider.prepare_inputs(block_state) + +# for guider_state_batch in guider_state: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# guider_state_batch.added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } +# if guider_state_batch.ip_adapter_embeds is not None: +# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds + +# # Prepare controlnet additional conditionings +# guider_state_batch.controlnet_added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } + +# if block_state.guess_mode and not components.guider.is_conditional: +# # guider always run uncond batch first, so these tensors should be set already +# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros +# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros +# else: +# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# controlnet_cond=block_state.controlnet_cond, +# conditioning_scale=block_state.conditioning_scale, +# guess_mode=block_state.guess_mode, +# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, +# return_dict=False, +# **block_state.extra_controlnet_kwargs, +# ) + +# if block_state.down_block_res_samples_zeros is None: +# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] +# if block_state.mid_block_res_sample_zeros is None: +# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) + + + +# guider_state_batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=guider_state_batch.added_cond_kwargs, +# down_block_additional_residuals=guider_state_batch.down_block_res_samples, +# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# # adjust latent for inpainting +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) + +# return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py new file mode 100644 index 000000000000..ca4efe2c4a7f --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -0,0 +1,858 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel +from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ...models.lora import adjust_lora_scale_text_encoder +from ...utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor, unwrap_module +from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel +from ...configuration_utils import FrozenDict + +from transformers import ( + CLIPTextModel, + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...schedulers import EulerDiscreteScheduler +from ...guiders import ClassifierFreeGuidance + +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline import PipelineBlock, PipelineState, AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec + +import numpy as np + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLIPAdapterStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + + @property + def description(self) -> str: + return ( + "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" + " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" + " for more details" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_encoder", CLIPVisionModelWithProjection), + ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "ip_adapter_image", + PipelineImageInput, + required=True, + description="The image(s) to be used as ip adapter" + ) + ] + + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), + OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") + ] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components + @staticmethod + def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(components.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = components.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = components.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = components.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds + ): + image_embeds = [] + if prepare_unconditional_embeds: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + components, single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if prepare_unconditional_embeds: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if prepare_unconditional_embeds: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if prepare_unconditional_embeds: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( + components, + ip_adapter_image=block_state.ip_adapter_image, + ip_adapter_image_embeds=None, + device=block_state.device, + num_images_per_prompt=1, + prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, + ) + if block_state.prepare_unconditional_embeds: + block_state.negative_ip_adapter_embeds = [] + for i, image_embeds in enumerate(block_state.ip_adapter_embeds): + negative_image_embeds, image_embeds = image_embeds.chunk(2) + block_state.negative_ip_adapter_embeds.append(negative_image_embeds) + block_state.ip_adapter_embeds[i] = image_embeds + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLTextEncoderStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return( + "Text Encoder step that generate text_embeddings to guide the image generation" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ConfigSpec("force_zeros_for_empty_prompt", True)] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("negative_prompt"), + InputParam("negative_prompt_2"), + InputParam("cross_attention_kwargs"), + InputParam("clip_skip"), + ] + + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + ] + + @staticmethod + def check_inputs(block_state): + + if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prepare_unconditional_embeds: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or components._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): + components._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if components.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) + else: + scale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) + else: + scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2] + text_encoders = ( + [components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt + if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif prepare_unconditional_embeds and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if components.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if prepare_unconditional_embeds: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if components.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if prepare_unconditional_embeds: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if components.text_encoder is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + # Encode input prompt + block_state.text_encoder_lora_scale = ( + block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None + ) + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + block_state.pooled_prompt_embeds, + block_state.negative_pooled_prompt_embeds, + ) = self.encode_prompt( + components, + block_state.prompt, + block_state.prompt_2, + block_state.device, + 1, + block_state.prepare_unconditional_embeds, + block_state.negative_prompt, + block_state.negative_prompt_2, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + lora_scale=block_state.text_encoder_lora_scale, + clip_skip=block_state.clip_skip, + ) + # Add outputs + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLVaeEncoderStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + + @property + def description(self) -> str: + return ( + "Vae Encoder step that encode the input image into a latent representation" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image", required=True), + InputParam("height"), + InputParam("width"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} + block_state.device = components._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs) + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + + block_state.batch_size = block_state.image.shape[0] + + # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) + if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." + ) + + + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ComponentSpec( + "mask_processor", + VaeImageProcessor, + config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), + default_creation_method="from_config"), + ] + + + @property + def description(self) -> str: + return ( + "Vae encoder step that prepares the image and mask for the inpainting process" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + InputParam("generator"), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), + OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + if block_state.padding_mask_crop is not None: + block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop) + block_state.resize_mode = "fill" + else: + block_state.crops_coords = None + block_state.resize_mode = "default" + + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode) + block_state.image = block_state.image.to(dtype=torch.float32) + + block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords) + block_state.masked_image = block_state.image * (block_state.mask < 0.5) + + block_state.batch_size = block_state.image.shape[0] + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image, + block_state.batch_size, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.add_block_state(state, block_state) + + + return components, state + + + +# auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file) +# Encode +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return "Vae encoder step that encode the image inputs into their latent representations.\n" + \ + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \ + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \ + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." + + +class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin): + block_classes = [StableDiffusionXLIPAdapterStep] + block_names = ["ip_adapter"] + block_trigger_inputs = ["ip_adapter_image"] + + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided." + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py new file mode 100644 index 000000000000..4af942af64e6 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py @@ -0,0 +1,174 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple, Union, Dict +import PIL +import torch +import numpy as np + +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...image_processor import PipelineImageInput +from ...pipelines.pipeline_utils import StableDiffusionMixin +from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from ...utils import logging + +from ..modular_pipeline import ModularLoader +from ..modular_pipeline_utils import InputParam, OutputParam + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? +# YiYi Notes: model specific components: +## (1) it should inherit from ModularLoader +## (2) acts like a container that holds components and configs +## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents +## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) +## (5) how to use together with Components_manager? +class StableDiffusionXLModularLoader( + ModularLoader, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + ModularIPAdapterMixin, +): + @property + def default_sample_size(self): + default_sample_size = 128 + if hasattr(self, "unet") and self.unet is not None: + default_sample_size = self.unet.config.sample_size + return default_sample_size + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_unet(self): + num_channels_unet = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_unet = self.unet.config.in_channels + return num_channels_unet + + @property + def num_channels_latents(self): + num_channels_latents = 4 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + return num_channels_latents + + + +# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks +SDXL_INPUTS_SCHEMA = { + "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), + "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), + "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), + "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), + "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), + "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), + "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), + "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), + "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), + "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), + "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), + "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), + "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), + "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), + "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), + # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 + "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), + "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), + "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), + "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), + "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), + "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), + "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), + "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), + "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), + "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), + "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), + "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), + "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), + "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), + "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), + "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), + "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), + "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), + "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), + "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), + "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") +} + + +SDXL_INTERMEDIATE_INPUTS_SCHEMA = { + "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), + "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), + "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), + "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), + "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), + "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), + "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), + "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), + "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") +} + + +SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { + "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), + "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), + "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), + "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), + "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), + "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), + "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), + "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), + "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") +} + + +SDXL_OUTPUTS_SCHEMA = { + "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") +} + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py new file mode 100644 index 000000000000..6d909ab5a4a0 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py @@ -0,0 +1,128 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +# Import all the necessary block classes +from .denoise import ( + StableDiffusionXLAutoDenoiseStep, + StableDiffusionXLDenoiseStep, + StableDiffusionXLControlNetDenoiseStep +) +from .before_denoise import ( + StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLInputStep, + StableDiffusionXLSetTimestepsStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLControlNetInputStep, + StableDiffusionXLControlNetUnionInputStep +) +from .encoders import ( + StableDiffusionXLTextEncoderStep, + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLVaeEncoderStep, + StableDiffusionXLInpaintVaeEncoderStep, + StableDiffusionXLIPAdapterStep +) +from .decoders import ( + StableDiffusionXLDecodeStep, + StableDiffusionXLInpaintDecodeStep, + StableDiffusionXLAutoDecodeStep +) + + +# YiYi notes: comment out for now, work on this later +# block mapping +TEXT2IMAGE_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLSetTimestepsStep), + ("prepare_latents", StableDiffusionXLPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep) +]) + +IMAGE2IMAGE_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep) +]) + +INPAINT_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLInpaintDecodeStep) +]) + +CONTROLNET_BLOCKS = OrderedDict([ + ("controlnet_input", StableDiffusionXLControlNetInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), +]) + +CONTROLNET_UNION_BLOCKS = OrderedDict([ + ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), +]) + +IP_ADAPTER_BLOCKS = OrderedDict([ + ("ip_adapter", StableDiffusionXLIPAdapterStep), +]) + +AUTO_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), + ("decode", StableDiffusionXLAutoDecodeStep) +]) + +AUTO_CORE_BLOCKS = OrderedDict([ + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), +]) + + +SDXL_SUPPORTED_BLOCKS = { + "text2img": TEXT2IMAGE_BLOCKS, + "img2img": IMAGE2IMAGE_BLOCKS, + "inpaint": INPAINT_BLOCKS, + "controlnet": CONTROLNET_BLOCKS, + "controlnet_union": CONTROLNET_UNION_BLOCKS, + "ip_adapter": IP_ADAPTER_BLOCKS, + "auto": AUTO_BLOCKS +} + + + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py new file mode 100644 index 000000000000..637c7ac306d7 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py @@ -0,0 +1,43 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple, Union, Dict +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks + +from .denoise import StableDiffusionXLAutoDenoiseStep +from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep +from .decoders import StableDiffusionXLAutoDecodeStep +from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] + block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decoder"] + + @property + def description(self): + return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ + "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ + "- to run the controlnet workflow, you need to provide `control_image`\n" + \ + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ + "- for text-to-image generation, all you need to provide is `prompt`" + + + + diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 0567eb687c62..011f23ed371c 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -47,7 +47,6 @@ "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] - _import_structure["modular_pipeline"] = ["ModularLoader"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] @@ -330,8 +329,6 @@ "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLModularLoader", - "StableDiffusionXLAutoPipeline", ] ) _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] @@ -481,7 +478,6 @@ from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline - from .modular_pipeline import ModularLoader from .pipeline_utils import ( AudioPipelineOutput, DiffusionPipeline, @@ -706,9 +702,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularLoader, StableDiffusionXLPipeline, - StableDiffusionXLAutoPipeline, ) from .stable_video_diffusion import StableVideoDiffusionPipeline from .t2i_adapter import ( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 006836fe30d4..8088fbcfceba 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -29,18 +29,6 @@ _import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"] _import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"] _import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"] - _import_structure["pipeline_stable_diffusion_xl_modular"] = [ - "StableDiffusionXLControlNetDenoiseStep", - "StableDiffusionXLDecodeLatentsStep", - "StableDiffusionXLDenoiseStep", - "StableDiffusionXLInputStep", - "StableDiffusionXLModularLoader", - "StableDiffusionXLPrepareAdditionalConditioningStep", - "StableDiffusionXLPrepareLatentsStep", - "StableDiffusionXLSetTimestepsStep", - "StableDiffusionXLTextEncoderStep", - "StableDiffusionXLAutoPipeline", - ] if is_transformers_available() and is_flax_available(): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState @@ -60,18 +48,6 @@ from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline - from .pipeline_stable_diffusion_xl_modular import ( - StableDiffusionXLControlNetDenoiseStep, - StableDiffusionXLDecodeLatentsStep, - StableDiffusionXLDenoiseStep, - StableDiffusionXLInputStep, - StableDiffusionXLModularLoader, - StableDiffusionXLPrepareAdditionalConditioningStep, - StableDiffusionXLPrepareLatentsStep, - StableDiffusionXLSetTimestepsStep, - StableDiffusionXLTextEncoderStep, - StableDiffusionXLAutoPipeline, - ) try: if not (is_transformers_available() and is_flax_available()): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py deleted file mode 100644 index 5ae9e63851db..000000000000 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ /dev/null @@ -1,3713 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, List, Optional, Tuple, Union, Dict - -import PIL -import torch -from collections import OrderedDict - -from ...image_processor import VaeImageProcessor, PipelineImageInput -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin -from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel -from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor -from ...models.lora import adjust_lora_scale_text_encoder -from ...utils import ( - USE_PEFT_BACKEND, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import randn_tensor, unwrap_module -from ..controlnet.multicontrolnet import MultiControlNetModel -from ..modular_pipeline import ( - AutoPipelineBlocks, - ModularLoader, - PipelineBlock, - PipelineState, - InputParam, - OutputParam, - SequentialPipelineBlocks, - ComponentSpec, - ConfigSpec, -) -from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from .pipeline_output import ( - StableDiffusionXLPipelineOutput, -) - -from transformers import ( - CLIPTextModel, - CLIPImageProcessor, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) - -from ...schedulers import EulerDiscreteScheduler -from ...guiders import ClassifierFreeGuidance -from ...configuration_utils import FrozenDict - -import numpy as np - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - - -# YiYi Notes: I think we do not need this, we can add loader methods on the components class -class StableDiffusionXLLoraStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Lora step that handles all the lora related tasks: load/unload lora weights into unet and text encoders, manage lora adapters etc" - " See [StableDiffusionXLLoraLoaderMixin](https://huggingface.co/docs/diffusers/api/loaders/lora#diffusers.loaders.StableDiffusionXLLoraLoaderMixin)" - " for more details" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("text_encoder", CLIPTextModel), - ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), - ComponentSpec("unet", UNet2DConditionModel), - ] - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise EnvironmentError("StableDiffusionXLLoraStep is desgined to be used to load lora weights, __call__ is not implemented") - - -class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin): - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" - " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" - " for more details" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("image_encoder", CLIPVisionModelWithProjection), - ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "ip_adapter_image", - PipelineImageInput, - required=True, - description="The image(s) to be used as ip adapter" - ) - ] - - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), - OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") - ] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components - def encode_image(self, components, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(components.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = components.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = components.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = components.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - - # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds - ): - image_embeds = [] - if prepare_unconditional_embeds: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - components, single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if prepare_unconditional_embeds: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if prepare_unconditional_embeds: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if prepare_unconditional_embeds: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1 - data.device = pipeline._execution_device - - data.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( - pipeline, - ip_adapter_image=data.ip_adapter_image, - ip_adapter_image_embeds=None, - device=data.device, - num_images_per_prompt=1, - prepare_unconditional_embeds=data.prepare_unconditional_embeds, - ) - if data.prepare_unconditional_embeds: - data.negative_ip_adapter_embeds = [] - for i, image_embeds in enumerate(data.ip_adapter_embeds): - negative_image_embeds, image_embeds = image_embeds.chunk(2) - data.negative_ip_adapter_embeds.append(negative_image_embeds) - data.ip_adapter_embeds[i] = image_embeds - - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLTextEncoderStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return( - "Text Encoder step that generate text_embeddings to guide the image generation" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("text_encoder", CLIPTextModel), - ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), - ComponentSpec("tokenizer", CLIPTokenizer), - ComponentSpec("tokenizer_2", CLIPTokenizer), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ] - - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ConfigSpec("force_zeros_for_empty_prompt", True)] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("prompt"), - InputParam("prompt_2"), - InputParam("negative_prompt"), - InputParam("negative_prompt_2"), - InputParam("cross_attention_kwargs"), - InputParam("clip_skip"), - ] - - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), - ] - - def check_inputs(self, pipeline, data): - - if data.prompt is not None and (not isinstance(data.prompt, str) and not isinstance(data.prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(data.prompt)}") - elif data.prompt_2 is not None and (not isinstance(data.prompt_2, str) and not isinstance(data.prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(data.prompt_2)}") - - def encode_prompt( - self, - components, - prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prepare_unconditional_embeds: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prepare_unconditional_embeds (`bool`): - whether to use prepare unconditional embeddings or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): - components._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if components.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) - else: - scale_lora_layers(components.text_encoder, lora_scale) - - if components.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) - else: - scale_lora_layers(components.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2] - text_encoders = ( - [components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(components, TextualInversionLoaderMixin): - prompt = components.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt - if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif prepare_unconditional_embeds and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(components, TextualInversionLoaderMixin): - negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - if components.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if prepare_unconditional_embeds: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - if components.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if prepare_unconditional_embeds: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - if components.text_encoder is not None: - if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder, lora_scale) - - if components.text_encoder_2 is not None: - if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - # Get inputs and intermediates - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1 - data.device = pipeline._execution_device - - # Encode input prompt - data.text_encoder_lora_scale = ( - data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None - ) - ( - data.prompt_embeds, - data.negative_prompt_embeds, - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) = self.encode_prompt( - pipeline, - data.prompt, - data.prompt_2, - data.device, - 1, - data.prepare_unconditional_embeds, - data.negative_prompt, - data.negative_prompt_2, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - lora_scale=data.text_encoder_lora_scale, - clip_skip=data.clip_skip, - ) - # Add outputs - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLVaeEncoderStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "Vae Encoder step that encode the input image into a latent representation" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), - ] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("image", required=True), - InputParam("generator"), - InputParam("height"), - InputParam("width"), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.preprocess_kwargs = data.preprocess_kwargs or {} - data.device = pipeline._execution_device - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, **data.preprocess_kwargs) - data.image = data.image.to(device=data.device, dtype=data.dtype) - - data.batch_size = data.image.shape[0] - - # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) - if isinstance(data.generator, list) and len(data.generator) != data.batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(data.generator)}, but requested an effective batch" - f" size of {data.batch_size}. Make sure the batch size matches the length of the generators." - ) - - - data.image_latents = self._encode_vae_image(pipeline,image=data.image, generator=data.generator) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), - ComponentSpec( - "mask_processor", - VaeImageProcessor, - config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), - default_creation_method="from_config"), - ] - - - @property - def description(self) -> str: - return ( - "Vae encoder step that prepares the image and mask for the inpainting process" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("height"), - InputParam("width"), - InputParam("generator"), - InputParam("image", required=True), - InputParam("mask_image", required=True), - InputParam("padding_mask_crop"), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), - OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, components, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - - return mask, masked_image_latents - - - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - - data = self.get_block_state(state) - - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device - - if data.padding_mask_crop is not None: - data.crops_coords = pipeline.mask_processor.get_crop_region(data.mask_image, data.width, data.height, pad=data.padding_mask_crop) - data.resize_mode = "fill" - else: - data.crops_coords = None - data.resize_mode = "default" - - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, crops_coords=data.crops_coords, resize_mode=data.resize_mode) - data.image = data.image.to(dtype=torch.float32) - - data.mask = pipeline.mask_processor.preprocess(data.mask_image, height=data.height, width=data.width, resize_mode=data.resize_mode, crops_coords=data.crops_coords) - data.masked_image = data.image * (data.mask < 0.5) - - data.batch_size = data.image.shape[0] - data.image = data.image.to(device=data.device, dtype=data.dtype) - data.image_latents = self._encode_vae_image(pipeline, image=data.image, generator=data.generator) - - # 7. Prepare mask latent variables - data.mask, data.masked_image_latents = self.prepare_mask_latents( - pipeline, - data.mask, - data.masked_image, - data.batch_size, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - ) - - self.add_block_state(state, data) - - - return pipeline, state - - -class StableDiffusionXLInputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Input processing step that:\n" - " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" - " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" - "All input tensors are expected to have either batch_size=1 or match the batch_size\n" - "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" - "have a final batch_size of batch_size * num_images_per_prompt." - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), - InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), - InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), - OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), - OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="image embeddings for IP-Adapter"), - OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="negative image embeddings for IP-Adapter"), - ] - - def check_inputs(self, pipeline, data): - - if data.prompt_embeds is not None and data.negative_prompt_embeds is not None: - if data.prompt_embeds.shape != data.negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {data.prompt_embeds.shape} != `negative_prompt_embeds`" - f" {data.negative_prompt_embeds.shape}." - ) - - if data.prompt_embeds is not None and data.pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if data.negative_prompt_embeds is not None and data.negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - if data.ip_adapter_embeds is not None and not isinstance(data.ip_adapter_embeds, list): - raise ValueError("`ip_adapter_embeds` must be a list") - - if data.negative_ip_adapter_embeds is not None and not isinstance(data.negative_ip_adapter_embeds, list): - raise ValueError("`negative_ip_adapter_embeds` must be a list") - - if data.ip_adapter_embeds is not None and data.negative_ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): - if ip_adapter_embed.shape != data.negative_ip_adapter_embeds[i].shape: - raise ValueError( - "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" - f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" - f" {data.negative_ip_adapter_embeds[i].shape}." - ) - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.batch_size = data.prompt_embeds.shape[0] - data.dtype = data.prompt_embeds.dtype - - _, seq_len, _ = data.prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - data.prompt_embeds = data.prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.prompt_embeds = data.prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) - - if data.negative_prompt_embeds is not None: - _, seq_len, _ = data.negative_prompt_embeds.shape - data.negative_prompt_embeds = data.negative_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.negative_prompt_embeds = data.negative_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) - - data.pooled_prompt_embeds = data.pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.pooled_prompt_embeds = data.pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) - - if data.negative_pooled_prompt_embeds is not None: - data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) - - if data.ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): - data.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * data.num_images_per_prompt, dim=0) - - if data.negative_ip_adapter_embeds is not None: - for i, negative_ip_adapter_embed in enumerate(data.negative_ip_adapter_embeds): - data.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * data.num_images_per_prompt, dim=0) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + \ - "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("num_inference_steps", default=50), - InputParam("timesteps"), - InputParam("sigmas"), - InputParam("denoising_end"), - InputParam("strength", default=0.3), - InputParam("denoising_start"), - # YiYi TODO: do we need num_images_per_prompt here? - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), - OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") - ] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components - def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None): - # get the original timestep using init_timestep - if denoising_start is None: - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - t_start = max(num_inference_steps - init_timestep, 0) - - timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] - if hasattr(components.scheduler, "set_begin_index"): - components.scheduler.set_begin_index(t_start * components.scheduler.order) - - return timesteps, num_inference_steps - t_start - - else: - # Strength is irrelevant if we directly request a timestep to start at; - # that is, strength is determined by the denoising_start instead. - discrete_timestep_cutoff = int( - round( - components.scheduler.config.num_train_timesteps - - (denoising_start * components.scheduler.config.num_train_timesteps) - ) - ) - - num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() - if components.scheduler.order == 2 and num_inference_steps % 2 == 0: - # if the scheduler is a 2nd order scheduler we might have to do +1 - # because `num_inference_steps` might be even given that every timestep - # (except the highest one) is duplicated. If `num_inference_steps` is even it would - # mean that we cut the timesteps in the middle of the denoising step - # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 - # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler - num_inference_steps = num_inference_steps + 1 - - # because t_n+1 >= t_n, we slice the timesteps starting from the end - t_start = len(components.scheduler.timesteps) - num_inference_steps - timesteps = components.scheduler.timesteps[t_start:] - if hasattr(components.scheduler, "set_begin_index"): - components.scheduler.set_begin_index(t_start) - return timesteps, num_inference_steps - - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.device = pipeline._execution_device - - data.timesteps, data.num_inference_steps = retrieve_timesteps( - pipeline.scheduler, data.num_inference_steps, data.device, data.timesteps, data.sigmas - ) - - def denoising_value_valid(dnv): - return isinstance(dnv, float) and 0 < dnv < 1 - - data.timesteps, data.num_inference_steps = self.get_timesteps( - pipeline, - data.num_inference_steps, - data.strength, - data.device, - denoising_start=data.denoising_start if denoising_value_valid(data.denoising_start) else None, - ) - data.latent_timestep = data.timesteps[:1].repeat(data.batch_size * data.num_images_per_prompt) - - if data.denoising_end is not None and isinstance(data.denoising_end, float) and data.denoising_end > 0 and data.denoising_end < 1: - data.discrete_timestep_cutoff = int( - round( - pipeline.scheduler.config.num_train_timesteps - - (data.denoising_end * pipeline.scheduler.config.num_train_timesteps) - ) - ) - data.num_inference_steps = len(list(filter(lambda ts: ts >= data.discrete_timestep_cutoff, data.timesteps))) - data.timesteps = data.timesteps[:data.num_inference_steps] - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLSetTimestepsStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that sets the scheduler's timesteps for inference" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("num_inference_steps", default=50), - InputParam("timesteps"), - InputParam("sigmas"), - InputParam("denoising_end"), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.device = pipeline._execution_device - - data.timesteps, data.num_inference_steps = retrieve_timesteps( - pipeline.scheduler, data.num_inference_steps, data.device, data.timesteps, data.sigmas - ) - - if data.denoising_end is not None and isinstance(data.denoising_end, float) and data.denoising_end > 0 and data.denoising_end < 1: - data.discrete_timestep_cutoff = int( - round( - pipeline.scheduler.config.num_train_timesteps - - (data.denoising_end * pipeline.scheduler.config.num_train_timesteps) - ) - ) - data.num_inference_steps = len(list(filter(lambda ts: ts >= data.discrete_timestep_cutoff, data.timesteps))) - data.timesteps = data.timesteps[:data.num_inference_steps] - - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that prepares the latents for the inpainting process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - InputParam("denoising_start"), - InputParam( - "strength", - default=0.9999, - description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " - "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " - "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " - "be maximum and the denoising process will run for the full number of iterations specified in " - "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " - "`denoising_start` being declared as an integer, the value of `strength` will be ignored." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "latent_timestep", - required=True, - type_hint=torch.Tensor, - description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." - ), - InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, - description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." - ), - InputParam( - "mask", - required=True, - type_hint=torch.Tensor, - description="The mask for the inpainting generation. Can be generated in vae_encode step." - ), - InputParam( - "masked_image_latents", - type_hint=torch.Tensor, - description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." - ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), - OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents with self -> components - def prepare_latents_inpaint( - self, - components, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - image=None, - timestep=None, - is_strength_max=True, - add_noise=True, - return_noise=False, - return_image_latents=False, - ): - shape = ( - batch_size, - num_channels_latents, - int(height) // components.vae_scale_factor, - int(width) // components.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if (image is None or timestep is None) and not is_strength_max: - raise ValueError( - "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." - "However, either the image or the noise timestep has not been provided." - ) - - if image.shape[1] == 4: - image_latents = image.to(device=device, dtype=dtype) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - elif return_image_latents or (latents is None and not is_strength_max): - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(components, image=image, generator=generator) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - - if latents is None and add_noise: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) - # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents - elif add_noise: - noise = latents.to(device) - latents = noise * components.scheduler.init_noise_sigma - else: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = image_latents.to(device) - - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_image_latents: - outputs += (image_latents,) - - return outputs - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, components, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - - return mask, masked_image_latents - - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device - - data.is_strength_max = data.strength == 1.0 - - # for non-inpainting specific unet, we do not need masked_image_latents - if hasattr(pipeline,"unet") and pipeline.unet is not None: - if pipeline.unet.config.in_channels == 4: - data.masked_image_latents = None - - data.add_noise = True if data.denoising_start is None else False - - data.height = data.image_latents.shape[-2] * pipeline.vae_scale_factor - data.width = data.image_latents.shape[-1] * pipeline.vae_scale_factor - - data.latents, data.noise = self.prepare_latents_inpaint( - pipeline, - data.batch_size * data.num_images_per_prompt, - pipeline.num_channels_latents, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - data.latents, - image=data.image_latents, - timestep=data.latent_timestep, - is_strength_max=data.is_strength_max, - add_noise=data.add_noise, - return_noise=True, - return_image_latents=False, - ) - - # 7. Prepare mask latent variables - data.mask, data.masked_image_latents = self.prepare_mask_latents( - pipeline, - data.mask, - data.masked_image_latents, - data.batch_size * data.num_images_per_prompt, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - ) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that prepares the latents for the image-to-image generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - InputParam("denoising_start"), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), - InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), - InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components - # YiYi TODO: refactor using _encode_vae_image - def prepare_latents_img2img( - self, components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - # Offload text encoder if `enable_model_cpu_offload` was enabled - if hasattr(components, "final_offload_hook") and components.final_offload_hook is not None: - components.text_encoder_2.to("cpu") - torch.cuda.empty_cache() - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - - else: - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - # make sure the VAE is in float32 mode, as it overflows in float16 - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) - - init_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - init_latents = components.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = components.scheduler.add_noise(init_latents, noise, timestep) - - latents = init_latents - - return latents - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device - data.add_noise = True if data.denoising_start is None else False - if data.latents is None: - data.latents = self.prepare_latents_img2img( - pipeline, - data.image_latents, - data.latent_timestep, - data.batch_size, - data.num_images_per_prompt, - data.dtype, - data.device, - data.generator, - data.add_noise, - ) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLPrepareLatentsStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Prepare latents step that prepares the latents for the text-to-image generation process" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("height"), - InputParam("width"), - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - "latents", - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process" - ) - ] - - - @staticmethod - def check_inputs(pipeline, data): - if ( - data.height is not None - and data.height % pipeline.vae_scale_factor != 0 - or data.width is not None - and data.width % pipeline.vae_scale_factor != 0 - ): - raise ValueError( - f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {data.height} and {data.width}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components - def prepare_latents(self, components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = ( - batch_size, - num_channels_latents, - int(height) // components.vae_scale_factor, - int(width) // components.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * components.scheduler.init_noise_sigma - return latents - - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if data.dtype is None: - data.dtype = pipeline.vae.dtype - - data.device = pipeline._execution_device - - self.check_inputs(pipeline, data) - - data.height = data.height or pipeline.default_sample_size * pipeline.vae_scale_factor - data.width = data.width or pipeline.default_sample_size * pipeline.vae_scale_factor - data.num_channels_latents = pipeline.num_channels_latents - data.latents = self.prepare_latents( - pipeline, - data.batch_size * data.num_images_per_prompt, - data.num_channels_latents, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - data.latents, - ) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ConfigSpec("requires_aesthetics_score", False),] - - @property - def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("original_size"), - InputParam("target_size"), - InputParam("negative_original_size"), - InputParam("negative_target_size"), - InputParam("crops_coords_top_left", default=(0, 0)), - InputParam("negative_crops_coords_top_left", default=(0, 0)), - InputParam("num_images_per_prompt", default=1), - InputParam("aesthetic_score", default=6.0), - InputParam("negative_aesthetic_score", default=2.0), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components - def _get_add_time_ids_img2img( - self, - components, - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype, - text_encoder_projection_dim=None, - ): - if components.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) - ) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) - - passed_add_embed_dim = ( - components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features - - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." - ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." - ) - elif expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.device = pipeline._execution_device - - data.vae_scale_factor = pipeline.vae_scale_factor - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * data.vae_scale_factor - data.width = data.width * data.vae_scale_factor - - data.original_size = data.original_size or (data.height, data.width) - data.target_size = data.target_size or (data.height, data.width) - - data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) - - if data.negative_original_size is None: - data.negative_original_size = data.original_size - if data.negative_target_size is None: - data.negative_target_size = data.target_size - - data.add_time_ids, data.negative_add_time_ids = self._get_add_time_ids_img2img( - pipeline, - data.original_size, - data.crops_coords_top_left, - data.target_size, - data.aesthetic_score, - data.negative_aesthetic_score, - data.negative_original_size, - data.negative_crops_coords_top_left, - data.negative_target_size, - dtype=data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, - ) - data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - data.negative_add_time_ids = data.negative_add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - - # Optionally get Guidance Scale Embedding for LCM - data.timestep_cond = None - if ( - hasattr(pipeline, "unet") - and pipeline.unet is not None - and pipeline.unet.config.time_cond_proj_dim is not None - ): - # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) - data.timestep_cond = self.get_guidance_scale_embedding( - data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=data.device, dtype=data.latents.dtype) - - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the text-to-image generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("original_size"), - InputParam("target_size"), - InputParam("negative_original_size"), - InputParam("negative_target_size"), - InputParam("crops_coords_top_left", default=(0, 0)), - InputParam("negative_crops_coords_top_left", default=(0, 0)), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components - def _get_add_time_ids( - self, components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None - ): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.device = pipeline._execution_device - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor - - data.original_size = data.original_size or (data.height, data.width) - data.target_size = data.target_size or (data.height, data.width) - - data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) - - data.add_time_ids = self._get_add_time_ids( - pipeline, - data.original_size, - data.crops_coords_top_left, - data.target_size, - data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, - ) - if data.negative_original_size is not None and data.negative_target_size is not None: - data.negative_add_time_ids = self._get_add_time_ids( - pipeline, - data.negative_original_size, - data.negative_crops_coords_top_left, - data.negative_target_size, - data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, - ) - else: - data.negative_add_time_ids = data.add_time_ids - - data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - data.negative_add_time_ids = data.negative_add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - - # Optionally get Guidance Scale Embedding for LCM - data.timestep_cond = None - if ( - hasattr(pipeline, "unet") - and pipeline.unet is not None - and pipeline.unet.config.time_cond_proj_dim is not None - ): - # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) - data.timestep_cond = self.get_guidance_scale_embedding( - data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=data.device, dtype=data.latents.dtype) - - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLDenoiseStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def description(self) -> str: - return ( - "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("cross_attention_kwargs"), - InputParam("generator"), - InputParam("eta", default=0.0), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - - def check_inputs(self, pipeline, data): - - num_channels_unet = pipeline.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - def prepare_extra_step_kwargs(self, components, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.num_channels_unet = pipeline.unet.config.in_channels - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - if data.disable_guidance: - pipeline.guider.disable() - else: - pipeline.guider.enable() - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - - pipeline.guider.set_input_fields( - prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), - add_time_ids=("add_time_ids", "negative_add_time_ids"), - pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), - ) - - with self.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - guider_data = pipeline.guider.prepare_inputs(data) - - data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) - - # Prepare for inpainting - if data.num_channels_unet == 9: - data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) - - for batch in guider_data: - pipeline.guider.prepare_models(pipeline.unet) - - # Prepare additional conditionings - batch.added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, - } - if batch.ip_adapter_embeds is not None: - batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - - # Predict the noise residual - batch.noise_pred = pipeline.unet( - data.scaled_latents, - t, - encoder_hidden_states=batch.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=batch.added_cond_kwargs, - return_dict=False, - )[0] - pipeline.guider.cleanup_models(pipeline.unet) - - # Perform guidance - data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) - - # Perform scheduler step using the predicted output - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - - if data.latents.dtype != data.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) - - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) - ) - - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents - - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): - progress_bar.update() - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("controlnet", ControlNetModel), - ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), - ] - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("control_image", required=True), - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("guess_mode", default=False), - InputParam("num_images_per_prompt", default=1), - InputParam("cross_attention_kwargs"), - InputParam("generator"), - InputParam("eta", default=0.0), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - def check_inputs(self, pipeline, data): - - num_channels_unet = pipeline.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - def prepare_control_image( - self, - components, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - - image_batch_size = image.shape[0] - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - image = image.to(device=device, dtype=dtype) - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - def prepare_extra_step_kwargs(self, components, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.num_channels_unet = pipeline.unet.config.in_channels - - # (1) prepare controlnet inputs - data.device = pipeline._execution_device - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor - - controlnet = unwrap_module(pipeline.controlnet) - - # (1.1) - # control_guidance_start/control_guidance_end (align format) - if not isinstance(data.control_guidance_start, list) and isinstance(data.control_guidance_end, list): - data.control_guidance_start = len(data.control_guidance_end) * [data.control_guidance_start] - elif not isinstance(data.control_guidance_end, list) and isinstance(data.control_guidance_start, list): - data.control_guidance_end = len(data.control_guidance_start) * [data.control_guidance_end] - elif not isinstance(data.control_guidance_start, list) and not isinstance(data.control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - data.control_guidance_start, data.control_guidance_end = ( - mult * [data.control_guidance_start], - mult * [data.control_guidance_end], - ) - - # (1.2) - # controlnet_conditioning_scale (align format) - if isinstance(controlnet, MultiControlNetModel) and isinstance(data.controlnet_conditioning_scale, float): - data.controlnet_conditioning_scale = [data.controlnet_conditioning_scale] * len(controlnet.nets) - - # (1.3) - # global_pool_conditions - data.global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - # (1.4) - # guess_mode - data.guess_mode = data.guess_mode or data.global_pool_conditions - - # (1.5) - # control_image - if isinstance(controlnet, ControlNetModel): - data.control_image = self.prepare_control_image( - pipeline, - image=data.control_image, - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, - dtype=controlnet.dtype, - crops_coords=data.crops_coords, - ) - elif isinstance(controlnet, MultiControlNetModel): - control_images = [] - - for control_image_ in data.control_image: - control_image = self.prepare_control_image( - pipeline, - image=control_image_, - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, - dtype=controlnet.dtype, - crops_coords=data.crops_coords, - ) - - control_images.append(control_image) - - data.control_image = control_images - else: - assert False - - # (1.6) - # controlnet_keep - data.controlnet_keep = [] - for i in range(len(data.timesteps)): - keeps = [ - 1.0 - float(i / len(data.timesteps) < s or (i + 1) / len(data.timesteps) > e) - for s, e in zip(data.control_guidance_start, data.control_guidance_end) - ] - data.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - - # (2) Prepare conditional inputs for unet using the guider - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - if data.disable_guidance: - pipeline.guider.disable() - else: - pipeline.guider.enable() - - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - - pipeline.guider.set_input_fields( - prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), - add_time_ids=("add_time_ids", "negative_add_time_ids"), - pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), - ) - - # (5) Denoise loop - with self.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - guider_data = pipeline.guider.prepare_inputs(data) - - data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) - - if isinstance(data.controlnet_keep[i], list): - data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] - else: - data.controlnet_cond_scale = data.controlnet_conditioning_scale - if isinstance(data.controlnet_cond_scale, list): - data.controlnet_cond_scale = data.controlnet_cond_scale[0] - data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - - for batch in guider_data: - pipeline.guider.prepare_models(pipeline.unet) - - # Prepare additional conditionings - batch.added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, - } - if batch.ip_adapter_embeds is not None: - batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - - # Prepare controlnet additional conditionings - batch.controlnet_added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, - } - - # Will always be run atleast once with every guider - if pipeline.guider.is_conditional or not data.guess_mode: - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - data.scaled_latents, - t, - encoder_hidden_states=batch.prompt_embeds, - controlnet_cond=data.control_image, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=batch.controlnet_added_cond_kwargs, - return_dict=False, - ) - - batch.down_block_res_samples = data.down_block_res_samples - batch.mid_block_res_sample = data.mid_block_res_sample - - if pipeline.guider.is_unconditional and data.guess_mode: - batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] - batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) - - # Prepare for inpainting - if data.num_channels_unet == 9: - data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) - - batch.noise_pred = pipeline.unet( - data.scaled_latents, - t, - encoder_hidden_states=batch.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=batch.added_cond_kwargs, - down_block_additional_residuals=batch.down_block_res_samples, - mid_block_additional_residual=batch.mid_block_res_sample, - return_dict=False, - )[0] - pipeline.guider.cleanup_models(pipeline.unet) - - # Perform guidance - data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) - - # Perform scheduler step using the predicted output - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - - if data.latents.dtype != data.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) - - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) - ) - - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents - - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): - progress_bar.update() - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("controlnet", ControlNetUnionModel), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec( - "control_image_processor", - VaeImageProcessor, - config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), - default_creation_method="from_config"), - ] - - @property - def description(self) -> str: - return " The denoising step for the controlnet union model, works for inpainting, image-to-image, and text-to-image tasks" - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("control_image", required=True), - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("control_mode", required=True), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("guess_mode", default=False), - InputParam("num_images_per_prompt", default=1), - InputParam("cross_attention_kwargs"), - InputParam("generator"), - InputParam("eta", default=0.0), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step. " - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - def check_inputs(self, pipeline, data): - - num_channels_unet = pipeline.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - def prepare_control_image( - self, - components, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components - def prepare_extra_step_kwargs(self, components, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.num_channels_unet = pipeline.unet.config.in_channels - - # (1) prepare controlnet inputs - data.device = pipeline._execution_device - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor - - controlnet = unwrap_module(pipeline.controlnet) - - # (1.1) - # control guidance - if not isinstance(data.control_guidance_start, list) and isinstance(data.control_guidance_end, list): - data.control_guidance_start = len(data.control_guidance_end) * [data.control_guidance_start] - elif not isinstance(data.control_guidance_end, list) and isinstance(data.control_guidance_start, list): - data.control_guidance_end = len(data.control_guidance_start) * [data.control_guidance_end] - - # (1.2) - # global_pool_conditions & guess_mode - data.global_pool_conditions = controlnet.config.global_pool_conditions - data.guess_mode = data.guess_mode or data.global_pool_conditions - - # (1.3) - # control_type - data.num_control_type = controlnet.config.num_control_type - - # (1.4) - # control_type - if not isinstance(data.control_image, list): - data.control_image = [data.control_image] - - if not isinstance(data.control_mode, list): - data.control_mode = [data.control_mode] - - if len(data.control_image) != len(data.control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") - - data.control_type = [0 for _ in range(data.num_control_type)] - for control_idx in data.control_mode: - data.control_type[control_idx] = 1 - - data.control_type = torch.Tensor(data.control_type) - - # (1.5) - # prepare control_image - for idx, _ in enumerate(data.control_image): - data.control_image[idx] = self.prepare_control_image( - pipeline, - image=data.control_image[idx], - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, - dtype=controlnet.dtype, - crops_coords=data.crops_coords, - ) - data.height, data.width = data.control_image[idx].shape[-2:] - - # (1.6) - # controlnet_keep - data.controlnet_keep = [] - for i in range(len(data.timesteps)): - data.controlnet_keep.append( - 1.0 - - float(i / len(data.timesteps) < data.control_guidance_start or (i + 1) / len(data.timesteps) > data.control_guidance_end) - ) - - # (2) Prepare conditional inputs for unet using the guider - # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - if data.disable_guidance: - pipeline.guider.disable() - else: - pipeline.guider.enable() - - data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype) - repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] - data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0) - - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - - pipeline.guider.set_input_fields( - prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), - add_time_ids=("add_time_ids", "negative_add_time_ids"), - pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), - ) - - with self.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - guider_data = pipeline.guider.prepare_inputs(data) - - data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) - - if isinstance(data.controlnet_keep[i], list): - data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] - else: - data.controlnet_cond_scale = data.controlnet_conditioning_scale - if isinstance(data.controlnet_cond_scale, list): - data.controlnet_cond_scale = data.controlnet_cond_scale[0] - data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - - for batch in guider_data: - pipeline.guider.prepare_models(pipeline.unet) - - # Prepare additional conditionings - batch.added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, - } - if batch.ip_adapter_embeds is not None: - batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - - # Prepare controlnet additional conditionings - batch.controlnet_added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, - } - - # Will always be run atleast once with every guider - if pipeline.guider.is_conditional or not data.guess_mode: - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - data.scaled_latents, - t, - encoder_hidden_states=batch.prompt_embeds, - controlnet_cond=data.control_image, - control_type=data.control_type, - control_type_idx=data.control_mode, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=batch.controlnet_added_cond_kwargs, - return_dict=False, - ) - - batch.down_block_res_samples = data.down_block_res_samples - batch.mid_block_res_sample = data.mid_block_res_sample - - if pipeline.guider.is_unconditional and data.guess_mode: - batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] - batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) - - if data.num_channels_unet == 9: - data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) - - batch.noise_pred = pipeline.unet( - data.scaled_latents, - t, - encoder_hidden_states=batch.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=batch.added_cond_kwargs, - down_block_additional_residuals=batch.down_block_res_samples, - mid_block_additional_residual=batch.mid_block_res_sample, - return_dict=False, - )[0] - pipeline.guider.cleanup_models(pipeline.unet) - - # Perform guidance - data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) - - # Perform scheduler step using the predicted output - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - - if data.latents.dtype != data.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) - - if data.num_channels_unet == 9 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) - ) - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents - - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): - progress_bar.update() - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLDecodeLatentsStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), - ] - - @property - def description(self) -> str: - return "Step that decodes the denoised latents into images" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("output_type", default="pil"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components - def upcast_vae(self, components): - dtype = components.vae.dtype - components.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = isinstance( - components.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - components.vae.post_quant_conv.to(dtype) - components.vae.decoder.conv_in.to(dtype) - components.vae.decoder.mid_block.to(dtype) - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if not data.output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - data.needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast - - if data.needs_upcasting: - self.upcast_vae(pipeline) - data.latents = data.latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype) - elif data.latents.dtype != pipeline.vae.dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - pipeline.vae = pipeline.vae.to(data.latents.dtype) - - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - data.has_latents_mean = ( - hasattr(pipeline.vae.config, "latents_mean") and pipeline.vae.config.latents_mean is not None - ) - data.has_latents_std = ( - hasattr(pipeline.vae.config, "latents_std") and pipeline.vae.config.latents_std is not None - ) - if data.has_latents_mean and data.has_latents_std: - data.latents_mean = ( - torch.tensor(pipeline.vae.config.latents_mean).view(1, 4, 1, 1).to(data.latents.device, data.latents.dtype) - ) - data.latents_std = ( - torch.tensor(pipeline.vae.config.latents_std).view(1, 4, 1, 1).to(data.latents.device, data.latents.dtype) - ) - data.latents = data.latents * data.latents_std / pipeline.vae.config.scaling_factor + data.latents_mean - else: - data.latents = data.latents / pipeline.vae.config.scaling_factor - - data.images = pipeline.vae.decode(data.latents, return_dict=False)[0] - - # cast back to fp16 if needed - if data.needs_upcasting: - pipeline.vae.to(dtype=torch.float16) - else: - data.images = data.latents - - # apply watermark if available - if hasattr(pipeline, "watermark") and pipeline.watermark is not None: - data.images = pipeline.watermark.apply_watermark(data.images) - - data.images = pipeline.image_processor.postprocess(data.images, output_type=data.output_type) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \ - "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("image", required=True), - InputParam("mask_image", required=True), - InputParam("padding_mask_crop"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"), - InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.") - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if data.padding_mask_crop is not None and data.crops_coords is not None: - data.images = [pipeline.image_processor.apply_overlay(data.mask_image, data.image, i, data.crops_coords) for i in data.images] - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLOutputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "final step to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [InputParam("return_dict", default=True)] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step.")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")] - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if not data.return_dict: - data.images = (data.images,) - else: - data.images = StableDiffusionXLPipelineOutput(images=data.images) - self.add_block_state(state, data) - return pipeline, state - - -# Encode -class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] - block_names = ["inpaint", "img2img"] - block_trigger_inputs = ["mask_image", "image"] - - @property - def description(self): - return "Vae encoder step that encode the image inputs into their latent representations.\n" + \ - "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \ - " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \ - " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." - - -# Before denoise -class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" - - -class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" - - -class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" - - -class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] - block_names = ["inpaint", "img2img", "text2img"] - block_trigger_inputs = ["mask", "image_latents", None] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n" + \ - " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ - " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ - " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided." - - -# Denoise -class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] - block_names = ["controlnet_union", "controlnet", "unet"] - block_trigger_inputs = ["control_mode", "control_image", None] - - @property - def description(self): - return "Denoise step that denoise the latents.\n" + \ - "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ - " - `StableDiffusionXLControlNetUnionDenoiseStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlNetDenoiseStep` (controlnet) is used when `control_image` is provided.\n" + \ - " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." - -# After denoise -class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] - block_names = ["decode", "output"] - - @property - def description(self): - return """Decode step that decode the denoised latents into images outputs. -This is a sequential pipeline blocks: - - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images - - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple.""" - - -class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] - block_names = ["decode", "mask_overlay", "output"] - - @property - def description(self): - return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ - " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image\n" + \ - " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - -class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] - block_names = ["inpaint", "non-inpaint"] - block_trigger_inputs = ["padding_mask_crop", None] - - @property - def description(self): - return "Decode step that decode the denoised latents into images outputs.\n" + \ - "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ - " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ - " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." - - -class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin): - block_classes = [StableDiffusionXLIPAdapterStep] - block_names = ["ip_adapter"] - block_trigger_inputs = ["ip_adapter_image"] - - @property - def description(self): - return "Run IP Adapter step if `ip_adapter_image` is provided." - - -class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] - block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decode"] - - @property - def description(self): - return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ - "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ - "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ - "- to run the controlnet workflow, you need to provide `control_image`\n" + \ - "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ - "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ - "- for text-to-image generation, all you need to provide is `prompt`" - -# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that -# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by -# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the -# configuration of guider is. - -# block mapping -TEXT2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLSetTimestepsStep), - ("prepare_latents", StableDiffusionXLPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -IMAGE2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -INPAINT_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLInpaintDecodeStep) -]) - -CONTROLNET_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetDenoiseStep), -]) - -CONTROLNET_UNION_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetUnionDenoiseStep), -]) - -IP_ADAPTER_BLOCKS = OrderedDict([ - ("ip_adapter", StableDiffusionXLIPAdapterStep), -]) - -AUTO_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), - ("decode", StableDiffusionXLAutoDecodeStep) -]) - -AUTO_CORE_BLOCKS = OrderedDict([ - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), -]) - - -SDXL_SUPPORTED_BLOCKS = { - "text2img": TEXT2IMAGE_BLOCKS, - "img2img": IMAGE2IMAGE_BLOCKS, - "inpaint": INPAINT_BLOCKS, - "controlnet": CONTROLNET_BLOCKS, - "controlnet_union": CONTROLNET_UNION_BLOCKS, - "ip_adapter": IP_ADAPTER_BLOCKS, - "auto": AUTO_BLOCKS -} - - -# YiYi Notes: model specific components: -## (1) it should inherit from ModularLoader -## (2) acts like a container that holds components and configs -## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents -## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) -## (5) how to use together with Components_manager? -class StableDiffusionXLModularLoader( - ModularLoader, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionXLLoraLoaderMixin, - ModularIPAdapterMixin, -): - @property - def default_sample_size(self): - default_sample_size = 128 - if hasattr(self, "unet") and self.unet is not None: - default_sample_size = self.unet.config.sample_size - return default_sample_size - - @property - def vae_scale_factor(self): - vae_scale_factor = 8 - if hasattr(self, "vae") and self.vae is not None: - vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - return vae_scale_factor - - @property - def num_channels_unet(self): - num_channels_unet = 4 - if hasattr(self, "unet") and self.unet is not None: - num_channels_unet = self.unet.config.in_channels - return num_channels_unet - - @property - def num_channels_latents(self): - num_channels_latents = 4 - if hasattr(self, "vae") and self.vae is not None: - num_channels_latents = self.vae.config.latent_channels - return num_channels_latents - - - - - -# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks -SDXL_INPUTS_SCHEMA = { - "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), - "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), - "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), - "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), - "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), - "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), - "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), - "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), - "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), - "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), - "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), - "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), - "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), - "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), - "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), - "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), - # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 - "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), - "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), - "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), - "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), - "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), - "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), - "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), - "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), - "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), - "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), - "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), - "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), - "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), - "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), - "return_dict": InputParam("return_dict", type_hint=bool, default=True, description="Whether to return a StableDiffusionXLPipelineOutput"), - "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), - "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), - "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), - "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), - "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), - "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), - "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") -} - - -SDXL_INTERMEDIATE_INPUTS_SCHEMA = { - "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), - "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), - "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), - "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), - "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), - "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), - "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), - "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), - "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), - "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), - "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), - "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), - "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), - "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), - "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), - "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), - "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), - "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), - "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), - "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") -} - - -SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { - "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), - "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), - "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), - "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), - "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), - "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), - "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), - "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), - "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), - "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), - "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), - "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), - "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), - "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), - "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), - "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), - "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), - "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), - "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") -} - - -SDXL_OUTPUTS_SCHEMA = { - "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") -}