diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 770093438ed5..bb2c847f8aff 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -93,6 +93,26 @@ - local: hybrid_inference/api_reference title: API Reference title: Hybrid Inference +- sections: + - local: modular_diffusers/overview + title: Overview + - local: modular_diffusers/modular_pipeline + title: Modular Pipeline + - local: modular_diffusers/components_manager + title: Components Manager + - local: modular_diffusers/modular_diffusers_states + title: Modular Diffusers States + - local: modular_diffusers/pipeline_block + title: Pipeline Block + - local: modular_diffusers/sequential_pipeline_blocks + title: Sequential Pipeline Blocks + - local: modular_diffusers/loop_sequential_pipeline_blocks + title: Loop Sequential Pipeline Blocks + - local: modular_diffusers/auto_pipeline_blocks + title: Auto Pipeline Blocks + - local: modular_diffusers/end_to_end_guide + title: End-to-End Example + title: Modular Diffusers - sections: - local: using-diffusers/consisid title: ConsisID diff --git a/docs/source/en/modular_diffusers/auto_pipeline_blocks.md b/docs/source/en/modular_diffusers/auto_pipeline_blocks.md new file mode 100644 index 000000000000..50c3250512d1 --- /dev/null +++ b/docs/source/en/modular_diffusers/auto_pipeline_blocks.md @@ -0,0 +1,316 @@ + + +# AutoPipelineBlocks + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +`AutoPipelineBlocks` is a subclass of `ModularPipelineBlocks`. It is a multi-block that automatically selects which sub-blocks to run based on the inputs provided at runtime, creating conditional workflows that adapt to different scenarios. The main purpose is convenience and portability - for developers, you can package everything into one workflow, making it easier to share and use. + +In this tutorial, we will show you how to create an `AutoPipelineBlocks` and learn more about how the conditional selection works. + + + +Other types of multi-blocks include [SequentialPipelineBlocks](sequential_pipeline_blocks.md) (for linear workflows) and [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows). For information on creating individual blocks, see the [PipelineBlock guide](pipeline_block.md). + +Additionally, like all `ModularPipelineBlocks`, `AutoPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md). + + + +For example, you might want to support text-to-image and image-to-image tasks. Instead of creating two separate pipelines, you can create an `AutoPipelineBlocks` that automatically chooses the workflow based on whether an `image` input is provided. + +Let's see an example. We'll use the helper function from the [PipelineBlock guide](./pipeline_block.md) to create our blocks: + +**Helper Function** + +```py +from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam +import torch + +def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None): + class TestBlock(PipelineBlock): + model_name = "test" + + @property + def inputs(self): + return inputs + + @property + def intermediate_inputs(self): + return intermediate_inputs + + @property + def intermediate_outputs(self): + return intermediate_outputs + + @property + def description(self): + return description if description is not None else "" + + def __call__(self, components, state): + block_state = self.get_block_state(state) + if block_fn is not None: + block_state = block_fn(block_state, state) + self.set_block_state(state, block_state) + return components, state + + return TestBlock +``` + +Now let's create a dummy `AutoPipelineBlocks` that includes dummy text-to-image, image-to-image, and inpaint pipelines. + + +```py +from diffusers.modular_pipelines import AutoPipelineBlocks + +# These are dummy blocks and we only focus on "inputs" for our purpose +inputs = [InputParam(name="prompt")] +# block_fn prints out which workflow is running so we can see the execution order at runtime +block_fn = lambda x, y: print("running the text-to-image workflow") +block_t2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a text-to-image workflow!") + +inputs = [InputParam(name="prompt"), InputParam(name="image")] +block_fn = lambda x, y: print("running the image-to-image workflow") +block_i2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a image-to-image workflow!") + +inputs = [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")] +block_fn = lambda x, y: print("running the inpaint workflow") +block_inpaint_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a inpaint workflow!") + +class AutoImageBlocks(AutoPipelineBlocks): + # List of sub-block classes to choose from + block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls] + # Names for each block in the same order + block_names = ["inpaint", "img2img", "text2img"] + # Trigger inputs that determine which block to run + # - "mask" triggers inpaint workflow + # - "image" triggers img2img workflow (but only if mask is not provided) + # - if none of above, runs the text2img workflow (default) + block_trigger_inputs = ["mask", "image", None] + # Description is extremely important for AutoPipelineBlocks + @property + def description(self): + return ( + "Pipeline generates images given different types of conditions!\n" + + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n" + + " - inpaint workflow is run when `mask` is provided.\n" + + " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n" + + " - text2img workflow is run when neither `image` nor `mask` is provided.\n" + ) + +# Create the blocks +auto_blocks = AutoImageBlocks() +# convert to pipeline +auto_pipeline = auto_blocks.init_pipeline() +``` + +Now we have created an `AutoPipelineBlocks` that contains 3 sub-blocks. Notice the warning message at the top - this automatically appears in every `ModularPipelineBlocks` that contains `AutoPipelineBlocks` to remind end users that dynamic block selection happens at runtime. + +```py +AutoImageBlocks( + Class: AutoPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: ['mask', 'image'] + ==================================================================================================== + + + Description: Pipeline generates images given different types of conditions! + This is an auto pipeline block that works for text2img, img2img and inpainting tasks. + - inpaint workflow is run when `mask` is provided. + - img2img workflow is run when `image` is provided (but only when `mask` is not provided). + - text2img workflow is run when neither `image` nor `mask` is provided. + + + + Sub-Blocks: + • inpaint [trigger: mask] (TestBlock) + Description: I'm a inpaint workflow! + + • img2img [trigger: image] (TestBlock) + Description: I'm a image-to-image workflow! + + • text2img [default] (TestBlock) + Description: I'm a text-to-image workflow! + +) +``` + +Check out the documentation with `print(auto_pipeline.doc)`: + +```py +>>> print(auto_pipeline.doc) +class AutoImageBlocks + + Pipeline generates images given different types of conditions! + This is an auto pipeline block that works for text2img, img2img and inpainting tasks. + - inpaint workflow is run when `mask` is provided. + - img2img workflow is run when `image` is provided (but only when `mask` is not provided). + - text2img workflow is run when neither `image` nor `mask` is provided. + + Inputs: + + prompt (`None`, *optional*): + + image (`None`, *optional*): + + mask (`None`, *optional*): +``` + +There is a fundamental trade-off of AutoPipelineBlocks: it trades clarity for convenience. While it is really easy for packaging multiple workflows, it can become confusing without proper documentation. e.g. if we just throw a pipeline at you and tell you that it contains 3 sub-blocks and takes 3 inputs `prompt`, `image` and `mask`, and ask you to run an image-to-image workflow: if you don't have any prior knowledge on how these pipelines work, you would be pretty clueless, right? + +This pipeline we just made though, has a docstring that shows all available inputs and workflows and explains how to use each with different inputs. So it's really helpful for users. For example, it's clear that you need to pass `image` to run img2img. This is why the description field is absolutely critical for AutoPipelineBlocks. We highly recommend you to explain the conditional logic very well for each `AutoPipelineBlocks` you would make. We also recommend to always test individual pipelines first before packaging them into AutoPipelineBlocks. + +Let's run this auto pipeline with different inputs to see if the conditional logic works as described. Remember that we have added `print` in each `PipelineBlock`'s `__call__` method to print out its workflow name, so it should be easy to tell which one is running: + +```py +>>> _ = auto_pipeline(image="image", mask="mask") +running the inpaint workflow +>>> _ = auto_pipeline(image="image") +running the image-to-image workflow +>>> _ = auto_pipeline(prompt="prompt") +running the text-to-image workflow +>>> _ = auto_pipeline(image="prompt", mask="mask") +running the inpaint workflow +``` + +However, even with documentation, it can become very confusing when AutoPipelineBlocks are combined with other blocks. The complexity grows quickly when you have nested AutoPipelineBlocks or use them as sub-blocks in larger pipelines. + +Let's make another `AutoPipelineBlocks` - this one only contains one block, and it does not include `None` in its `block_trigger_inputs` (which corresponds to the default block to run when none of the trigger inputs are provided). This means this block will be skipped if the trigger input (`ip_adapter_image`) is not provided at runtime. + +```py +from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict +inputs = [InputParam(name="ip_adapter_image")] +block_fn = lambda x, y: print("running the ip-adapter workflow") +block_ipa_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a IP-adapter workflow!") + +class AutoIPAdapter(AutoPipelineBlocks): + block_classes = [block_ipa_cls] + block_names = ["ip-adapter"] + block_trigger_inputs = ["ip_adapter_image"] + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided." +``` + +Now let's combine these 2 auto blocks together into a `SequentialPipelineBlocks`: + +```py +auto_ipa_blocks = AutoIPAdapter() +blocks_dict = InsertableDict() +blocks_dict["ip-adapter"] = auto_ipa_blocks +blocks_dict["image-generation"] = auto_blocks +all_blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict) +pipeline = all_blocks.init_pipeline() +``` + +Let's take a look: now things get more confusing. In this particular example, you could still try to explain the conditional logic in the `description` field here - there are only 4 possible execution paths so it's doable. However, since this is a `SequentialPipelineBlocks` that could contain many more blocks, the complexity can quickly get out of hand as the number of blocks increases. + +```py +>>> all_blocks +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: ['image', 'mask', 'ip_adapter_image'] + Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('image')`). + ==================================================================================================== + + + Description: + + + Sub-Blocks: + [0] ip-adapter (AutoIPAdapter) + Description: Run IP Adapter step if `ip_adapter_image` is provided. + + + [1] image-generation (AutoImageBlocks) + Description: Pipeline generates images given different types of conditions! + This is an auto pipeline block that works for text2img, img2img and inpainting tasks. + - inpaint workflow is run when `mask` is provided. + - img2img workflow is run when `image` is provided (but only when `mask` is not provided). + - text2img workflow is run when neither `image` nor `mask` is provided. + + +) + +``` + +This is when the `get_execution_blocks()` method comes in handy - it basically extracts a `SequentialPipelineBlocks` that only contains the blocks that are actually run based on your inputs. + +Let's try some examples: + +`mask`: we expect it to skip the first ip-adapter since `ip_adapter_image` is not provided, and then run the inpaint for the second block. + +```py +>>> all_blocks.get_execution_blocks('mask') +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + Description: + + + Sub-Blocks: + [0] image-generation (TestBlock) + Description: I'm a inpaint workflow! + +) +``` + +Let's also actually run the pipeline to confirm: + +```py +>>> _ = pipeline(mask="mask") +skipping auto block: AutoIPAdapter +running the inpaint workflow +``` + +Try a few more: + +```py +print(f"inputs: ip_adapter_image:") +blocks_select = all_blocks.get_execution_blocks('ip_adapter_image') +print(f"expected_execution_blocks: {blocks_select}") +print(f"actual execution blocks:") +_ = pipeline(ip_adapter_image="ip_adapter_image", prompt="prompt") +# expect to see ip-adapter + text2img + +print(f"inputs: image:") +blocks_select = all_blocks.get_execution_blocks('image') +print(f"expected_execution_blocks: {blocks_select}") +print(f"actual execution blocks:") +_ = pipeline(image="image", prompt="prompt") +# expect to see img2img + +print(f"inputs: prompt:") +blocks_select = all_blocks.get_execution_blocks('prompt') +print(f"expected_execution_blocks: {blocks_select}") +print(f"actual execution blocks:") +_ = pipeline(prompt="prompt") +# expect to see text2img (prompt is not a trigger input so fallback to default) + +print(f"inputs: mask + ip_adapter_image:") +blocks_select = all_blocks.get_execution_blocks('mask','ip_adapter_image') +print(f"expected_execution_blocks: {blocks_select}") +print(f"actual execution blocks:") +_ = pipeline(mask="mask", ip_adapter_image="ip_adapter_image") +# expect to see ip-adapter + inpaint +``` + +In summary, `AutoPipelineBlocks` is a good tool for packaging multiple workflows into a single, convenient interface and it can greatly simplify the user experience. However, always provide clear descriptions explaining the conditional logic, test individual pipelines first before combining them, and use `get_execution_blocks()` to understand runtime behavior in complex compositions. \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/components_manager.md b/docs/source/en/modular_diffusers/components_manager.md new file mode 100644 index 000000000000..15b6c66b9b06 --- /dev/null +++ b/docs/source/en/modular_diffusers/components_manager.md @@ -0,0 +1,514 @@ + + +# Components Manager + + + +🧪 **Experimental Feature**: This is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +The Components Manager is a central model registry and management system in diffusers. It lets you add models then reuse them across multiple pipelines and workflows. It tracks all models in one place with useful metadata such as model size, device placement and loaded adapters (LoRA, IP-Adapter). It has mechanisms in place to prevent duplicate model instances, enables memory-efficient sharing. Most significantly, it offers offloading that works across pipelines — unlike regular DiffusionPipeline offloading (i.e. `enable_model_cpu_offload` and `enable_sequential_cpu_offload`) which is limited to one pipeline with predefined sequences, the Components Manager automatically manages your device memory across all your models and workflows. + + +## Basic Operations + +Let's start with the most basic operations. First, create a Components Manager: + +```py +from diffusers import ComponentsManager +comp = ComponentsManager() +``` + +Use the `add(name, component)` method to register a component. It returns a unique ID that combines the component name with the object's unique identifier (using Python's `id()` function): + +```py +from diffusers import AutoModel +text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder") +# Returns component_id like 'text_encoder_139917733042864' +component_id = comp.add("text_encoder", text_encoder) +``` + +You can view all registered components and their metadata: + +```py +>>> comp +Components: +=============================================================================================================================================== +Models: +----------------------------------------------------------------------------------------------------------------------------------------------- +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +----------------------------------------------------------------------------------------------------------------------------------------------- +text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A +----------------------------------------------------------------------------------------------------------------------------------------------- + +Additional Component Info: +================================================== +``` + +And remove components using their unique ID: + +```py +comp.remove("text_encoder_139917733042864") +``` + +## Duplicate Detection + +The Components Manager automatically detects and prevents duplicate model instances to save memory and avoid confusion. Let's walk through how this works in practice. + +When you try to add the same object twice, the manager will warn you and return the existing ID: + +```py +>>> comp.add("text_encoder", text_encoder) +'text_encoder_139917733042864' +>>> comp.add("text_encoder", text_encoder) +ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917733042864' +'text_encoder_139917733042864' +``` + +Even if you add the same object under a different name, it will still be detected as a duplicate: + +```py +>>> comp.add("clip", text_encoder) +ComponentsManager: adding component 'clip' as 'clip_139917733042864', but it is duplicate of 'text_encoder_139917733042864' +To remove a duplicate, call `components_manager.remove('')`. +'clip_139917733042864' +``` + +However, there's a more subtle case where duplicate detection becomes tricky. When you load the same model into different objects, the manager can't detect duplicates unless you use `ComponentSpec`. For example: + +```py +>>> text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder") +>>> comp.add("text_encoder", text_encoder_2) +'text_encoder_139917732983664' +``` + +This creates a problem - you now have two copies of the same model consuming double the memory: + +```py +>>> comp +Components: +=============================================================================================================================================== +Models: +----------------------------------------------------------------------------------------------------------------------------------------------- +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +----------------------------------------------------------------------------------------------------------------------------------------------- +text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A +clip_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A +text_encoder_139917732983664 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A +----------------------------------------------------------------------------------------------------------------------------------------------- + +Additional Component Info: +================================================== +``` + +We recommend using `ComponentSpec` to load your models. Models loaded with `ComponentSpec` get tagged with a unique ID that encodes their loading parameters, allowing the Components Manager to detect when different objects represent the same underlying checkpoint: + +```py +from diffusers import ComponentSpec, ComponentsManager +from transformers import CLIPTextModel +comp = ComponentsManager() + +# Create ComponentSpec for the first text encoder +spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel) +# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from same repo/subfolder) +spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=CLIPTextModel) + +# Load and add both components - the manager will detect they're the same model +comp.add("text_encoder", spec.load()) +comp.add("text_encoder_duplicated", spec_duplicated.load()) +``` + +Now the manager detects the duplicate and warns you: + +```out +ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('')`. +'text_encoder_duplicated_139917580682672' +``` + +Both models now show the same `load_id`, making it clear they're the same model: + +```py +>>> comp +Components: +====================================================================================================================================================================================================== +Models: +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ +text_encoder_139918506246832 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A +text_encoder_duplicated_139917580682672 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ + +Additional Component Info: +================================================== +``` + +## Collections + +Collections are labels you can assign to components for better organization and management. You add a component under a collection by passing the `collection=` parameter when you add the component to the manager, i.e. `add(name, component, collection=...)`. Within each collection, only one component per name is allowed - if you add a second component with the same name, the first one is automatically removed. + +Here's how collections work in practice: + +```py +comp = ComponentsManager() +# Create ComponentSpec for the first UNet (SDXL base) +spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel) +# Create ComponentSpec for a different UNet (Juggernaut-XL) +spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16") + +# Add both UNets to the same collection - the second one will replace the first +comp.add("unet", spec.load(), collection="sdxl") +comp.add("unet", spec2.load(), collection="sdxl") +``` + +The manager automatically removes the old UNet and adds the new one: + +```out +ComponentsManager: removing existing unet from collection 'sdxl': unet_139917723891888 +'unet_139917723893136' +``` + +Only one UNet remains in the collection: + +```py +>>> comp +Components: +==================================================================================================================================================================== +Models: +-------------------------------------------------------------------------------------------------------------------------------------------------------------------- +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +-------------------------------------------------------------------------------------------------------------------------------------------------------------------- +unet_139917723893136 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | RunDiffusion/Juggernaut-XL-v9|unet|fp16|null | sdxl +-------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Additional Component Info: +================================================== +``` + +For example, in node-based systems, you can mark all models loaded from one node with the same collection label, automatically replace models when user loads new checkpoints under same name, batch delete all models in a collection when a node is removed. + +## Retrieving Components + +The Components Manager provides several methods to retrieve registered components. + +The `get_one()` method returns a single component and supports pattern matching for the `name` parameter. You can use: +- exact matches like `comp.get_one(name="unet")` +- wildcards like `comp.get_one(name="unet*")` for components starting with "unet" +- exclusion patterns like `comp.get_one(name="!unet")` to exclude components named "unet" +- OR patterns like `comp.get_one(name="unet|vae")` to match either "unet" OR "vae". + +Optionally, You can add collection and load_id as filters e.g. `comp.get_one(name="unet", collection="sdxl")`. If multiple components match, `get_one()` throws an error. + +Another useful method is `get_components_by_names()`, which takes a list of names and returns a dictionary mapping names to components. This is particularly helpful with modular pipelines since they provide lists of required component names, and the returned dictionary can be directly passed to `pipeline.update_components()`. + +```py +# Get components by name list +component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"]) +# Returns: {"text_encoder": component1, "unet": component2, "vae": component3} +``` + +## Using Components Manager with Modular Pipelines + +The Components Manager integrates seamlessly with Modular Pipelines. All you need to do is pass a Components Manager instance to `from_pretrained()` or `init_pipeline()` with an optional `collection` parameter: + +```py +from diffusers import ModularPipeline, ComponentsManager +comp = ComponentsManager() +pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1") +``` + +By default, modular pipelines don't load components immediately, so both the pipeline and Components Manager start empty: + +```py +>>> comp +Components: +================================================== +No components registered. +================================================== +``` + +When you load components on the pipeline, they are automatically registered in the Components Manager: + +```py +>>> pipe.load_components(names="unet") +>>> comp +Components: +============================================================================================================================================================== +Models: +-------------------------------------------------------------------------------------------------------------------------------------------------------------- +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +-------------------------------------------------------------------------------------------------------------------------------------------------------------- +unet_139917726686304 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1 +-------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Additional Component Info: +================================================== +``` + +Now let's load all default components and then create a second pipeline that reuses all components from the first one. We pass the same Components Manager to the second pipeline but with a different collection: + +```py +# Load all default components +>>> pipe.load_default_components() + +# Create a second pipeline using the same Components Manager but with a different collection +>>> pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2") +``` + +As mentioned earlier, `ModularPipeline` has a property `null_component_names` that returns a list of component names it needs to load. We can conveniently use this list with the `get_components_by_names` method on the Components Manager: + +```py +# Get the list of components that pipe2 needs to load +>>> pipe2.null_component_names +['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet'] + +# Retrieve all required components from the Components Manager +>>> comp_dict = comp.get_components_by_names(names=pipe2.null_component_names) + +# Update the pipeline with the retrieved components +>>> pipe2.update_components(**comp_dict) +``` + +The warnings that follow are expected and indicate that the Components Manager is correctly identifying that these components already exist and will be reused rather than creating duplicates: + +```out +ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917586016400' +ComponentsManager: component 'text_encoder_2' already exists as 'text_encoder_2_139917699973424' +ComponentsManager: component 'tokenizer' already exists as 'tokenizer_139917580599504' +ComponentsManager: component 'tokenizer_2' already exists as 'tokenizer_2_139915763443904' +ComponentsManager: component 'image_encoder' already exists as 'image_encoder_139917722468304' +ComponentsManager: component 'unet' already exists as 'unet_139917580609632' +ComponentsManager: component 'vae' already exists as 'vae_139917722459040' +ComponentsManager: component 'scheduler' already exists as 'scheduler_139916266559408' +ComponentsManager: component 'controlnet' already exists as 'controlnet_139917722454432' +``` + + +The pipeline is now fully loaded: + +```py +# null_component_names return empty list, meaning everything are loaded +>>> pipe2.null_component_names +[] +``` + +No new components were added to the Components Manager - we're reusing everything. All models are now associated with both `test1` and `test2` collections, showing that these components are shared across multiple pipelines: +```py +>>> comp +Components: +======================================================================================================================================================================================== +Models: +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +text_encoder_139917586016400 | CLIPTextModel | cpu | torch.float32 | 0.46 | SG161222/RealVisXL_V4.0|text_encoder|null|null | test1 + | | | | | | test2 +text_encoder_2_139917699973424 | CLIPTextModelWithProjection | cpu | torch.float32 | 2.59 | SG161222/RealVisXL_V4.0|text_encoder_2|null|null | test1 + | | | | | | test2 +unet_139917580609632 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1 + | | | | | | test2 +controlnet_139917722454432 | ControlNetModel | cpu | torch.float32 | 4.66 | diffusers/controlnet-canny-sdxl-1.0|null|null|null | test1 + | | | | | | test2 +vae_139917722459040 | AutoencoderKL | cpu | torch.float32 | 0.31 | SG161222/RealVisXL_V4.0|vae|null|null | test1 + | | | | | | test2 +image_encoder_139917722468304 | CLIPVisionModelWithProjection | cpu | torch.float32 | 6.87 | h94/IP-Adapter|sdxl_models/image_encoder|null|null | test1 + | | | | | | test2 +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Other Components: +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +ID | Class | Collection +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +tokenizer_139917580599504 | CLIPTokenizer | test1 + | | test2 +scheduler_139916266559408 | EulerDiscreteScheduler | test1 + | | test2 +tokenizer_2_139915763443904 | CLIPTokenizer | test1 + | | test2 +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Additional Component Info: +================================================== +``` + + +## Automatic Memory Management + +The Components Manager provides a global offloading strategy across all models, regardless of which pipeline is using them: + +```py +comp.enable_auto_cpu_offload(device="cuda") +``` + +When enabled, all models start on CPU. The manager moves models to the device right before they're used and moves other models back to CPU when GPU memory runs low. You can set your own rules for which models to offload first. This works smoothly as you add or remove components. Once it's on, you don't need to worry about device placement - you can focus on your workflow. + + + +## Practical Example: Building Modular Workflows with Component Reuse + +Now that we've covered the basics of the Components Manager, let's walk through a practical example that shows how to build workflows in a modular setting and use the Components Manager to reuse components across multiple pipelines. This example demonstrates the true power of Modular Diffusers by working with multiple pipelines that can share components. + +In this example, we'll generate latents from a text-to-image pipeline, then refine them with an image-to-image pipeline. + +Let's create a modular text-to-image workflow by separating it into three workflows: `text_blocks` for encoding prompts, `t2i_blocks` for generating latents, and `decoder_blocks` for creating final images. + +```py +import torch +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS + +# Create modular blocks and separate text encoding and decoding steps +t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["text2img"]) +text_blocks = t2i_blocks.sub_blocks.pop("text_encoder") +decoder_blocks = t2i_blocks.sub_blocks.pop("decode") +``` + +Now we will convert them into runnalbe pipelines and set up the Components Manager with auto offloading and organize components under a "t2i" collection + +Since we now have 3 different workflows that share components, we create a separate pipeline that serves as a dedicated loader to load all the components, register them to the component manager, and then reuse them across different workflows. + +```py +from diffusers import ComponentsManager, ModularPipeline + +# Set up Components Manager with auto offloading +components = ComponentsManager() +components.enable_auto_cpu_offload(device="cuda") + +# Create a new pipeline to load the components +t2i_repo = "YiYiXu/modular-demo-auto" +t2i_loader_pipe = ModularPipeline.from_pretrained(t2i_repo, components_manager=components, collection="t2i") + +# convert the 3 blocks into pipelines and attach the same components manager to all 3 +text_node = text_blocks.init_pipeline(t2i_repo, components_manager=components) +decoder_node = decoder_blocks.init_pipeline(t2i_repo, components_manager=components) +t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components) +``` + +Load all components into the loader pipeline, they should all be automatically registered to Components Manager under the "t2i" collection: + +```py +# Load all components (including IP-Adapter and ControlNet for later use) +t2i_loader_pipe.load_default_components(torch_dtype=torch.float16) +``` + +Now distribute the loaded components to each pipeline: + +```py +# Get VAE for decoder (using get_one since there's only one) +vae = components.get_one(load_id="SG161222/RealVisXL_V4.0|vae|null|null") +decoder_node.update_components(vae=vae) + +# Get text components for text node (using get_components_by_names for multiple components) +text_components = components.get_components_by_names(text_node.null_component_names) +text_node.update_components(**text_components) + +# Get remaining components for t2i pipeline +t2i_components = components.get_components_by_names(t2i_pipe.null_component_names) +t2i_pipe.update_components(**t2i_components) +``` + +Now we can generate images using our modular workflow: + +```py +# Generate text embeddings +prompt = "an astronaut" +text_embeddings = text_node(prompt=prompt, output=["prompt_embeds","negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds"]) + +# Generate latents and decode to image +generator = torch.Generator(device="cuda").manual_seed(0) +latents_t2i = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents") +image = decoder_node(latents=latents_t2i, output="images")[0] +image.save("modular_part2_t2i.png") +``` + +Let's add a LoRA: + +```py +# Load LoRA weights +>>> t2i_loader_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face") +>>> components +Components: +============================================================================================================================================================ +... +Additional Component Info: +================================================== + +unet: + Adapters: ['toy_face'] +``` + +You can see that the Components Manager tracks adapters metadata for all models it manages, and in our case, only Unet has lora loaded. This means we can reuse existing text embeddings. + +```py +# Generate with LoRA (reusing existing text embeddings) +generator = torch.Generator(device="cuda").manual_seed(0) +latents_lora = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents") +image = decoder_node(latents=latents_lora, output="images")[0] +image.save("modular_part2_lora.png") +``` + + +Now let's create a refiner pipeline that reuses components from our text-to-image workflow: + +```py +# Create refiner blocks (removing image_encoder and decode since we work with latents) +refiner_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["img2img"]) +refiner_blocks.sub_blocks.pop("image_encoder") +refiner_blocks.sub_blocks.pop("decode") + +# Create refiner pipeline with different repo and collection, +# Attach the same component manager to it +refiner_repo = "YiYiXu/modular_refiner" +refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, components_manager=components, collection="refiner") +``` + +We pass the **same Components Manager** (`components`) to the refiner pipeline, but with a **different collection** (`"refiner"`). This allows the refiner to access and reuse components from the "t2i" collection while organizing its own components (like the refiner UNet) under the "refiner" collection. + +```py +# Load only the refiner UNet (different from t2i UNet) +refiner_pipe.load_components(names="unet", torch_dtype=torch.float16) + +# Reuse components from t2i pipeline using pattern matching +reuse_components = components.search_components("text_encoder_2|scheduler|vae|tokenizer_2") +refiner_pipe.update_components(**reuse_components) +``` + +When we reuse components from the "t2i" collection, they automatically get added to the "refiner" collection as well. You can verify this by checking the Components Manager - you'll see components like `vae`, `scheduler`, etc. listed under both collections, indicating they're shared between workflows. + +Now we can refine any of our generated latents: + +```py +# Refine all our different latents +refined_latents = refiner_pipe(image_latents=latents_t2i, prompt=prompt, num_inference_steps=10, output="latents") +refined_image = decoder_node(latents=refined_latents, output="images")[0] +refined_image.save("modular_part2_t2i_refine_out.png") + +refined_latents = refiner_pipe(image_latents=latents_lora, prompt=prompt, num_inference_steps=10, output="latents") +refined_image = decoder_node(latents=refined_latents, output="images")[0] +refined_image.save("modular_part2_lora_refine_out.png") +``` + + +Here are the results from our modular pipeline examples. + +#### Base Text-to-Image Generation +| Base Text-to-Image | Base Text-to-Image (Refined) | +|-------------------|------------------------------| +| ![Base T2I](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_t2i.png) | ![Base T2I Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_t2i_refine_out.png) | + +#### LoRA +| LoRA | LoRA (Refined) | +|-------------------|------------------------------| +| ![LoRA](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_lora.png) | ![LoRA Refined](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/modular_part2_lora_refine_out.png) | + diff --git a/docs/source/en/modular_diffusers/end_to_end_guide.md b/docs/source/en/modular_diffusers/end_to_end_guide.md new file mode 100644 index 000000000000..cb7b87552a37 --- /dev/null +++ b/docs/source/en/modular_diffusers/end_to_end_guide.md @@ -0,0 +1,648 @@ + + +# End-to-End Developer Guide: Building with Modular Diffusers + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + + +In this tutorial we will walk through the process of adding a new pipeline to the modular framework using differential diffusion as our example. We'll cover the complete workflow from implementation to deployment: implementing the new pipeline, ensuring compatibility with existing tools, sharing the code on Hugging Face Hub, and deploying it as a UI node. + +We'll also demonstrate the 4-step framework process we use for implementing new basic pipelines in the modular system. + +1. **Start with an existing pipeline as a base** + - Identify which existing pipeline is most similar to the one you want to implement + - Determine what part of the pipeline needs modification + +2. **Build a working pipeline structure first** + - Assemble the complete pipeline structure + - Use existing blocks wherever possible + - For new blocks, create placeholders (e.g. you can copy from similar blocks and change the name) without implementing custom logic just yet + +3. **Set up an example** + - Create a simple inference script with expected inputs/outputs + +4. **Implement your custom logic and test incrementally** + - Add the custom logics the blocks you want to change + - Test incrementally, and inspect pipeline states and debug as needed + +Let's see how this works with the Differential Diffusion example. + + +## Differential Diffusion Pipeline + +### Start with an existing pipeline + +Differential diffusion (https://differential-diffusion.github.io/) is an image-to-image workflow, so it makes sense for us to start with the preset of pipeline blocks used to build img2img pipeline (`IMAGE2IMAGE_BLOCKS`) and see how we can build this new pipeline with them. + +```py +>>> from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS +>>> IMAGE2IMAGE_BLOCKS = InsertableDict([ +... ("text_encoder", StableDiffusionXLTextEncoderStep), +... ("image_encoder", StableDiffusionXLVaeEncoderStep), +... ("input", StableDiffusionXLInputStep), +... ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), +... ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), +... ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), +... ("denoise", StableDiffusionXLDenoiseStep), +... ("decode", StableDiffusionXLDecodeStep) +... ]) +``` + +Note that "denoise" (`StableDiffusionXLDenoiseStep`) is a `LoopSequentialPipelineBlocks` that contains 3 loop blocks (more on LoopSequentialPipelineBlocks [here](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#loopsequentialpipelineblocks)) + +```py +>>> denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]() +>>> print(denoise_blocks) +``` + +```out +StableDiffusionXLDenoiseStep( + Class: StableDiffusionXLDenoiseLoopWrapper + + Description: Denoise step that iteratively denoise the latents. + Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `sub_blocks` sequencially: + - `StableDiffusionXLLoopBeforeDenoiser` + - `StableDiffusionXLLoopDenoiser` + - `StableDiffusionXLLoopAfterDenoiser` + This block supports both text2img and img2img tasks. + + + Components: + scheduler (`EulerDiscreteScheduler`) + guider (`ClassifierFreeGuidance`) + unet (`UNet2DConditionModel`) + + Sub-Blocks: + [0] before_denoiser (StableDiffusionXLLoopBeforeDenoiser) + Description: step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`) + + [1] denoiser (StableDiffusionXLLoopDenoiser) + Description: Step within the denoising loop that denoise the latents with guidance. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`) + + [2] after_denoiser (StableDiffusionXLLoopAfterDenoiser) + Description: step within the denoising loop that update the latents. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`) + +) +``` + +Let's compare standard image-to-image and differential diffusion! The key difference in algorithm is that standard image-to-image diffusion applies uniform noise across all pixels based on a single `strength` parameter, but differential diffusion uses a change map where each pixel value determines when that region starts denoising. Regions with lower values get "frozen" earlier by replacing them with noised original latents, preserving more of the original image. + +Therefore, the key differences when it comes to pipeline implementation would be: +1. The `prepare_latents` step (which prepares the change map and pre-computes noised latents for all timesteps) +2. The `denoise` step (which selectively applies denoising based on the change map) +3. Since differential diffusion doesn't use the `strength` parameter, we'll use the text-to-image `set_timesteps` step instead of the image-to-image version + +To implement differntial diffusion, we can reuse most blocks from image-to-image and text-to-image workflows, only modifying the `prepare_latents` step and the first part of the `denoise` step (i.e. `before_denoiser (StableDiffusionXLLoopBeforeDenoiser)`). + +Here's a flowchart showing the pipeline structure and the changes we need to make: + + +![DiffDiff Pipeline Structure](https://mermaid.ink/img/pako:eNqVVO9r4kAQ_VeWLQWFKEk00eRDwZpa7Q-ucPfpYpE1mdWlcTdsVmpb-7_fZk1tTCl3J0Sy8968N5kZ9g0nIgUc4pUk-Rr9iuYc6d_Ibs14vlXoQYpNrtqo07lAo1jBTi2AlynysWIa6DJmG7KCBnZpsHHMSqkqNjaxKC5ALRTbQKEgLyosMthVnEvIiYRFRhRwVaBoNpmUT0W7MrTJkUbSdJEInlbwxMDXcQpcsAKq6OH_2mDTODIY4yt0J0ReUaYGnLXiJVChdSsB-enfPhBnhnjT-rCQj-1K_8Ygt62YUAVy8Ykf4FvU6XYu9rpuIGqPpvXSzs_RVEj2KrgiGUp02zNQTHBEM_FcK3BfQbBHd7qAst-PxvW-9WOrypnNylG0G9oRUMYBFeolg-IQTTJSFDqOUkZp-fwsQURZloVnlPpLf2kVSoonCM-SwCUuqY6dZ5aqddjLd1YiMiFLNrWorrxj9EOmP4El37lsl_9p5PzFqIqwVwgdN981fDM94bphH5I06R8NXZ_4QcPQPTFs6JltPrS6JssFhw9N817l27bdyM-lSKAo6iVBAAnQY0n9wLO9wbcluY7ruUFDtdguH74K0yENKDkK-8nAG6TfNrfy_bf-HjdrlOfZS7VYSAlU5JAwyhLE9WrWVw1dWdPTXauDsy8LUkdHtnX_pfMnBOvSGluRNbGurbuTHtdZN9Zts1MljC19_7EUh0puwcIbkBtSHvFbic6xWsMG5jjUrymRT3M85-86Jyf8txCbjzQptqs1DinJCn3a5qm-viJG9M26OUYlcH0_jsWWKxwGttHA4Rve4dD1el3H8_yh49hD3_X7roVfcNhx-l3b14PxvGHQ0xMa9t4t_Gp8na7tDvu-4w08HXecweD9D4X54ZI) + + +### Build a Working Pipeline Structure + +ok now we've identified the blocks to modify, let's build the pipeline skeleton first - at this stage, our goal is to get the pipeline struture working end-to-end (even though it's just doing the img2img behavior). I would simply create placeholder blocks by copying from existing ones: + +```py +>>> # Copy existing blocks as placeholders +>>> class SDXLDiffDiffPrepareLatentsStep(PipelineBlock): +... """Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later""" +... # ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep +... +>>> class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock): +... """Copied from StableDiffusionXLLoopBeforeDenoiser - will modify later""" +... # ... same implementation as StableDiffusionXLLoopBeforeDenoiser +``` + +`SDXLDiffDiffLoopBeforeDenoiser` is the be part of the denoise loop we need to change. Let's use it to assemble a `SDXLDiffDiffDenoiseStep`. + +```py +>>> class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): +... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser] +... block_names = ["before_denoiser", "denoiser", "after_denoiser"] +``` + +Now we can put together our differential diffusion pipeline. + +```py +>>> DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy() +>>> DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"] +>>> DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep +>>> DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep +>>> +>>> dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS) +>>> print(dd_blocks) +>>> # At this point, the pipeline works exactly like img2img since our blocks are just copies +``` + +### Set up an example + +ok, so now our blocks should be able to compile without an error, we can move on to the next step. Let's setup a simple example so we can run the pipeline as we build it. diff-diff use same model checkpoints as SDXL so we can fetch the models from a regular SDXL repo. + +```py +>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") +>>> dd_pipeline.load_default_componenets(torch_dtype=torch.float16) +>>> dd_pipeline.to("cuda") +``` + +We will use this example script: + +```py +>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") +>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") +>>> +>>> prompt = "a green pear" +>>> negative_prompt = "blurry" +>>> +>>> image = dd_pipeline( +... prompt=prompt, +... negative_prompt=negative_prompt, +... num_inference_steps=25, +... diffdiff_map=mask, +... image=image, +... output="images" +... )[0] +>>> +>>> image.save("diffdiff_out.png") +``` + +If you run the script right now, you will get a complaint about unexpected input `diffdiff_map`. +and you would get the same result as the original img2img pipeline. + +### implement your custom logic and test incrementally + +Let's modify the pipeline so that we can get expected result with this example script. + +We'll start with the `prepare_latents` step. The main changes are: +- Requires a new user input `diffdiff_map` +- Requires new component `mask_processor` to process the `diffdiff_map` +- Requires new intermediate inputs: + - Need `timestep` instead of `latent_timestep` to precompute all the latents + - Need `num_inference_steps` to create the `diffdiff_masks` +- create a new output `diffdiff_masks` and `original_latents` + + + +💡 use `print(dd_pipeline.doc)` to check compiled inputs and outputs of the built piepline. + +e.g. after we added `diffdiff_map` as an input in this step, we can run `print(dd_pipeline.doc)` to verify that it shows up in the docstring as a user input. + + + +Once we make sure all the variables we need are available in the block state, we can implement the diff-diff logic inside `__call__`. We created 2 new variables: the change map `diffdiff_mask` and the pre-computed noised latents for all timesteps `original_latents`. + + + +💡 Implement incrementally! Run the example script as you go, and insert `print(state)` and `print(block_state)` everywhere inside the `__call__` method to inspect the intermediate results. This helps you understand what's going on and what each line you just added does. + + + +Here are the key changes we made to implement differential diffusion: + +**1. Modified `prepare_latents` step:** +```diff +class SDXLDiffDiffPrepareLatentsStep(PipelineBlock): + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("scheduler", EulerDiscreteScheduler), ++ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True})) + ] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ ++ InputParam("diffdiff_map", required=True), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), +- InputParam("latent_timestep", required=True, type_hint=torch.Tensor), ++ InputParam("timesteps", type_hint=torch.Tensor), ++ InputParam("num_inference_steps", type_hint=int), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ ++ OutputParam("original_latents", type_hint=torch.Tensor), ++ OutputParam("diffdiff_masks", type_hint=torch.Tensor), + ] + + def __call__(self, components, state: PipelineState): + # ... existing logic ... ++ # Process change map and create masks ++ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width) ++ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps ++ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0)) ++ block_state.original_latents = block_state.latents +``` + +**2. Modified `before_denoiser` step:** +```diff +class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock): + @property + def description(self) -> str: + return ( + "Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser" + ) + ++ @property ++ def inputs(self) -> List[Tuple[str, Any]]: ++ return [ ++ InputParam("denoising_start"), ++ ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor), ++ InputParam("original_latents", type_hint=torch.Tensor), ++ InputParam("diffdiff_masks", type_hint=torch.Tensor), + ] + + def __call__(self, components, block_state, i, t): ++ # Apply differential diffusion logic ++ if i == 0 and block_state.denoising_start is None: ++ block_state.latents = block_state.original_latents[:1] ++ else: ++ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1) ++ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask) + + # ... rest of existing logic ... +``` + +That's all there is to it! We've just created a simple sequential pipeline by mix-and-match some existing and new pipeline blocks. + +Now we use the process we've prepred in step2 to build the pipeline and inspect it. + + +```py +>> dd_pipeline +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + Description: + + + Components: + text_encoder (`CLIPTextModel`) + text_encoder_2 (`CLIPTextModelWithProjection`) + tokenizer (`CLIPTokenizer`) + tokenizer_2 (`CLIPTokenizer`) + guider (`ClassifierFreeGuidance`) + vae (`AutoencoderKL`) + image_processor (`VaeImageProcessor`) + scheduler (`EulerDiscreteScheduler`) + mask_processor (`VaeImageProcessor`) + unet (`UNet2DConditionModel`) + + Configs: + force_zeros_for_empty_prompt (default: True) + requires_aesthetics_score (default: False) + + Blocks: + [0] text_encoder (StableDiffusionXLTextEncoderStep) + Description: Text Encoder step that generate text_embeddings to guide the image generation + + [1] image_encoder (StableDiffusionXLVaeEncoderStep) + Description: Vae Encoder step that encode the input image into a latent representation + + [2] input (StableDiffusionXLInputStep) + Description: Input processing step that: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt` + + All input tensors are expected to have either batch_size=1 or match the batch_size + of prompt_embeds. The tensors will be duplicated across the batch dimension to + have a final batch_size of batch_size * num_images_per_prompt. + + [3] set_timesteps (StableDiffusionXLSetTimestepsStep) + Description: Step that sets the scheduler's timesteps for inference + + [4] prepare_latents (SDXLDiffDiffPrepareLatentsStep) + Description: Step that prepares the latents for the differential diffusion generation process + + [5] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep) + Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process + + [6] denoise (SDXLDiffDiffDenoiseStep) + Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes + + [7] decode (StableDiffusionXLDecodeStep) + Description: Step that decodes the denoised latents into images + +) +``` + +Run the example now, you should see an apple with its right half transformed into a green pear. + +![Image description](https://cdn-uploads.huggingface.co/production/uploads/624ef9ba9d608e459387b34e/4zqJOz-35Q0i6jyUW3liL.png) + + +## Adding IP-adapter + +We provide an auto IP-adapter block that you can plug-and-play into your modular workflow. It's an `AutoPipelineBlocks`, so it will only run when the user passes an IP adapter image. In this tutorial, we'll focus on how to package it into your differential diffusion workflow. To learn more about `AutoPipelineBlocks`, see [here](./auto_pipeline_blocks.md) + +We talked about how to add IP-adapter into your workflow in the [Modular Pipeline Guide](./modular_pipeline.md). Let's just go ahead to create the IP-adapter block. + +```py +>>> from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep +>>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep() +``` + +We can directly add the ip-adapter block instance to the `diffdiff_blocks` that we created before. The `sub_blocks` attribute is a `InsertableDict`, so we're able to insert the it at specific position (index `0` here). + +```py +>>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0) +``` + +Take a look at the new diff-diff pipeline with ip-adapter! + +```py +>>> print(dd_blocks) +``` + +The pipeline now lists ip-adapter as its first block, and tells you that it will run only if `ip_adapter_image` is provided. It also includes the two new components from ip-adpater: `image_encoder` and `feature_extractor` + +```out +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: {'ip_adapter_image'} + Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`). + ==================================================================================================== + + + Description: + + + Components: + image_encoder (`CLIPVisionModelWithProjection`) + feature_extractor (`CLIPImageProcessor`) + unet (`UNet2DConditionModel`) + guider (`ClassifierFreeGuidance`) + text_encoder (`CLIPTextModel`) + text_encoder_2 (`CLIPTextModelWithProjection`) + tokenizer (`CLIPTokenizer`) + tokenizer_2 (`CLIPTokenizer`) + vae (`AutoencoderKL`) + image_processor (`VaeImageProcessor`) + scheduler (`EulerDiscreteScheduler`) + mask_processor (`VaeImageProcessor`) + + Configs: + force_zeros_for_empty_prompt (default: True) + requires_aesthetics_score (default: False) + + Blocks: + [0] ip_adapter (StableDiffusionXLAutoIPAdapterStep) + Description: Run IP Adapter step if `ip_adapter_image` is provided. + + [1] text_encoder (StableDiffusionXLTextEncoderStep) + Description: Text Encoder step that generate text_embeddings to guide the image generation + + [2] image_encoder (StableDiffusionXLVaeEncoderStep) + Description: Vae Encoder step that encode the input image into a latent representation + + [3] input (StableDiffusionXLInputStep) + Description: Input processing step that: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt` + + All input tensors are expected to have either batch_size=1 or match the batch_size + of prompt_embeds. The tensors will be duplicated across the batch dimension to + have a final batch_size of batch_size * num_images_per_prompt. + + [4] set_timesteps (StableDiffusionXLSetTimestepsStep) + Description: Step that sets the scheduler's timesteps for inference + + [5] prepare_latents (SDXLDiffDiffPrepareLatentsStep) + Description: Step that prepares the latents for the differential diffusion generation process + + [6] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep) + Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process + + [7] denoise (SDXLDiffDiffDenoiseStep) + Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes + + [8] decode (StableDiffusionXLDecodeStep) + Description: Step that decodes the denoised latents into images + +) +``` + +Let's test it out. We used an orange image to condition the generation via ip-addapter and we can see a slight orange color and texture in the final output. + + +```py +>>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep() +>>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0) +>>> +>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") +>>> dd_pipeline.load_default_components(torch_dtype=torch.float16) +>>> dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") +>>> dd_pipeline.loader.set_ip_adapter_scale(0.6) +>>> dd_pipeline = dd_pipeline.to(device) +>>> +>>> ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg") +>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") +>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") +>>> +>>> prompt = "a green pear" +>>> negative_prompt = "blurry" +>>> generator = torch.Generator(device=device).manual_seed(42) +>>> +>>> image = dd_pipeline( +... prompt=prompt, +... negative_prompt=negative_prompt, +... num_inference_steps=25, +... generator=generator, +... ip_adapter_image=ip_adapter_image, +... diffdiff_map=mask, +... image=image, +... output="images" +... )[0] +``` + +## Working with ControlNets + +What about controlnet? Can differential diffusion work with controlnet? The key differences between a regular pipeline and a ControlNet pipeline are: +1. A ControlNet input step that prepares the control condition +2. Inside the denoising loop, a modified denoiser step where the control image is first processed through ControlNet, then control information is injected into the UNet + +From looking at the code workflow: differential diffusion only modifies the "before denoiser" step, while ControlNet operates within the "denoiser" itself. Since they intervene at different points in the pipeline, they should work together without conflicts. + +Intuitively, these two techniques are orthogonal and should combine naturally: differential diffusion controls how much the inference process can deviate from the original in each region, while ControlNet controls in what direction that change occurs. + +With this understanding, let's assemble the diffdiff-controlnet loop by combining the diffdiff before-denoiser step and controlnet denoiser step. + +```py +>>> class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): +... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] +... block_names = ["before_denoiser", "denoiser", "after_denoiser"] +>>> +>>> controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep() +>>> # print(controlnet_denoise) +``` + +We provide a auto controlnet input block that you can directly put into your workflow to proceess the `control_image`: similar to auto ip-adapter block, this step will only run if `control_image` input is passed from user. It work with both controlnet and controlnet union. + + +```py +>>> from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep +>>> control_input_block = StableDiffusionXLAutoControlNetInputStep() +>>> print(control_input_block) +``` + +```out +StableDiffusionXLAutoControlNetInputStep( + Class: AutoPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: ['control_image', 'control_mode'] + ==================================================================================================== + + + Description: Controlnet Input step that prepare the controlnet input. + This is an auto pipeline block that works for both controlnet and controlnet_union. + (it should be called right before the denoise step) - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided. + - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. - if neither `control_mode` nor `control_image` is provided, step will be skipped. + + + Components: + controlnet (`ControlNetUnionModel`) + control_image_processor (`VaeImageProcessor`) + + Sub-Blocks: + • controlnet_union [trigger: control_mode] (StableDiffusionXLControlNetUnionInputStep) + Description: step that prepares inputs for the ControlNetUnion model + + • controlnet [trigger: control_image] (StableDiffusionXLControlNetInputStep) + Description: step that prepare inputs for controlnet + +) + +``` + +Let's assemble the blocks and run an example using controlnet + differential diffusion. We used a tomato as `control_image`, so you can see that in the output, the right half that transformed into a pear had a tomato-like shape. + +```py +>>> dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7) +>>> dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block +>>> +>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") +>>> dd_pipeline.load_default_components(torch_dtype=torch.float16) +>>> dd_pipeline = dd_pipeline.to(device) +>>> +>>> control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg") +>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") +>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") +>>> +>>> prompt = "a green pear" +>>> negative_prompt = "blurry" +>>> generator = torch.Generator(device=device).manual_seed(42) +>>> +>>> image = dd_pipeline( +... prompt=prompt, +... negative_prompt=negative_prompt, +... num_inference_steps=25, +... generator=generator, +... control_image=control_image, +... controlnet_conditioning_scale=0.5, +... diffdiff_map=mask, +... image=image, +... output="images" +... )[0] +``` + +Optionally, We can combine `SDXLDiffDiffControlNetDenoiseStep` and `SDXLDiffDiffDenoiseStep` into a `AutoPipelineBlocks` so that same workflow can work with or without controlnet. + + +```py +>>> class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks): +... block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep] +... block_names = ["controlnet_denoise", "denoise"] +... block_trigger_inputs = ["controlnet_cond", None] +``` + +`SDXLDiffDiffAutoDenoiseStep` will run the ControlNet denoise step if `control_image` input is provided, otherwise it will run the regular denoise step. + + + + Note that it's perfectly fine not to use `AutoPipelineBlocks`. In fact, we recommend only using `AutoPipelineBlocks` to package your workflow at the end once you've verified all your pipelines work as expected. + + + +Now you can create the differential diffusion preset that works with ip-adapter & controlnet. + +```py +>>> DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy() +>>> DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep +>>> DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"] +>>> DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep +>>> DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0) +>>> DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7) +>>> +>>> print(DIFFDIFF_AUTO_BLOCKS) +``` + +to use + +```py +>>> dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS) +>>> dd_pipeline = dd_auto_blocks.init_pipeline(...) +``` +## Creating a Modular Repo + +You can easily share your differential diffusion workflow on the Hub by creating a modular repo. This is one created using the code we just wrote together: https://huggingface.co/YiYiXu/modular-diffdiff + +To create a Modular Repo and share on hub, you just need to run `save_pretrained()` along with the `push_to_hub=True` flag. Note that if your pipeline contains custom block, you need to manually upload the code to the hub. But we are working on a command line tool to help you upload it very easily. + +```py +dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True) +``` + +With a modular repo, it is very easy for the community to use the workflow you just created! Here is an example to use the differential-diffusion pipeline we just created and shared. + +```py +>>> from diffusers.modular_pipelines import ModularPipeline, ComponentsManager +>>> import torch +>>> from diffusers.utils import load_image +>>> +>>> repo_id = "YiYiXu/modular-diffdiff-0704" +>>> +>>> components = ComponentsManager() +>>> +>>> diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, components_manager=components, collection="diffdiff") +>>> diffdiff_pipeline.load_default_components(torch_dtype=torch.float16) +>>> components.enable_auto_cpu_offload() +``` + +see more usage example on model card. + +## deploy a mellon node + +[YIYI TODO: for now, here is an example of mellon node https://huggingface.co/YiYiXu/diff-diff-mellon] diff --git a/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md new file mode 100644 index 000000000000..e95cdc7163b4 --- /dev/null +++ b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md @@ -0,0 +1,194 @@ + + +# LoopSequentialPipelineBlocks + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +`LoopSequentialPipelineBlocks` is a subclass of `ModularPipelineBlocks`. It is a multi-block that composes other blocks together in a loop, creating iterative workflows where blocks run multiple times with evolving state. It's particularly useful for denoising loops requiring repeated execution of the same blocks. + + + +Other types of multi-blocks include [SequentialPipelineBlocks](./sequential_pipeline_blocks.md) (for linear workflows) and [AutoPipelineBlocks](./auto_pipeline_blocks.md) (for conditional block selection). For information on creating individual blocks, see the [PipelineBlock guide](./pipeline_block.md). + +Additionally, like all `ModularPipelineBlocks`, `LoopSequentialPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md). + + + +You could create a loop using `PipelineBlock` like this: + +```python +class DenoiseLoop(PipelineBlock): + def __call__(self, components, state): + block_state = self.get_block_state(state) + for t in range(block_state.num_inference_steps): + # ... loop logic here + pass + self.set_block_state(state, block_state) + return components, state +``` + +But in this tutorial, we will focus on how to use `LoopSequentialPipelineBlocks` to create a "composable" denoising loop where you can add or remove blocks within the loop or reuse the same loop structure with different block combinations. + +It involves two parts: a **loop wrapper** and **loop blocks** + +* The **loop wrapper** (`LoopSequentialPipelineBlocks`) defines the loop structure, e.g. it defines the iteration variables, and loop configurations such as progress bar. + +* The **loop blocks** are basically standard pipeline blocks you add to the loop wrapper. + - they run sequentially for each iteration of the loop + - they receive the current iteration index as an additional parameter + - they share the same block_state throughout the entire loop + +Unlike regular `SequentialPipelineBlocks` where each block gets its own state, loop blocks share a single state that persists and evolves across iterations. + +We will build a simple loop block to demonstrate these concepts. Creating a loop block involves three steps: +1. defining the loop wrapper class +2. creating the loop blocks +3. adding the loop blocks to the loop wrapper class to create the loop wrapper instance + +**Step 1: Define the Loop Wrapper** + +To create a `LoopSequentialPipelineBlocks` class, you need to define: + +* `loop_inputs`: User input variables (equivalent to `PipelineBlock.inputs`) +* `loop_intermediate_inputs`: Intermediate variables needed from the mutable pipeline state (equivalent to `PipelineBlock.intermediates_inputs`) +* `loop_intermediate_outputs`: New intermediate variables this block will add to the mutable pipeline state (equivalent to `PipelineBlock.intermediates_outputs`) +* `__call__` method: Defines the loop structure and iteration logic + +Here is an example of a loop wrapper: + +```py +import torch +from diffusers.modular_pipelines import LoopSequentialPipelineBlocks, PipelineBlock, InputParam, OutputParam + +class LoopWrapper(LoopSequentialPipelineBlocks): + model_name = "test" + @property + def description(self): + return "I'm a loop!!" + @property + def loop_inputs(self): + return [InputParam(name="num_steps")] + @torch.no_grad() + def __call__(self, components, state): + block_state = self.get_block_state(state) + # Loop structure - can be customized to your needs + for i in range(block_state.num_steps): + # loop_step executes all registered blocks in sequence + components, block_state = self.loop_step(components, block_state, i=i) + self.set_block_state(state, block_state) + return components, state +``` + +**Step 2: Create Loop Blocks** + +Loop blocks are standard `PipelineBlock`s, but their `__call__` method works differently: +* It receives the iteration variable (e.g., `i`) passed by the loop wrapper +* It works directly with `block_state` instead of pipeline state +* No need to call `self.get_block_state()` or `self.set_block_state()` + +```py +class LoopBlock(PipelineBlock): + # this is used to identify the model family, we won't worry about it in this example + model_name = "test" + @property + def inputs(self): + return [InputParam(name="x")] + @property + def intermediate_outputs(self): + # outputs produced by this block + return [OutputParam(name="x")] + @property + def description(self): + return "I'm a block used inside the `LoopWrapper` class" + def __call__(self, components, block_state, i: int): + block_state.x += 1 + return components, block_state +``` + +**Step 3: Combine Everything** + +Finally, assemble your loop by adding the block(s) to the wrapper: + +```py +loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock}) +``` + +Now you've created a loop with one step: + +```py +>>> loop +LoopWrapper( + Class: LoopSequentialPipelineBlocks + + Description: I'm a loop!! + + Sub-Blocks: + [0] block1 (LoopBlock) + Description: I'm a block used inside the `LoopWrapper` class + +) +``` + +It has two inputs: `x` (used at each step within the loop) and `num_steps` used to define the loop. + +```py +>>> print(loop.doc) +class LoopWrapper + + I'm a loop!! + + Inputs: + + x (`None`, *optional*): + + num_steps (`None`, *optional*): + + Outputs: + + x (`None`): +``` + +**Running the Loop:** + +```py +# run the loop +loop_pipeline = loop.init_pipeline() +x = loop_pipeline(num_steps=10, x=0, output="x") +assert x == 10 +``` + +**Adding Multiple Blocks:** + +We can add multiple blocks to run within each iteration. Let's run the loop block twice within each iteration: + +```py +loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock}) +loop_pipeline = loop.init_pipeline() +x = loop_pipeline(num_steps=10, x=0, output="x") +assert x == 20 # Each iteration runs 2 blocks, so 10 iterations * 2 = 20 +``` + +**Key Differences from SequentialPipelineBlocks:** + +The main difference is that loop blocks share the same `block_state` across all iterations, allowing values to accumulate and evolve throughout the loop. Loop blocks could receive additional arguments (like the current iteration index) depending on the loop wrapper's implementation, since the wrapper defines how loop blocks are called. You can easily add, remove, or reorder blocks within the loop without changing the loop logic itself. + +The officially supported denoising loops in Modular Diffusers are implemented using `LoopSequentialPipelineBlocks`. You can explore the actual implementation to see how these concepts work in practice: + +```py +from diffusers.modular_pipelines.stable_diffusion_xl.denoise import StableDiffusionXLDenoiseStep +StableDiffusionXLDenoiseStep() +``` \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/modular_diffusers_states.md b/docs/source/en/modular_diffusers/modular_diffusers_states.md new file mode 100644 index 000000000000..744089fcf676 --- /dev/null +++ b/docs/source/en/modular_diffusers/modular_diffusers_states.md @@ -0,0 +1,59 @@ + + +# PipelineState and BlockState + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +In Modular Diffusers, `PipelineState` and `BlockState` are the core data structures that enable blocks to communicate and share data. The concept is fundamental to understand how blocks interact with each other and the pipeline system. + +In the modular diffusers system, `PipelineState` acts as the global state container that all pipeline blocks operate on. It maintains the complete runtime state of the pipeline and provides a structured way for blocks to read from and write to shared data. + +A `PipelineState` consists of two distinct states: + +- **The immutable state** (i.e. the `inputs` dict) contains a copy of values provided by users. Once a value is added to the immutable state, it cannot be changed. Blocks can read from the immutable state but cannot write to it. + +- **The mutable state** (i.e. the `intermediates` dict) contains variables that are passed between blocks and can be modified by them. + +Here's an example of what a `PipelineState` looks like: + +```py +PipelineState( + inputs={ + 'prompt': 'a cat' + 'guidance_scale': 7.0 + 'num_inference_steps': 25 + }, + intermediates={ + 'prompt_embeds': Tensor(dtype=torch.float32, shape=torch.Size([1, 1, 1, 1])) + 'negative_prompt_embeds': None + }, +) +``` + +Each pipeline blocks define what parts of that state they can read from and write to through their `inputs`, `intermediate_inputs`, and `intermediate_outputs` properties. At run time, they gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes. + +For example, if a block defines an input `image`, inside the block's `__call__` method, the `BlockState` would contain: + +```py +BlockState( + image: +) +``` + +You can access the variables directly as attributes: `block_state.image`. + +We will explore more on how blocks interact with pipeline state through their `inputs`, `intermediate_inputs`, and `intermediate_outputs` properties, see the [PipelineBlock guide](./pipeline_block.md). \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/modular_pipeline.md b/docs/source/en/modular_diffusers/modular_pipeline.md new file mode 100644 index 000000000000..55182b921fdb --- /dev/null +++ b/docs/source/en/modular_diffusers/modular_pipeline.md @@ -0,0 +1,1237 @@ + + +# ModularPipeline + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +`ModularPipeline` is the main interface for end users to run pipelines in Modular Diffusers. It takes pipeline blocks and converts them into a runnable pipeline that can load models and execute the computation steps. + +In this guide, we will focus on how to build pipelines using the blocks we officially support at diffusers 🧨. We'll cover how to use predefined blocks and convert them into a `ModularPipeline` for execution. + + + +This guide shows you how to use predefined blocks. If you want to learn how to create your own pipeline blocks, see the [PipelineBlock guide](pipeline_block.md) for creating individual blocks, and the multi-block guides for connecting them together: +- [SequentialPipelineBlocks](sequential_pipeline_blocks.md) (for linear workflows) +- [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows) +- [AutoPipelineBlocks](auto_pipeline_blocks.md) (for conditional workflows) + +For information on how data flows through pipelines, see the [PipelineState and BlockState guide](modular_diffusers_states.md). + + + + +## Create ModularPipelineBlocks + +In Modular Diffusers system, you build pipelines using Pipeline blocks. Pipeline Blocks are fundamental building blocks - they define what components, inputs/outputs, and computation logics are needed. They are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting. But they are just definitions and don't actually run anything. To execute blocks, you need to put them into a `ModularPipeline`. We'll first learn how to create predefined blocks here before talking about how to run them using `ModularPipeline`. + +All pipeline blocks inherit from the base class `ModularPipelineBlocks`, including: + +- [`PipelineBlock`]: The most granular block - you define the input/output/components requirements and computation logic. +- [`SequentialPipelineBlocks`]: A multi-block composed of multiple blocks that run sequentially, passing outputs as inputs to the next block. +- [`LoopSequentialPipelineBlocks`]: A special type of `SequentialPipelineBlocks` that runs the same sequence of blocks multiple times (loops), typically used for iterative processes like denoising steps in diffusion models. +- [`AutoPipelineBlocks`]: A multi-block composed of multiple blocks that are selected at runtime based on the inputs. + +It is very easy to use a `ModularPipelineBlocks` officially supported in 🧨 Diffusers + +```py +from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLTextEncoderStep + +text_encoder_block = StableDiffusionXLTextEncoderStep() +``` + +This is a single `PipelineBlock`. You'll see that this text encoder block uses 2 text_encoders, 2 tokenizers as well as a guider component. It takes user inputs such as `prompt` and `negative_prompt`, and return text embeddings outputs such as `prompt_embeds` and `negative_prompt_embeds`. + +```py +>>> text_encoder_block +StableDiffusionXLTextEncoderStep( + Class: PipelineBlock + Description: Text Encoder step that generate text_embeddings to guide the image generation + Components: + text_encoder (`CLIPTextModel`) + text_encoder_2 (`CLIPTextModelWithProjection`) + tokenizer (`CLIPTokenizer`) + tokenizer_2 (`CLIPTokenizer`) + guider (`ClassifierFreeGuidance`) + Configs: + force_zeros_for_empty_prompt (default: True) + Inputs: + prompt=None, prompt_2=None, negative_prompt=None, negative_prompt_2=None, cross_attention_kwargs=None, clip_skip=None + Intermediates: + - outputs: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds +) +``` + +More commonly, you need multiple blocks to build your workflow. You can create a `SequentialPipelineBlocks` using block class presets from 🧨 Diffusers. `TEXT2IMAGE_BLOCKS` is a dict containing all the blocks needed for text-to-image generation. + +```py +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS +t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) +``` + +This creates a `SequentialPipelineBlocks`. Unlike the `text_encoder_block` we saw earlier, this is a multi-block and its `sub_blocks` attribute contains a list of other blocks (text_encoder, input, set_timesteps, prepare_latents, prepare_added_con, denoise, decode). Its requirements for components, inputs, and intermediate inputs are combined from these blocks that compose it. At runtime, it executes its sub-blocks sequentially and passes the pipeline state from one block to another. + +```py +>>> t2i_blocks +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + Description: + + + Components: + text_encoder (`CLIPTextModel`) + text_encoder_2 (`CLIPTextModelWithProjection`) + tokenizer (`CLIPTokenizer`) + tokenizer_2 (`CLIPTokenizer`) + guider (`ClassifierFreeGuidance`) + scheduler (`EulerDiscreteScheduler`) + unet (`UNet2DConditionModel`) + vae (`AutoencoderKL`) + image_processor (`VaeImageProcessor`) + + Configs: + force_zeros_for_empty_prompt (default: True) + + Sub-Blocks: + [0] text_encoder (StableDiffusionXLTextEncoderStep) + Description: Text Encoder step that generate text_embeddings to guide the image generation + + [1] input (StableDiffusionXLInputStep) + Description: Input processing step that: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt` + + All input tensors are expected to have either batch_size=1 or match the batch_size + of prompt_embeds. The tensors will be duplicated across the batch dimension to + have a final batch_size of batch_size * num_images_per_prompt. + + [2] set_timesteps (StableDiffusionXLSetTimestepsStep) + Description: Step that sets the scheduler's timesteps for inference + + [3] prepare_latents (StableDiffusionXLPrepareLatentsStep) + Description: Prepare latents step that prepares the latents for the text-to-image generation process + + [4] prepare_add_cond (StableDiffusionXLPrepareAdditionalConditioningStep) + Description: Step that prepares the additional conditioning for the text-to-image generation process + + [5] denoise (StableDiffusionXLDenoiseStep) + Description: Denoise step that iteratively denoise the latents. + Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `sub_blocks` sequencially: + - `StableDiffusionXLLoopBeforeDenoiser` + - `StableDiffusionXLLoopDenoiser` + - `StableDiffusionXLLoopAfterDenoiser` + This block supports both text2img and img2img tasks. + + [6] decode (StableDiffusionXLDecodeStep) + Description: Step that decodes the denoised latents into images + +) +``` + +This is the block classes preset (`TEXT2IMAGE_BLOCKS`) we used: It is just a dictionary that maps names to ModularPipelineBlocks classes + +```py +>>> TEXT2IMAGE_BLOCKS +InsertableDict([ + 0: ('text_encoder', ), + 1: ('input', ), + 2: ('set_timesteps', ), + 3: ('prepare_latents', ), + 4: ('prepare_add_cond', ), + 5: ('denoise', ), + 6: ('decode', ) +]) +``` + +When we create a `SequentialPipelineBlocks` from this preset, it instantiates each block class into actual block objects. Its `sub_blocks` attribute now contains these instantiated objects: + +```py +>>> t2i_blocks.sub_blocks +InsertableDict([ + 0: ('text_encoder', ), + 1: ('input', ), + 2: ('set_timesteps', ), + 3: ('prepare_latents', ), + 4: ('prepare_add_cond', ), + 5: ('denoise', ), + 6: ('decode', ) +]) +``` + +Note that both the block classes preset and the `sub_blocks` attribute are `InsertableDict` objects. This is a custom dictionary that extends `OrderedDict` with the ability to insert items at specific positions. You can perform all standard dictionary operations (get, set, delete) plus insert items at any index, which is particularly useful for reordering or inserting blocks in the middle of a pipeline. + +**Add a block:** +```py +# BLOCKS is dict of block classes, you need to add class to it +BLOCKS.insert("block_name", BlockClass, index) +# sub_blocks attribute contains instance, add a block instance to the attribute +t2i_blocks.sub_blocks.insert("block_name", block_instance, index) +``` + +**Remove a block:** +```py +# remove a block class from preset +BLOCKS.pop("text_encoder") +# split out a block instance on its own +text_encoder_block = t2i_blocks.sub_blocks.pop("text_encoder") +``` + +**Swap block:** +```py +# Replace block class in preset +BLOCKS["prepare_latents"] = CustomPrepareLatents +# Replace in sub_blocks attribute using an block instance +t2i_blocks.sub_blocks["prepare_latents"] = CustomPrepareLatents() +``` + +This means you can mix-and-match blocks in very flexible ways. Let's see some real examples: + +**Example 1: Adding IP-Adapter to the Block Classes Preset** +Let's make a new block classes preset by insert IP-Adapter at index 0 (before the text_encoder block), and create a text-to-image pipeline with IP-Adapter support: + +```py +from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep +CUSTOM_BLOCKS = TEXT2IMAGE_BLOCKS.copy() +# CUSTOM_BLOCKS is now a preset including ip_adapter +CUSTOM_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0) +# create a blocks isntance from the preset +custom_blocks = SequentialPipelineBlocks.from_blocks_dict(CUSTOM_BLOCKS) +``` + +**Example 2: Extracting a block from a multi-block** +You can extract a block instance from the multi-block to use it independently. A common pattern is to use text_encoder to process prompts once, then reuse the text embeddings outputs to generate multiple images with different settings (schedulers, seeds, inference steps). We can do this by simply extracting the text_encoder block from the pipeline. + +```py +# this gives you StableDiffusionXLTextEncoderStep() +>>> text_encoder_blocks = t2i_blocks.sub_blocks.pop("text_encoder") +>>> text_encoder_blocks +``` + +The multi-block now has fewer components and no longer has the `text_encoder` block. If you check its docstring `t2i_blocks.doc`, you will see that it no longer accepts `prompt` as input - you will need to pass the embeddings instead. + +```py +>>> t2i_blocks +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + Description: + + Components: + scheduler (`EulerDiscreteScheduler`) + guider (`ClassifierFreeGuidance`) + unet (`UNet2DConditionModel`) + vae (`AutoencoderKL`) + image_processor (`VaeImageProcessor`) + + Blocks: + [0] input (StableDiffusionXLInputStep) + Description: Input processing step that: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt` + + All input tensors are expected to have either batch_size=1 or match the batch_size + of prompt_embeds. The tensors will be duplicated across the batch dimension to + have a final batch_size of batch_size * num_images_per_prompt. + + [1] set_timesteps (StableDiffusionXLSetTimestepsStep) + Description: Step that sets the scheduler's timesteps for inference + + [2] prepare_latents (StableDiffusionXLPrepareLatentsStep) + Description: Prepare latents step that prepares the latents for the text-to-image generation process + + [3] prepare_add_cond (StableDiffusionXLPrepareAdditionalConditioningStep) + Description: Step that prepares the additional conditioning for the text-to-image generation process + + [4] denoise (StableDiffusionXLDenoiseLoop) + Description: Denoise step that iteratively denoise the latents. + Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `blocks` sequencially: + - `StableDiffusionXLLoopBeforeDenoiser` + - `StableDiffusionXLLoopDenoiser` + - `StableDiffusionXLLoopAfterDenoiser` + + + [5] decode (StableDiffusionXLDecodeStep) + Description: Step that decodes the denoised latents into images + +) +``` + + + +💡 You can find all the block classes presets we support for each model in `ALL_BLOCKS`. + +```py +# For Stable Diffusion XL +from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS +ALL_BLOCKS +# For other models... +from diffusers.modular_pipelines. import ALL_BLOCKS +``` + +Each model provides a dictionary that maps all supported tasks/techniques to their corresponding block classes presets. For SDXL, it is + +```py +ALL_BLOCKS = { + "text2img": TEXT2IMAGE_BLOCKS, + "img2img": IMAGE2IMAGE_BLOCKS, + "inpaint": INPAINT_BLOCKS, + "controlnet": CONTROLNET_BLOCKS, + "ip_adapter": IP_ADAPTER_BLOCKS, + "auto": AUTO_BLOCKS, +} +``` + + + +This covers the essentials of pipeline blocks! Like we have already mentioned, **pipeline blocks are not runnable by themselves**. They are essentially **"definitions"** - they define the specifications and computational steps for a pipeline, but they do not contain any model states. To actually run them, you need to convert them into a `ModularPipeline` object. + + +## Modular Repo + +To convert blocks into a runnable pipeline, you may need a repository if your blocks contain **pretrained components** (models with checkpoints that need to be loaded from the Hub). Pipeline blocks define what components they need (like a UNet, text encoder, etc.), as well as how to create them: components can be either created using **from_pretrained** method (with checkpoints) or **from_config** (initialized from scratch with default configuration, usually stateless like a guider or scheduler). + +If your pipeline contains **pretrained components**, you typically need to use a repository to provide the loading specifications and metadata. + +`ModularPipeline` works specifically with modular repositories, which offer more flexibility in component loading compared to traditional repositories. You can find an example modular repo [here](https://huggingface.co/YiYiXu/modular-diffdiff). + +A `DiffusionPipeline` defines `model_index.json` to configure its components. However, repositories for Modular Diffusers work with `modular_model_index.json`. Let's walk through the differences here. + +In standard `model_index.json`, each component entry is a `(library, class)` tuple: +```py +"text_encoder": [ + "transformers", + "CLIPTextModel" +], +``` + +In `modular_model_index.json`, each component entry contains 3 elements: `(library, class, loading_specs_dict)` + +- `library` and `class`: Information about the actual component loaded in the pipeline at the time of saving (will be `null` if not loaded) +- `loading_specs_dict`: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint`. + +```py +"text_encoder": [ + null, # library of actual loaded component (same as in model_index.json) + null, # class of actual loaded componenet (same as in model_index.json) + { # loading specs map (unique to modular_model_index.json) + "repo": "stabilityai/stable-diffusion-xl-base-1.0", # can be a different repo + "revision": null, + "subfolder": "text_encoder", + "type_hint": [ # (library, class) for the expected component + "transformers", + "CLIPTextModel" + ], + "variant": null + } +], +``` + +Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs_dict`. e.g. the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories. + + +## Creating a `ModularPipeline` from `ModularPipelineBlocks` + +Each `ModularPipelineBlocks` has an `init_pipeline` method that can initialize a `ModularPipeline` object based on its component and configuration specifications. + +Let's convert our `t2i_blocks` (which we created earlier) into a runnable `ModularPipeline`. We'll use a `ComponentsManager` to handle device placement, memory management, and component reuse automatically: + +```py +# We already have this from earlier +t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) + +# Now convert it to a ModularPipeline +from diffusers import ComponentsManager +modular_repo_id = "YiYiXu/modular-loader-t2i-0704" +components = ComponentsManager() +t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components) +``` + + + +💡 **ComponentsManager** is the model registry and management system in diffusers, it track all the models in one place and let you add, remove and reuse them across different workflows in most efficient way. Without it, you'd need to manually manage GPU memory, device placement, and component sharing between workflows. See the [Components Manager guide](components_manager.md) for detailed information. + + + +The `init_pipeline()` method creates a ModularPipeline and loads component specifications from the repository's `modular_model_index.json` file, but doesn't load the actual models yet. + + +## Creating a `ModularPipeline` with `from_pretrained` + +You can create a `ModularPipeline` from a HuggingFace Hub repository with `from_pretrained` method, as long as it's a modular repo: + +```py +from diffusers import ModularPipeline, ComponentsManager +components = ComponentsManager() +pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components) +``` + +Loading custom code is also supported: + +```py +from diffusers import ModularPipeline, ComponentsManager +components = ComponentsManager() +modular_repo_id = "YiYiXu/modular-diffdiff-0704" +diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True, components_manager=components) +``` + +This modular repository contains custom code. The folder contains these files: + +``` +modular-diffdiff-0704/ +├── block.py # Custom pipeline blocks implementation +├── config.json # Pipeline configuration and auto_map +└── modular_model_index.json # Component loading specifications +``` + +The [`config.json`](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json) file defines a custom `DiffDiffBlocks` class and points to its implementation: + +```json +{ + "_class_name": "DiffDiffBlocks", + "auto_map": { + "ModularPipelineBlocks": "block.DiffDiffBlocks" + } +} +``` + +The `auto_map` tells the pipeline where to find the custom blocks definition - in this case, it's looking for `DiffDiffBlocks` in the `block.py` file. The actual `DiffDiffBlocks` class is defined in [`block.py`](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/block.py) within the repository. + +When `diffdiff_pipeline.blocks` is created, it's based on the `DiffDiffBlocks` definition from the custom code in the repository, allowing you to use specialized blocks that aren't part of the standard diffusers library. + +## Loading components into a `ModularPipeline` + +Unlike `DiffusionPipeline`, when you create a `ModularPipeline` instance (whether using `from_pretrained` or converting from pipeline blocks), its components aren't loaded automatically. You need to explicitly load model components using `load_default_components` or `load_components(names=..,)`: + +```py +# This will load ALL the expected components into pipeline +import torch +t2i_pipeline.load_default_components(torch_dtype=torch.float16) +t2i_pipeline.to("cuda") +``` + +All expected components are now loaded into the pipeline. You can also partially load specific components using the `names` argument. For example, to only load unet and vae: + +```py +>>> t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16) +``` + +You can inspect the pipeline's loading status by simply printing the pipeline itself. It helps you understand what components are expected to load, which ones are already loaded, how they were loaded, and what loading specs are available. Let's print out the `t2i_pipeline`: + +```py +>>> t2i_pipeline +StableDiffusionXLModularPipeline { + "_blocks_class_name": "SequentialPipelineBlocks", + "_class_name": "StableDiffusionXLModularPipeline", + "_diffusers_version": "0.35.0.dev0", + "force_zeros_for_empty_prompt": true, + "scheduler": [ + null, + null, + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "revision": null, + "subfolder": "scheduler", + "type_hint": [ + "diffusers", + "EulerDiscreteScheduler" + ], + "variant": null + } + ], + "text_encoder": [ + null, + null, + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "revision": null, + "subfolder": "text_encoder", + "type_hint": [ + "transformers", + "CLIPTextModel" + ], + "variant": null + } + ], + "text_encoder_2": [ + null, + null, + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "revision": null, + "subfolder": "text_encoder_2", + "type_hint": [ + "transformers", + "CLIPTextModelWithProjection" + ], + "variant": null + } + ], + "tokenizer": [ + null, + null, + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "revision": null, + "subfolder": "tokenizer", + "type_hint": [ + "transformers", + "CLIPTokenizer" + ], + "variant": null + } + ], + "tokenizer_2": [ + null, + null, + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "revision": null, + "subfolder": "tokenizer_2", + "type_hint": [ + "transformers", + "CLIPTokenizer" + ], + "variant": null + } + ], + "unet": [ + "diffusers", + "UNet2DConditionModel", + { + "repo": "RunDiffusion/Juggernaut-XL-v9", + "revision": null, + "subfolder": "unet", + "type_hint": [ + "diffusers", + "UNet2DConditionModel" + ], + "variant": "fp16" + } + ], + "vae": [ + "diffusers", + "AutoencoderKL", + { + "repo": "madebyollin/sdxl-vae-fp16-fix", + "revision": null, + "subfolder": null, + "type_hint": [ + "diffusers", + "AutoencoderKL" + ], + "variant": null + } + ] +} +``` + +You can see all the **pretrained components** that will be loaded using `from_pretrained` method are listed as entries. Each entry contains 3 elements: `(library, class, loading_specs_dict)`: + +- **`library` and `class`**: Show the actual loaded component info. If `null`, the component is not loaded yet. +- **`loading_specs_dict`**: Contains all the information needed to load the component (repo, subfolder, variant, etc.) + +In this example: +- **Loaded components**: `vae` and `unet` (their `library` and `class` fields show the actual loaded models) +- **Not loaded yet**: `scheduler`, `text_encoder`, `text_encoder_2`, `tokenizer`, `tokenizer_2` (their `library` and `class` fields are `null`, but you can see their loading specs to know where they'll be loaded from when you call `load_components()`) + +You're looking at essentailly the pipeline's config dict that's synced with the `modular_model_index.json` from the repository you used during `init_pipeline()` - it takes the loading specs that match the pipeline's component requirements. + +For example, if your pipeline needs a `text_encoder` component, it will include the loading spec for `text_encoder` from the modular repo during the `init_pipeline`. If the pipeline doesn't need a component (like `controlnet` in a basic text-to-image pipeline), that component won't be included even if it exists in the modular repo. + +There are also a few properties that can provide a quick summary of component loading status: + +```py +# All components expected by the pipeline +>>> t2i_pipeline.component_names +['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor'] + +# Components that are not loaded yet (will be loaded with from_pretrained) +>>> t2i_pipeline.null_component_names +['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler'] + +# Components that will be loaded from pretrained models +>>> t2i_pipeline.pretrained_component_names +['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae'] + +# Components that are created with default config (no repo needed) +>>> t2i_pipeline.config_component_names +['guider', 'image_processor'] +``` + +From config components (like `guider` and `image_processor`) are not included in the pipeline output above because they don't need loading specs - they're already initialized during pipeline creation. You can see this because they're not listed in `null_component_names`. + +## Modifying Loading Specs + +When you call `pipeline.load_components(names=)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. You can change where components are loaded from by modifying the `modular_model_index.json` in the repository. Just find the file on the Hub and click edit - you can change any field in the loading specs: `repo`, `subfolder`, `variant`, `revision`, etc. + +```py +# Original spec in modular_model_index.json +"unet": [ + null, null, + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "subfolder": "unet", + "variant": "fp16" + } +] + +# Modified spec - changed repo, subfolder, and variant +"unet": [ + null, null, + { + "repo": "RunDiffusion/Juggernaut-XL-v9", + "subfolder": "unet", + "variant": "fp16" + } +] +``` + +Now if you create a pipeline using the same blocks and updated repository, it will by default load from the new repository. + +```py +pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components) +pipeline.load_components(names="unet") +``` + + +## Updating components in a `ModularPipeline` + +Similar to `DiffusionPipeline`, you can load components separately to replace the default ones in the pipeline. In Modular Diffusers, the approach depends on the component type: + +- **Pretrained components** (`default_creation_method='from_pretrained'`): Must use `ComponentSpec` to load them to update the existing one. +- **Config components** (`default_creation_method='from_config'`): These are components that don't need loading specs - they're created during pipeline initialization with default config. To update them, you can either pass the object directly or pass a ComponentSpec directly. + + + +💡 **Component Type Changes**: The component type (pretrained vs config-based) can change when you update components. These types are initially defined in pipeline blocks' `expected_components` field using `ComponentSpec` with `default_creation_method`. See the [Customizing Guidance Techniques](#customizing-guidance-techniques) section for examples of how this works in practice. + + + +`ComponentSpec` defines how to create or load components and can actually create them using its `create()` method (for ConfigMixin objects) or `load()` method (wrapper around `from_pretrained()`). When a component is loaded with a ComponentSpec, it gets tagged with a unique ID that encodes its creation parameters, allowing you to always extract the original specification using `ComponentSpec.from_component()`. + +Now let's look at how to update pretrained components in practice: + +So instead of + +```py +from diffusers import UNet2DConditionModel +import torch +unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16", torch_dtype=torch.float16) +``` +You should load your model like this + +```py +from diffusers import ComponentSpec, UNet2DConditionModel +unet_spec = ComponentSpec(name="unet",type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16") +unet2 = unet_spec.load(torch_dtype=torch.float16) +``` + +The key difference is that the second unet retains its loading specs, so you can extract the spec and recreate the unet: + +```py +# component -> spec +>>> spec = ComponentSpec.from_component("unet", unet2) +>>> spec +ComponentSpec(name='unet', type_hint=, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained') +# spec -> component +>>> unet2_recreatd = spec.load(torch_dtype=torch.float16) +``` + +To replace the unet in the pipeline + +``` +t2i_pipeline.update_components(unet=unet2) +``` + +Not only is the `unet` component swapped, but its loading specs are also updated from "RunDiffusion/Juggernaut-XL-v9" to "stabilityai/stable-diffusion-xl-base-1.0" in pipeline config. This means that if you save the pipeline now and load it back with `from_pretrained`, the new pipeline will by default load the SDXL original unet. + +``` +>>> t2i_pipeline +StableDiffusionXLModularPipeline { + ... + "unet": [ + "diffusers", + "UNet2DConditionModel", + { + "repo": "stabilityai/stable-diffusion-xl-base-1.0", + "revision": null, + "subfolder": "unet", + "type_hint": [ + "diffusers", + "UNet2DConditionModel" + ], + "variant": "fp16" + } + ], + ... +} +``` + + +💡 **Modifying Component Specs**: You can get a copy of the current component spec from the pipeline using `get_component_spec()`. This makes it easy to modify the spec and updating components. + +```py +>>> unet_spec = t2i_pipeline.get_component_spec("unet") +>>> unet_spec +ComponentSpec( + name='unet', + type_hint=, + repo='RunDiffusion/Juggernaut-XL-v9', + subfolder='unet', + variant='fp16', + default_creation_method='from_pretrained' +) + +# Modify the spec to load from a different repository +>>> unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0" + +# Load the component with the modified spec +>>> unet = unet_spec.load(torch_dtype=torch.float16) +``` + + + +## Customizing Guidance Techniques + +Guiders are implementations of different [classifier-free guidance](https://huggingface.co/papers/2207.12598) techniques that can be applied during the denoising process to improve generation quality, control, and adherence to prompts. They work by steering the model predictions towards desired directions and away from undesired directions. In diffusers, guiders are implemented as subclasses of `BaseGuidance`. They can easily be integrated into modular pipelines and provide a flexible way to enhance generation quality without modifying the underlying diffusion models. + +**ClassifierFreeGuidance (CFG)** is the first and most common guidance technique, used in all our standard pipelines. We also offer many other guidance techniques from the latest research in this area - **PerturbedAttentionGuidance (PAG)**, **SkipLayerGuidance (SLG)**, **SmoothedEnergyGuidance (SEG)**, and others that can provide better results for specific use cases. + +This section demonstrates how to use guiders using the component updating methods we just learned. Since `BaseGuidance` components are stateless (similar to schedulers), they are typically created with default configurations during pipeline initialization using `default_creation_method='from_config'`. This means they don't require loading specs from the repository - you won't see guider listed in `modular_model_index.json` files. + +Let's take a look at the default guider configuration: + +```py +>>> t2i_pipeline.get_component_spec("guider") +ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 7.5), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['start', 'guidance_rescale', 'stop', 'use_original_formulation'])]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config') +``` + +As you can see, the guider is configured to use `ClassifierFreeGuidance` with default parameters and `default_creation_method='from_config'`, meaning it's created during pipeline initialization rather than loaded from a repository. Let's verify this, here we run `init_pipeline()` without a modular repo, and there it is, a guider with the default configuration we just saw + + +```py +>>> pipeline = t2i_blocks.init_pipeline() +>>> pipeline.guider +ClassifierFreeGuidance { + "_class_name": "ClassifierFreeGuidance", + "_diffusers_version": "0.35.0.dev0", + "guidance_rescale": 0.0, + "guidance_scale": 7.5, + "start": 0.0, + "stop": 1.0, + "use_original_formulation": false +} +``` + +#### Modify Parameters of the Same Guider Type + +To change parameters of the same guider type (e.g., adjusting the `guidance_scale` for CFG), you have two options: + +**Option 1: Use ComponentSpec.create() method** + +You just need to pass the parameter with the new value to override the default one. + +```python +>>> guider_spec = t2i_pipeline.get_component_spec("guider") +>>> guider = guider_spec.create(guidance_scale=10) +>>> t2i_pipeline.update_components(guider=guider) +``` + +**Option 2: Pass ComponentSpec directly** + +Update the spec directly and pass it to `update_components()`. + +```python +>>> guider_spec = t2i_pipeline.get_component_spec("guider") +>>> guider_spec.config["guidance_scale"] = 10 +>>> t2i_pipeline.update_components(guider=guider_spec) +``` + +Both approaches produce the same result: +```python +>>> t2i_pipeline.guider +ClassifierFreeGuidance { + "_class_name": "ClassifierFreeGuidance", + "_diffusers_version": "0.35.0.dev0", + "guidance_rescale": 0.0, + "guidance_scale": 10, + "start": 0.0, + "stop": 1.0, + "use_original_formulation": false +} +``` + +#### Switch to a Different Guider Type + +Switching between guidance techniques is as simple as passing a guider object of that technique: + +```py +from diffusers import LayerSkipConfig, PerturbedAttentionGuidance +config = LayerSkipConfig(indices=[2, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_attention_scores=True, skip_ff=False) +guider = PerturbedAttentionGuidance( + guidance_scale=5.0, perturbed_guidance_scale=2.5, perturbed_guidance_config=config +) +t2i_pipeline.update_components(guider=guider) +``` + +Note that you will get a warning about changing the guider type, which is expected: + +``` +ModularPipeline.update_components: adding guider with new type: PerturbedAttentionGuidance, previous type: ClassifierFreeGuidance +``` + + + +- For `from_config` components (like guiders, schedulers): You can pass an object of required type OR pass a ComponentSpec directly (which calls `create()` under the hood) +- For `from_pretrained` components (like models): You must use ComponentSpec to ensure proper tagging and loading + + + +Let's verify that the guider has been updated: + +```py +>>> t2i_pipeline.guider +PerturbedAttentionGuidance { + "_class_name": "PerturbedAttentionGuidance", + "_diffusers_version": "0.35.0.dev0", + "guidance_rescale": 0.0, + "guidance_scale": 5.0, + "perturbed_guidance_config": { + "dropout": 1.0, + "fqn": "mid_block.attentions.0.transformer_blocks", + "indices": [ + 2, + 9 + ], + "skip_attention": false, + "skip_attention_scores": true, + "skip_ff": false + }, + "perturbed_guidance_layers": null, + "perturbed_guidance_scale": 2.5, + "perturbed_guidance_start": 0.01, + "perturbed_guidance_stop": 0.2, + "start": 0.0, + "stop": 1.0, + "use_original_formulation": false +} + +``` + +The component spec has also been updated to reflect the new guider type: + +```py +>>> t2i_pipeline.get_component_spec("guider") +ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['perturbed_guidance_start', 'use_original_formulation', 'perturbed_guidance_layers', 'stop', 'start', 'guidance_rescale', 'perturbed_guidance_stop']), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config') +``` + +The "guider" is still a `from_config` component: is still not included in the pipeline config and will not be saved into the `modular_model_index.json`. + +```py +>>> assert "guider" not in t2i_pipeline.config +``` + +However, you can change it to a `from_pretrained` component, which allows you to upload your customized guider to the Hub and load it into your pipeline. + +#### Loading Custom Guiders from Hub + +If you already have a guider saved on the Hub and a `modular_model_index.json` with the loading spec for that guider, it will automatically be changed to a `from_pretrained` component during pipeline initialization. + +For example, this `modular_model_index.json` includes loading specs for the guider: + +```json +{ + "guider": [ + null, + null, + { + "repo": "YiYiXu/modular-loader-t2i-guider", + "revision": null, + "subfolder": "pag_guider", + "type_hint": [ + "diffusers", + "PerturbedAttentionGuidance" + ], + "variant": null + } + ] +} +``` + +When you use this repository to create a pipeline with the same blocks (that originally configured guider as a `from_config` component), the guider becomes a `from_pretrained` component. This means it doesn't get created during initialization, and after you call `load_default_components()`, it loads based on the spec - resulting in the PAG guider instead of the default CFG. + +```py +t2i_pipeline = t2i_blocks.init_pipeline("YiYiXu/modular-doc-guider") +assert t2i_pipeline.guider is None # Not created during init +t2i_pipeline.load_default_components() +t2i_pipeline.guider # Now loaded as PAG guider +``` + +#### Upload Custom Guider to Hub for Easy Loading & Sharing + +Now let's see how we can share the guider on the Hub and change it to a `from_pretrained` component. + +```py +guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider") +``` + +Voilà! Now you have a subfolder called `pag_guider` on that repository. + +You have a few options to make this guider available in your pipeline: + +1. **Directly modify the `modular_model_index.json`** to add a loading spec for the guider by pointing to a folder containing the desired guider config. + +2. **Use the `update_components` method** to change it to a `from_pretrained` component for your pipeline. This is easier if you just want to try it out with different repositories. + +Let's use the second approach and change our guider_spec to use `from_pretrained` as the default creation method and update the loading spec to use this subfolder we just created: + +```python +guider_spec = t2i_pipeline.get_component_spec("guider") +guider_spec.default_creation_method="from_pretrained" +guider_spec.repo="YiYiXu/modular-loader-t2i-guider" +guider_spec.subfolder="pag_guider" +pag_guider = guider_spec.load() +t2i_pipeline.update_components(guider=pag_guider) +``` + +You will get a warning about changing the creation method: + +``` +ModularPipeline.update_components: changing the default_creation_method of guider from from_config to from_pretrained. +``` + +Now not only the `guider` component and its component_spec are updated, but so is the pipeline config. + +If you want to change the default behavior for future pipelines, you can push the updated pipeline to the Hub. This way, when others use your repository, they'll get the PAG guider by default. However, this is optional - you don't have to do this if you just want to experiment locally. + +```py +t2i_pipeline.push_to_hub("YiYiXu/modular-doc-guider") +``` + + + + +Experiment with different techniques and parameters to find what works best for your specific use case! You can find all the guider class we support [here](TODO: API doc) + +Additionally, you can write your own guider implementations, for example, CFG Zero* combined with Skip Layer Guidance, and they should be compatible out-of-the-box with modular diffusers! + + + +## Running a `ModularPipeline` + +The API to run the `ModularPipeline` is very similar to how you would run a regular `DiffusionPipeline`: + +```py +>>> image = pipeline(prompt="a cat", num_inference_steps=15, output="images")[0] +``` + +There are a few key differences though: +1. You can also pass a `PipelineState` object directly to the pipeline instead of individual arguments +2. If you do not specify the `output` argument, it returns the `PipelineState` object +3. You can pass a list as `output`, e.g. `pipeline(... output=["images", "latents"])` will return a dictionary containing both the generated image and the final denoised latents + +Under the hood, `ModularPipeline`'s `__call__` method is a wrapper around the pipeline blocks' `__call__` method: it creates a `PipelineState` object and populates it with user inputs, then returns the output to the user based on the `output` argument. It also ensures that all pipeline-level config and components are exposed to all pipeline blocks by preparing and passing a `components` input. + + + +You can inspect the docstring of a `ModularPipeline` to check what arguments the pipeline accepts and how to specify the `output` you want. It will list all available outputs (basically everything in the intermediate pipeline state) so you can choose from the list. + +```py +t2i_pipeline.doc +``` + +**Important**: It is important to always check the docstring because arguments can be different from standard pipelines that you're familar with. For example, in Modular Diffusers we standardized controlnet image input as `control_image`, but regular pipelines have inconsistencies over the names, e.g. controlnet text-to-image uses `image` while SDXL controlnet img2img uses `control_image`. + +**Note**: The `output` list might be longer than you expected - it includes everything in the intermediate state that you can choose to return. Most of the time, you'll just want `output="images"` or `output="latents"`. + + + +#### Text-to-Image, Image-to-Image, and Inpainting + +These are minimum inference examples for basic tasks: text-to-image, image-to-image, and inpainting. The process to create different pipelines is the same - only difference is the block classes presets. The inference is also more or less same to standard pipelines, but please always check `.doc` for correct input names and remember to pass `output="images"`. + + + + + +```py +import torch +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS + +# create pipeline from official blocks preset +blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) + +modular_repo_id = "YiYiXu/modular-loader-t2i-0704" +pipeline = blocks.init_pipeline(modular_repo_id) + +pipeline.load_default_components(torch_dtype=torch.float16) +pipeline.to("cuda") + +# run pipeline, need to pass a "output=images" argument +image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", output="images")[0] +image.save("modular_t2i_out.png") +``` + + + + +```py +import torch +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS + +# create pipeline from blocks preset +blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS) + +modular_repo_id = "YiYiXu/modular-loader-t2i-0704" +pipeline = blocks.init_pipeline(modular_repo_id) + +pipeline.load_default_components(torch_dtype=torch.float16) +pipeline.to("cuda") + +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png" +init_image = load_image(url) +prompt = "a dog catching a frisbee in the jungle" +image = pipeline(prompt=prompt, image=init_image, strength=0.8, output="images")[0] +image.save("modular_i2i_out.png") +``` + + + + +```py +import torch +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS +from diffusers.utils import load_image + +# create pipeline from blocks preset +blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS) + +modular_repo_id = "YiYiXu/modular-loader-t2i-0704" +pipeline = blocks.init_pipeline(modular_repo_id) + +pipeline.load_default_components(torch_dtype=torch.float16) +pipeline.to("cuda") + +img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png" +mask_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png" + +init_image = load_image(img_url) +mask_image = load_image(mask_url) + +prompt = "A deep sea diver floating" +image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, output="images")[0] +image.save("moduar_inpaint_out.png") +``` + + + + +#### ControlNet + +For ControlNet, we provide one auto block you can place at the `denoise` step. Let's create it and inspect it to see what it tells us. + + + +💡 **How to explore new tasks**: When you want to figure out how to do a specific task in Modular Diffusers, it is a good idea to start by checking what block classes presets we offer in `ALL_BLOCKS`. Then create the block instance and inspect it - it will show you the required components, description, and sub-blocks. This is crucial for understanding what each block does and what it needs. + + + +```py +>>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS +>>> ALL_BLOCKS["controlnet"] +InsertableDict([ + 0: ('denoise', ) +]) +>>> controlnet_blocks = ALL_BLOCKS["controlnet"]["denoise"]() +>>> controlnet_blocks +StableDiffusionXLAutoControlnetStep( + Class: SequentialPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: {'mask', 'control_mode', 'control_image', 'controlnet_cond'} + Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('mask')`). + ==================================================================================================== + + + Description: Controlnet auto step that prepare the controlnet input and denoise the latents. It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks. (it should be replace at 'denoise' step) + + + Components: + controlnet (`ControlNetUnionModel`) + control_image_processor (`VaeImageProcessor`) + scheduler (`EulerDiscreteScheduler`) + unet (`UNet2DConditionModel`) + guider (`ClassifierFreeGuidance`) + + Sub-Blocks: + [0] controlnet_input (StableDiffusionXLAutoControlNetInputStep) + Description: Controlnet Input step that prepare the controlnet input. + This is an auto pipeline block that works for both controlnet and controlnet_union. + (it should be called right before the denoise step) - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided. + - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. - if neither `control_mode` nor `control_image` is provided, step will be skipped. + + [1] controlnet_denoise (StableDiffusionXLAutoControlNetDenoiseStep) + Description: Denoise step that iteratively denoise the latents with controlnet. This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks.This block should not be used without a controlnet_cond input - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided. - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided. - If neither mask nor controlnet_cond are provided, step will be skipped. + +) +``` + + + +💡 **Auto Blocks**: This is first time we meet a Auto Blocks! `AutoPipelineBlocks` automatically adapt to your inputs by combining multiple workflows with conditional logic. This is why one convenient block can work for all tasks and controlnet types. See the [Auto Blocks Guide](./auto_pipeline_blocks.md) for more details. + + + +The block shows us it has two steps (prepare inputs + denoise) and supports all tasks with both controlnet and controlnet union. Most importantly, it tells us to place it at the 'denoise' step. Let's do exactly that: + +```py +import torch +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS, StableDiffusionXLAutoControlnetStep +from diffusers.utils import load_image + +# create pipeline from blocks preset +blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) + +# these two lines applies controlnet +controlnet_blocks = StableDiffusionXLAutoControlnetStep() +blocks.sub_blocks["denoise"] = controlnet_blocks +``` + +Before we convert the blocks into a pipeline and load its components, let's inspect the blocks and its docs again to make sure it was assembled correctly. You should be able to see that `controlnet` and `control_image_processor` are now listed as `Components`, so we should initialize the pipeline with a repo that contains desired loading specs for these 2 components. + +```py +# make sure to a modular_repo including controlnet +modular_repo_id = "YiYiXu/modular-demo-auto" +pipeline = blocks.init_pipeline(modular_repo_id) +pipeline.load_default_components(torch_dtype=torch.float16) +pipeline.to("cuda") + +# generate +canny_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" +) +image = pipeline( + prompt="a bird", controlnet_conditioning_scale=0.5, control_image=canny_image, output="images" +)[0] +image.save("modular_control_out.png") +``` + +#### IP-Adapter + +**Challenge time!** Before we show you how to apply IP-adapter, try doing it yourself! Use the same process we just walked you through with ControlNet: check the official blocks preset, inspect the block instance and docstring `.doc`, and adapt a regular IP-adapter example to modular. + +Let's walk through the steps: + +1. Check blocks preset + +```py +>>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS +>>> ALL_BLOCKS["ip_adapter"] +InsertableDict([ + 0: ('ip_adapter', ) +]) +``` + +2. inspect the block & doc + +``` +>>> from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep +>>> ip_adapter_blocks = StableDiffusionXLAutoIPAdapterStep() +>>> ip_adapter_blocks +StableDiffusionXLAutoIPAdapterStep( + Class: AutoPipelineBlocks + + ==================================================================================================== + This pipeline contains blocks that are selected at runtime based on inputs. + Trigger Inputs: {'ip_adapter_image'} + Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`). + ==================================================================================================== + + + Description: Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step. + + + + Components: + image_encoder (`CLIPVisionModelWithProjection`) + feature_extractor (`CLIPImageProcessor`) + unet (`UNet2DConditionModel`) + guider (`ClassifierFreeGuidance`) + + Sub-Blocks: + • ip_adapter [trigger: ip_adapter_image] (StableDiffusionXLIPAdapterStep) + Description: IP Adapter step that prepares ip adapter image embeddings. + Note that this step only prepares the embeddings - in order for it to work correctly, you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale(). + See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin) for more details + +) +``` +3. follow the instruction to build + +```py +import torch +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS + +# create pipeline from official blocks preset +blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) + +# insert ip_adapter_blocks before the input step as instructed +blocks.sub_blocks.insert("ip_adapter", ip_adapter_blocks, 1) + +# inspec the blocks before you convert it into pipelines, +# and make sure to use a repo that contains the loading spec for all components +# for ip-adapter, you need image_encoder & feature_extractor +modular_repo_id = "YiYiXu/modular-demo-auto" +pipeline = blocks.init_pipeline(modular_repo_id) + +pipeline.load_default_components(torch_dtype=torch.float16) +pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter_sdxl.bin" +) +pipeline.set_ip_adapter_scale(0.8) +pipeline.to("cuda") +``` + +4. adapt an example to modular + +We are using [this one](https://huggingface.co/docs/diffusers/using-diffusers/ip_adapter?ipadapter-variants=IP-Adapter+Plus#ip-adapter) from our IP-Adapter doc! + + +```py +from diffusers.utils import load_image +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png") +image = pipeline( + prompt="a polar bear sitting in a chair drinking a milkshake", + ip_adapter_image=image, + negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality", + output="images" +)[0] +image.save("modular_ipa_out.png") +``` + + diff --git a/docs/source/en/modular_diffusers/overview.md b/docs/source/en/modular_diffusers/overview.md new file mode 100644 index 000000000000..9702cea0633d --- /dev/null +++ b/docs/source/en/modular_diffusers/overview.md @@ -0,0 +1,42 @@ + + +# Getting Started with Modular Diffusers + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers lets you: + +**Write Only What's New**: You won't need to write an entire pipeline from scratch every time you have a new use case. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities. + +**Assemble Like LEGO®**: You can mix and match between blocks in flexible ways. This allows you to write dedicated blocks unique to specific workflows, and then assemble different blocks into a pipeline that can be used more conveniently for multiple workflows. + + +Here's how our guides are organized to help you navigate the Modular Diffusers documentation: + +### 🚀 Running Pipelines +- **[Modular Pipeline Guide](./modular_pipeline.md)** - How to use predefined blocks to build a pipeline and run it +- **[Components Manager Guide](./components_manager.md)** - How to manage and reuse components across multiple pipelines + +### 📚 Creating PipelineBlocks +- **[Pipeline and Block States](./modular_diffusers_states.md)** - Understanding PipelineState and BlockState +- **[Pipeline Block](./pipeline_block.md)** - How to write custom PipelineBlocks +- **[SequentialPipelineBlocks](sequential_pipeline_blocks.md)** - Connecting blocks in sequence +- **[LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks.md)** - Creating iterative workflows +- **[AutoPipelineBlocks](./auto_pipeline_blocks.md)** - Conditional block selection + +### 🎯 Practical Examples +- **[End-to-End Example](./end_to_end_guide.md)** - Complete end-to-end examples including sharing your workflow in huggingface hub and deplying UI nodes diff --git a/docs/source/en/modular_diffusers/pipeline_block.md b/docs/source/en/modular_diffusers/pipeline_block.md new file mode 100644 index 000000000000..17a819732fd0 --- /dev/null +++ b/docs/source/en/modular_diffusers/pipeline_block.md @@ -0,0 +1,292 @@ + + +# PipelineBlock + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We support 4 different types of blocks: `PipelineBlock`, `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. Among them, `PipelineBlock` is the most fundamental building block of the whole system - it's like a brick in a Lego system. These blocks are designed to easily connect with each other, allowing for modular construction of creative and potentially very complex workflows. + + + +**Important**: `PipelineBlock`s are definitions/specifications, not runnable pipelines. They define what a block should do and what data it needs, but you need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](./modular_pipeline.md). + + + +In this tutorial, we will focus on how to write a basic `PipelineBlock` and how it interacts with the pipeline state. + +## PipelineState + +Before we dive into creating `PipelineBlock`s, make sure you have a basic understanding of `PipelineState`. It acts as the global state container that all blocks operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes. See the [PipelineState and BlockState guide](./modular_diffusers_states.md) for more details. + +## Define a `PipelineBlock` + +To write a `PipelineBlock` class, you need to define a few properties that determine how your block interacts with the pipeline state. Understanding these properties is crucial - they define what data your block can access and what it can produce. + +The three main properties you need to define are: +- `inputs`: Immutable values from the user that cannot be modified +- `intermediate_inputs`: Mutable values from previous blocks that can be read and modified +- `intermediate_outputs`: New values your block creates for subsequent blocks and user access + +Let's explore each one and understand how they work with the pipeline state. + +**Inputs: Immutable User Values** + +Inputs are variables your block needs from the immutable pipeline state - these are user-provided values that cannot be modified by any block. You define them using `InputParam`: + +```py +user_inputs = [ + InputParam(name="image", type_hint="PIL.Image", description="raw input image to process") +] +``` + +When you list something as an input, you're saying "I need this value directly from the end user, and I will talk to them directly, telling them what I need in the 'description' field. They will provide it and it will come to me unchanged." + +This is especially useful for raw values that serve as the "source of truth" in your workflow. For example, with a raw image, many workflows require preprocessing steps like resizing that a previous block might have performed. But in many cases, you also want the raw PIL image. In some inpainting workflows, you need the original image to overlay with the generated result for better control and consistency. + +**Intermediate Inputs: Mutable Values from Previous Blocks, or Users** + +Intermediate inputs are variables your block needs from the mutable pipeline state - these are values that can be read and modified. They're typically created by previous blocks, but could also be directly provided by the user if not the case: + +```py +user_intermediate_inputs = [ + InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"), +] +``` + +When you list something as an intermediate input, you're saying "I need this value, but I want to work with a different block that has already created it. I already know for sure that I can get it from this other block, but it's okay if other developers want use something different." + +**Intermediate Outputs: New Values for Subsequent Blocks and User Access** + +Intermediate outputs are new variables your block creates and adds to the mutable pipeline state. They serve two purposes: + +1. **For subsequent blocks**: They can be used as intermediate inputs by other blocks in the pipeline +2. **For users**: They become available as final outputs that users can access when running the pipeline + +```py +user_intermediate_outputs = [ + OutputParam(name="image_latents", description="latents representing the image") +] +``` + +Intermediate inputs and intermediate outputs work together like Lego studs and anti-studs - they're the connection points that make blocks modular. When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks. This is where the "modular" nature of the system really shines - blocks can be connected and reconnected in different ways as long as their inputs and outputs match. + +Additionally, all intermediate outputs are accessible to users when they run the pipeline, typically you would only need the final images, but they are also able to access intermediate results like latents, embeddings, or other processing steps. + +**The `__call__` Method Structure** + +Your `PipelineBlock`'s `__call__` method should follow this structure: + +```py +def __call__(self, components, state): + # Get a local view of the state variables this block needs + block_state = self.get_block_state(state) + + # Your computation logic here + # block_state contains all your inputs and intermediate_inputs + # You can access them like: block_state.image, block_state.processed_image + + # Update the pipeline state with your updated block_states + self.set_block_state(state, block_state) + return components, state +``` + +The `block_state` object contains all the variables you defined in `inputs` and `intermediate_inputs`, making them easily accessible for your computation. + +**Components and Configs** + +You can define the components and pipeline-level configs your block needs using `ComponentSpec` and `ConfigSpec`: + +```py +from diffusers import ComponentSpec, ConfigSpec + +# Define components your block needs +expected_components = [ + ComponentSpec(name="unet", type_hint=UNet2DConditionModel), + ComponentSpec(name="scheduler", type_hint=EulerDiscreteScheduler) +] + +# Define pipeline-level configs +expected_config = [ + ConfigSpec("force_zeros_for_empty_prompt", True) +] +``` + +**Components**: In the `ComponentSpec`, you must provide a `name` and ideally a `type_hint`. You can also specify a `default_creation_method` to indicate whether the component should be loaded from a pretrained model or created with default configurations. The actual loading details (`repo`, `subfolder`, `variant` and `revision` fields) are typically specified when creating the pipeline, as we covered in the [Modular Pipeline Guide](./modular_pipeline.md). + +**Configs**: Pipeline-level settings that control behavior across all blocks. + +When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the pipeline collects all component requirements from the blocks and fetches the loading specs from the modular repository. The components are then made available to your block as the first argument of the `__call__` method. You can access any component you need using dot notation: + +```py +def __call__(self, components, state): + # Access components using dot notation + unet = components.unet + vae = components.vae + scheduler = components.scheduler +``` + +That's all you need to define in order to create a `PipelineBlock`. There is no hidden complexity. In fact we are going to create a helper function that take exactly these variables as input and return a pipeline block. We will use this helper function through out the tutorial to create test blocks + +Note that for `__call__` method, the only part you should implement differently is the part between `self.get_block_state()` and `self.set_block_state()`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. Our helper function accepts a `block_fn` that does exactly that. + +**Helper Function** + +```py +from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam +import torch + +def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None): + class TestBlock(PipelineBlock): + model_name = "test" + + @property + def inputs(self): + return inputs + + @property + def intermediate_inputs(self): + return intermediate_inputs + + @property + def intermediate_outputs(self): + return intermediate_outputs + + @property + def description(self): + return description if description is not None else "" + + def __call__(self, components, state): + block_state = self.get_block_state(state) + if block_fn is not None: + block_state = block_fn(block_state, state) + self.set_block_state(state, block_state) + return components, state + + return TestBlock +``` + +## Example: Creating a Simple Pipeline Block + +Let's create a simple block to see how these definitions interact with the pipeline state. To better understand what's happening, we'll print out the states before and after updates to inspect them: + +```py +inputs = [ + InputParam(name="image", type_hint="PIL.Image", description="raw input image to process") +] + +intermediate_inputs = [InputParam(name="batch_size", type_hint=int)] + +intermediate_outputs = [ + OutputParam(name="image_latents", description="latents representing the image") +] + +def image_encoder_block_fn(block_state, pipeline_state): + print(f"pipeline_state (before update): {pipeline_state}") + print(f"block_state (before update): {block_state}") + + # Simulate processing the image + block_state.image = torch.randn(1, 3, 512, 512) + block_state.batch_size = block_state.batch_size * 2 + block_state.processed_image = [torch.randn(1, 3, 512, 512)] * block_state.batch_size + block_state.image_latents = torch.randn(1, 4, 64, 64) + + print(f"block_state (after update): {block_state}") + return block_state + +# Create a block with our definitions +image_encoder_block_cls = make_block( + inputs=inputs, + intermediate_inputs=intermediate_inputs, + intermediate_outputs=intermediate_outputs, + block_fn=image_encoder_block_fn, + description="Encode raw image into its latent presentation" +) +image_encoder_block = image_encoder_block_cls() +pipe = image_encoder_block.init_pipeline() +``` + +Let's check the pipeline's docstring to see what inputs it expects: +```py +>>> print(pipe.doc) +class TestBlock + + Encode raw image into its latent presentation + + Inputs: + + image (`PIL.Image`, *optional*): + raw input image to process + + batch_size (`int`, *optional*): + + Outputs: + + image_latents (`None`): + latents representing the image +``` + +Notice that `batch_size` appears as an input even though we defined it as an intermediate input. This happens because no previous block provided it, so the pipeline makes it available as a user input. However, unlike regular inputs, this value goes directly into the mutable intermediate state. + +Now let's run the pipeline: + +```py +from diffusers.utils import load_image + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png") +state = pipe(image=image, batch_size=2) +print(f"pipeline_state (after update): {state}") +``` +```out +pipeline_state (before update): PipelineState( + inputs={ + image: + }, + intermediates={ + batch_size: 2 + }, +) +block_state (before update): BlockState( + image: + batch_size: 2 +) + +block_state (after update): BlockState( + image: Tensor(dtype=torch.float32, shape=torch.Size([1, 3, 512, 512])) + batch_size: 4 + processed_image: List[4] of Tensors with shapes [torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512])] + image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64])) +) +pipeline_state (after update): PipelineState( + inputs={ + image: + }, + intermediates={ + batch_size: 4 + image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64])) + }, +) +``` + +**Key Observations:** + +1. **Before the update**: `image` (the input) goes to the immutable inputs dict, while `batch_size` (the intermediate_input) goes to the mutable intermediates dict, and both are available in `block_state`. + +2. **After the update**: + - **`image` (inputs)** changed in `block_state` but not in `pipeline_state` - this change is local to the block only. + - **`batch_size (intermediate_inputs)`** was updated in both `block_state` and `pipeline_state` - this change affects subsequent blocks (we didn't need to declare it as an intermediate output since it was already in the intermediates dict) + - **`image_latents (intermediate_outputs)`** was added to `pipeline_state` because it was declared as an intermediate output + - **`processed_image`** was not added to `pipeline_state` because it wasn't declared as an intermediate output \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md b/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md new file mode 100644 index 000000000000..a683f0d0659a --- /dev/null +++ b/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md @@ -0,0 +1,189 @@ + + +# SequentialPipelineBlocks + + + +🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes. + + + +`SequentialPipelineBlocks` is a subclass of `ModularPipelineBlocks`. Unlike `PipelineBlock`, it is a multi-block that composes other blocks together in sequence, creating modular workflows where data flows from one block to the next. It's one of the most common ways to build complex pipelines by combining simpler building blocks. + + + +Other types of multi-blocks include [AutoPipelineBlocks](auto_pipeline_blocks.md) (for conditional block selection) and [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows). For information on creating individual blocks, see the [PipelineBlock guide](pipeline_block.md). + +Additionally, like all `ModularPipelineBlocks`, `SequentialPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md). + + + +In this tutorial, we will focus on how to create `SequentialPipelineBlocks` and how blocks connect and work together. + +The key insight is that blocks connect through their intermediate inputs and outputs - the "studs and anti-studs" we discussed in the [PipelineBlock guide](pipeline_block.md). When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks. + +Let's explore this through an example. We will use the same helper function from the PipelineBlock guide to create blocks. + +```py +from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam +import torch + +def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None): + class TestBlock(PipelineBlock): + model_name = "test" + + @property + def inputs(self): + return inputs + + @property + def intermediate_inputs(self): + return intermediate_inputs + + @property + def intermediate_outputs(self): + return intermediate_outputs + + @property + def description(self): + return description if description is not None else "" + + def __call__(self, components, state): + block_state = self.get_block_state(state) + if block_fn is not None: + block_state = block_fn(block_state, state) + self.set_block_state(state, block_state) + return components, state + + return TestBlock +``` + +Let's create a block that produces `batch_size`, which we'll call "input_block": + +```py +def input_block_fn(block_state, pipeline_state): + + batch_size = len(block_state.prompt) + block_state.batch_size = batch_size * block_state.num_images_per_prompt + + return block_state + +input_block_cls = make_block( + inputs=[ + InputParam(name="prompt", type_hint=list, description="list of text prompts"), + InputParam(name="num_images_per_prompt", type_hint=int, description="number of images per prompt") + ], + intermediate_outputs=[ + OutputParam(name="batch_size", description="calculated batch size") + ], + block_fn=input_block_fn, + description="A block that determines batch_size based on the number of prompts and num_images_per_prompt argument." +) +input_block = input_block_cls() +``` + +Now let's create a second block that uses the `batch_size` from the first block: + +```py +def image_encoder_block_fn(block_state, pipeline_state): + # Simulate processing the image + block_state.image = torch.randn(1, 3, 512, 512) + block_state.batch_size = block_state.batch_size * 2 + block_state.image_latents = torch.randn(1, 4, 64, 64) + return block_state + +image_encoder_block_cls = make_block( + inputs=[ + InputParam(name="image", type_hint="PIL.Image", description="raw input image to process") + ], + intermediate_inputs=[ + InputParam(name="batch_size", type_hint=int) + ], + intermediate_outputs=[ + OutputParam(name="image_latents", description="latents representing the image") + ], + block_fn=image_encoder_block_fn, + description="Encode raw image into its latent presentation" +) +image_encoder_block = image_encoder_block_cls() +``` + +Now let's connect these blocks to create a `SequentialPipelineBlocks`: + +```py +from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict + +# Define a dict mapping block names to block instances +blocks_dict = InsertableDict() +blocks_dict["input"] = input_block +blocks_dict["image_encoder"] = image_encoder_block + +# Create the SequentialPipelineBlocks +blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict) +``` + +Now you have a `SequentialPipelineBlocks` with 2 blocks: + +```py +>>> blocks +SequentialPipelineBlocks( + Class: ModularPipelineBlocks + + Description: + + + Sub-Blocks: + [0] input (TestBlock) + Description: A block that determines batch_size based on the number of prompts and num_images_per_prompt argument. + + [1] image_encoder (TestBlock) + Description: Encode raw image into its latent presentation + +) +``` + +When you inspect `blocks.doc`, you can see that `batch_size` is not listed as an input. The pipeline automatically detects that the `input_block` can produce `batch_size` for the `image_encoder_block`, so it doesn't ask the user to provide it. + +```py +>>> print(blocks.doc) +class SequentialPipelineBlocks + + Inputs: + + prompt (`None`, *optional*): + + num_images_per_prompt (`None`, *optional*): + + image (`PIL.Image`, *optional*): + raw input image to process + + Outputs: + + batch_size (`None`): + + image_latents (`None`): + latents representing the image +``` + +At runtime, you have data flow like this: + +![Data Flow Diagram](https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/modular_quicktour/Editor%20_%20Mermaid%20Chart-2025-06-30-092631.png) + +**How SequentialPipelineBlocks Works:** + +1. Blocks are executed in the order they're registered in the `blocks_dict` +2. Outputs from one block become available as intermediate inputs to all subsequent blocks +3. The pipeline automatically figures out which values need to be provided by the user and which will be generated by previous blocks +4. Each block maintains its own behavior and operates through its defined interface, while collectively these interfaces determine what the entire pipeline accepts and produces + +What happens within each block follows the same pattern we described earlier: each block gets its own `block_state` with the relevant inputs and intermediate inputs, performs its computation, and updates the pipeline state with its intermediate outputs. \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 713472b4a517..ab80ddffec50 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -34,9 +34,11 @@ _import_structure = { "configuration_utils": ["ConfigMixin"], + "guiders": [], "hooks": [], "loaders": ["FromOriginalModelMixin"], "models": [], + "modular_pipelines": [], "pipelines": [], "quantizers.quantization_config": [], "schedulers": [], @@ -130,14 +132,29 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: + _import_structure["guiders"].extend( + [ + "AdaptiveProjectedGuidance", + "AutoGuidance", + "ClassifierFreeGuidance", + "ClassifierFreeZeroStarGuidance", + "PerturbedAttentionGuidance", + "SkipLayerGuidance", + "SmoothedEnergyGuidance", + "TangentialClassifierFreeGuidance", + ] + ) _import_structure["hooks"].extend( [ "FasterCacheConfig", "FirstBlockCacheConfig", "HookRegistry", + "LayerSkipConfig", "PyramidAttentionBroadcastConfig", + "SmoothedEnergyGuidanceConfig", "apply_faster_cache", "apply_first_block_cache", + "apply_layer_skip", "apply_pyramid_attention_broadcast", ] ) @@ -221,6 +238,14 @@ "WanVACETransformer3DModel", ] ) + _import_structure["modular_pipelines"].extend( + [ + "ComponentsManager", + "ComponentSpec", + "ModularPipeline", + "ModularPipelineBlocks", + ] + ) _import_structure["optimization"] = [ "get_constant_schedule", "get_constant_schedule_with_warmup", @@ -333,6 +358,12 @@ ] else: + _import_structure["modular_pipelines"].extend( + [ + "StableDiffusionXLAutoBlocks", + "StableDiffusionXLModularPipeline", + ] + ) _import_structure["pipelines"].extend( [ "AllegroPipeline", @@ -545,6 +576,7 @@ ] ) + try: if not (is_torch_available() and is_transformers_available() and is_opencv_available()): raise OptionalDependencyNotAvailable() @@ -751,13 +783,26 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: + from .guiders import ( + AdaptiveProjectedGuidance, + AutoGuidance, + ClassifierFreeGuidance, + ClassifierFreeZeroStarGuidance, + PerturbedAttentionGuidance, + SkipLayerGuidance, + SmoothedEnergyGuidance, + TangentialClassifierFreeGuidance, + ) from .hooks import ( FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, + LayerSkipConfig, PyramidAttentionBroadcastConfig, + SmoothedEnergyGuidanceConfig, apply_faster_cache, apply_first_block_cache, + apply_layer_skip, apply_pyramid_attention_broadcast, ) from .models import ( @@ -837,6 +882,12 @@ WanTransformer3DModel, WanVACETransformer3DModel, ) + from .modular_pipelines import ( + ComponentsManager, + ComponentSpec, + ModularPipeline, + ModularPipelineBlocks, + ) from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, @@ -933,6 +984,10 @@ except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: + from .modular_pipelines import ( + StableDiffusionXLAutoBlocks, + StableDiffusionXLModularPipeline, + ) from .pipelines import ( AllegroPipeline, AltDiffusionImg2ImgPipeline, diff --git a/src/diffusers/commands/custom_blocks.py b/src/diffusers/commands/custom_blocks.py new file mode 100644 index 000000000000..43d9ea88577a --- /dev/null +++ b/src/diffusers/commands/custom_blocks.py @@ -0,0 +1,134 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage example: + TODO +""" + +import ast +import importlib.util +import os +from argparse import ArgumentParser, Namespace +from pathlib import Path + +from ..utils import logging +from . import BaseDiffusersCLICommand + + +EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"] +CONFIG = "config.json" + + +def conversion_command_factory(args: Namespace): + return CustomBlocksCommand(args.block_module_name, args.block_class_name) + + +class CustomBlocksCommand(BaseDiffusersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + conversion_parser = parser.add_parser("custom_blocks") + conversion_parser.add_argument( + "--block_module_name", + type=str, + default="block.py", + help="Module filename in which the custom block will be implemented.", + ) + conversion_parser.add_argument( + "--block_class_name", + type=str, + default=None, + help="Name of the custom block. If provided None, we will try to infer it.", + ) + conversion_parser.set_defaults(func=conversion_command_factory) + + def __init__(self, block_module_name: str = "block.py", block_class_name: str = None): + self.logger = logging.get_logger("diffusers-cli/custom_blocks") + self.block_module_name = Path(block_module_name) + self.block_class_name = block_class_name + + def run(self): + # determine the block to be saved. + out = self._get_class_names(self.block_module_name) + classes_found = list({cls for cls, _ in out}) + + if self.block_class_name is not None: + child_class, parent_class = self._choose_block(out, self.block_class_name) + if child_class is None and parent_class is None: + raise ValueError( + "`block_class_name` could not be retrieved. Available classes from " + f"{self.block_module_name}:\n{classes_found}" + ) + else: + self.logger.info( + f"Found classes: {classes_found} will be using {classes_found[0]}. " + "If this needs to be changed, re-run the command specifying `block_class_name`." + ) + child_class, parent_class = out[0][0], out[0][1] + + # dynamically get the custom block and initialize it to call `save_pretrained` in the current directory. + # the user is responsible for running it, so I guess that is safe? + module_name = f"__dynamic__{self.block_module_name.stem}" + spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + getattr(module, child_class)().save_pretrained(os.getcwd()) + + # or, we could create it manually. + # automap = self._create_automap(parent_class=parent_class, child_class=child_class) + # with open(CONFIG, "w") as f: + # json.dump(automap, f) + with open("requirements.txt", "w") as f: + f.write("") + + def _choose_block(self, candidates, chosen=None): + for cls, base in candidates: + if cls == chosen: + return cls, base + return None, None + + def _get_class_names(self, file_path): + source = file_path.read_text(encoding="utf-8") + try: + tree = ast.parse(source, filename=file_path) + except SyntaxError as e: + raise ValueError(f"Could not parse {file_path!r}: {e}") from e + + results: list[tuple[str, str]] = [] + for node in tree.body: + if not isinstance(node, ast.ClassDef): + continue + + # extract all base names for this class + base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None] + + # for each allowed base that appears in the class's bases, emit a tuple + for allowed in EXPECTED_PARENT_CLASSES: + if allowed in base_names: + results.append((node.name, allowed)) + + return results + + def _get_base_name(self, node: ast.expr): + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + val = self._get_base_name(node.value) + return f"{val}.{node.attr}" if val else node.attr + return None + + def _create_automap(self, parent_class, child_class): + module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1] + auto_map = {f"{parent_class}": f"{module}.{child_class}"} + return {"auto_map": auto_map} diff --git a/src/diffusers/commands/diffusers_cli.py b/src/diffusers/commands/diffusers_cli.py index 3c744c5c4c95..a27ac24f2a3e 100644 --- a/src/diffusers/commands/diffusers_cli.py +++ b/src/diffusers/commands/diffusers_cli.py @@ -15,6 +15,7 @@ from argparse import ArgumentParser +from .custom_blocks import CustomBlocksCommand from .env import EnvironmentCommand from .fp16_safetensors import FP16SafetensorsCommand @@ -26,6 +27,7 @@ def main(): # Register commands EnvironmentCommand.register_subcommand(commands_parser) FP16SafetensorsCommand.register_subcommand(commands_parser) + CustomBlocksCommand.register_subcommand(commands_parser) # Let's go args = parser.parse_args() diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index f9b652bbc021..048ddcae32f9 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -176,6 +176,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool token = kwargs.pop("token", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + subfolder = kwargs.pop("subfolder", None) self._upload_folder( save_directory, @@ -183,6 +184,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool token=token, commit_message=commit_message, create_pr=create_pr, + subfolder=subfolder, ) @classmethod @@ -601,6 +603,10 @@ def to_json_saveable(value): value = value.tolist() elif isinstance(value, Path): value = value.as_posix() + elif hasattr(value, "to_dict") and callable(value.to_dict): + value = value.to_dict() + elif isinstance(value, list): + value = [to_json_saveable(v) for v in value] return value if "quantization_config" in config_dict: diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py new file mode 100644 index 000000000000..1c288f00f084 --- /dev/null +++ b/src/diffusers/guiders/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from ..utils import is_torch_available + + +if is_torch_available(): + from .adaptive_projected_guidance import AdaptiveProjectedGuidance + from .auto_guidance import AutoGuidance + from .classifier_free_guidance import ClassifierFreeGuidance + from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance + from .perturbed_attention_guidance import PerturbedAttentionGuidance + from .skip_layer_guidance import SkipLayerGuidance + from .smoothed_energy_guidance import SmoothedEnergyGuidance + from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance + + GuiderType = Union[ + AdaptiveProjectedGuidance, + AutoGuidance, + ClassifierFreeGuidance, + ClassifierFreeZeroStarGuidance, + PerturbedAttentionGuidance, + SkipLayerGuidance, + SmoothedEnergyGuidance, + TangentialClassifierFreeGuidance, + ] diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py new file mode 100644 index 000000000000..81137db106a0 --- /dev/null +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -0,0 +1,188 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class AdaptiveProjectedGuidance(BaseGuidance): + """ + Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + adaptive_projected_guidance_momentum (`float`, defaults to `None`): + The momentum parameter for the adaptive projected guidance. Disabled if set to `None`. + adaptive_projected_guidance_rescale (`float`, defaults to `15.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + @register_to_config + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum: Optional[float] = None, + adaptive_projected_guidance_rescale: float = 15.0, + eta: float = 1.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: + if input_fields is None: + input_fields = self._input_fields + + if self._step == 0: + if self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_apg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_apg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_apg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +def normalized_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer: Optional[MomentumBuffer] = None, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] + + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + normalized_update = diff_orthogonal + eta * diff_parallel + + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale * normalized_update + + return pred diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py new file mode 100644 index 000000000000..e1642211d393 --- /dev/null +++ b/src/diffusers/guiders/auto_guidance.py @@ -0,0 +1,190 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class AutoGuidance(BaseGuidance): + """ + AutoGuidance: https://huggingface.co/papers/2406.02507 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + auto_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. + auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. + dropout (`float`, *optional*): + The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or + `auto_guidance_config`). If not provided, the dropout probability will be set to 1.0. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + @register_to_config + def __init__( + self, + guidance_scale: float = 7.5, + auto_guidance_layers: Optional[Union[int, List[int]]] = None, + auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None, + dropout: Optional[float] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.auto_guidance_layers = auto_guidance_layers + self.auto_guidance_config = auto_guidance_config + self.dropout = dropout + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if auto_guidance_layers is None and auto_guidance_config is None: + raise ValueError( + "Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance." + ) + if auto_guidance_layers is not None and auto_guidance_config is not None: + raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.") + if (dropout is None and auto_guidance_layers is not None) or ( + dropout is not None and auto_guidance_layers is None + ): + raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.") + + if auto_guidance_layers is not None: + if isinstance(auto_guidance_layers, int): + auto_guidance_layers = [auto_guidance_layers] + if not isinstance(auto_guidance_layers, list): + raise ValueError( + f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}." + ) + auto_guidance_config = [ + LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers + ] + + if isinstance(auto_guidance_config, dict): + auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config) + + if isinstance(auto_guidance_config, LayerSkipConfig): + auto_guidance_config = [auto_guidance_config] + + if not isinstance(auto_guidance_config, list): + raise ValueError( + f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}." + ) + elif isinstance(next(iter(auto_guidance_config), None), dict): + auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config] + + self.auto_guidance_config = auto_guidance_config + self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + self._count_prepared += 1 + if self._is_ag_enabled() and self.is_unconditional: + for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + if self._is_ag_enabled() and self.is_unconditional: + for name in self._auto_guidance_hook_names: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + registry.remove_hook(name, recurse=True) + + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_ag_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_ag_enabled(): + num_conditions += 1 + return num_conditions + + def _is_ag_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py new file mode 100644 index 000000000000..7e72b92fcee2 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -0,0 +1,141 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class ClassifierFreeGuidance(BaseGuidance): + """ + Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598 + + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by + jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during + inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper + proposes scaling and shifting the conditional distribution based on the difference between conditional and + unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + + Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen + paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in + theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + + The intution behind the original formulation can be thought of as moving the conditional distribution estimates + further away from the unconditional distribution estimates, while the diffusers-native implementation can be + thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of + the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) + + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the + paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + @register_to_config + def __init__( + self, + guidance_scale: float = 7.5, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py new file mode 100644 index 000000000000..85d5cc62d4e7 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -0,0 +1,152 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class ClassifierFreeZeroStarGuidance(BaseGuidance): + """ + Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886 + + This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free + guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion + process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the + quality of generated images. + + The authors of the paper suggest setting zero initialization in the first 4% of the inference steps. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + zero_init_steps (`int`, defaults to `1`): + The number of inference steps for which the noise predictions are zeroed out (see Section 4.2). + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + @register_to_config + def __init__( + self, + guidance_scale: float = 7.5, + zero_init_steps: int = 1, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.zero_init_steps = zero_init_steps + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if self._step < self.zero_init_steps: + pred = torch.zeros_like(pred_cond) + elif not self._is_cfg_enabled(): + pred = pred_cond + else: + pred_cond_flat = pred_cond.flatten(1) + pred_uncond_flat = pred_uncond.flatten(1) + alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat) + alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1)) + pred_uncond = pred_uncond * alpha + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + cond_dtype = cond.dtype + cond = cond.float() + uncond = uncond.float() + dot_product = torch.sum(cond * uncond, dim=1, keepdim=True) + squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + scale = dot_product / squared_norm + return scale.to(dtype=cond_dtype) diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py new file mode 100644 index 000000000000..1c0b8cb286e7 --- /dev/null +++ b/src/diffusers/guiders/guider_utils.py @@ -0,0 +1,309 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch +from huggingface_hub.utils import validate_hf_hub_args +from typing_extensions import Self + +from ..configuration_utils import ConfigMixin +from ..utils import PushToHubMixin, get_logger + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +GUIDER_CONFIG_NAME = "guider_config.json" + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class BaseGuidance(ConfigMixin, PushToHubMixin): + r"""Base class providing the skeleton for implementing guidance techniques.""" + + config_name = GUIDER_CONFIG_NAME + _input_predictions = None + _identifier_key = "__guidance_identifier__" + + def __init__(self, start: float = 0.0, stop: float = 1.0): + self._start = start + self._stop = stop + self._step: int = None + self._num_inference_steps: int = None + self._timestep: torch.LongTensor = None + self._count_prepared = 0 + self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None + self._enabled = True + + if not (0.0 <= start < 1.0): + raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.") + if not (start <= stop <= 1.0): + raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.") + + if self._input_predictions is None or not isinstance(self._input_predictions, list): + raise ValueError( + "`_input_predictions` must be a list of required prediction names for the guidance technique." + ) + + def disable(self): + self._enabled = False + + def enable(self): + self._enabled = True + + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: + self._step = step + self._num_inference_steps = num_inference_steps + self._timestep = timestep + self._count_prepared = 0 + + def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None: + """ + Set the input fields for the guidance technique. The input fields are used to specify the names of the returned + attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from + the values of the provided keyword arguments to this method. + + Args: + **kwargs (`Dict[str, Union[str, Tuple[str, str]]]`): + A dictionary where the keys are the names of the fields that will be used to store the data once it is + prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used + to look up the required data provided for preparation. + + If a string is provided, it will be used as the conditional data (or unconditional if used with a + guidance method that requires it). If a tuple of length 2 is provided, the first element must be the + conditional data identifier and the second element must be the unconditional data identifier or None. + + Example: + ``` + data = {"prompt_embeds": , "negative_prompt_embeds": , "latents": } + + BaseGuidance.set_input_fields( + latents="latents", + prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), + ) + ``` + """ + for key, value in kwargs.items(): + is_string = isinstance(value, str) + is_tuple_of_str_with_len_2 = ( + isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value) + ) + if not (is_string or is_tuple_of_str_with_len_2): + raise ValueError( + f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}." + ) + self._input_fields = kwargs + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + """ + Prepares the models for the guidance technique on a given batch of data. This method should be overridden in + subclasses to implement specific model preparation logic. + """ + self._count_prepared += 1 + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + """ + Cleans up the models for the guidance technique after a given batch of data. This method should be overridden + in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful + modifications made during `prepare_models`. + """ + pass + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") + + def __call__(self, data: List["BlockState"]) -> Any: + if not all(hasattr(d, "noise_pred") for d in data): + raise ValueError("Expected all data to have `noise_pred` attribute.") + if len(data) != self.num_conditions: + raise ValueError( + f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data." + ) + forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data} + return self.forward(**forward_inputs) + + def forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") + + @property + def is_conditional(self) -> bool: + raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") + + @property + def is_unconditional(self) -> bool: + return not self.is_conditional + + @property + def num_conditions(self) -> int: + raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.") + + @classmethod + def _prepare_batch( + cls, + input_fields: Dict[str, Union[str, Tuple[str, str]]], + data: "BlockState", + tuple_index: int, + identifier: str, + ) -> "BlockState": + """ + Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the + `BaseGuidance` class. It prepares the batch based on the provided tuple index. + + Args: + input_fields (`Dict[str, Union[str, Tuple[str, str]]]`): + A dictionary where the keys are the names of the fields that will be used to store the data once it is + prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used + to look up the required data provided for preparation. If a string is provided, it will be used as the + conditional data (or unconditional if used with a guidance method that requires it). If a tuple of + length 2 is provided, the first element must be the conditional data identifier and the second element + must be the unconditional data identifier or None. + data (`BlockState`): + The input data to be prepared. + tuple_index (`int`): + The index to use when accessing input fields that are tuples. + + Returns: + `BlockState`: The prepared batch of data. + """ + from ..modular_pipelines.modular_pipeline import BlockState + + if input_fields is None: + raise ValueError( + "Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs." + ) + data_batch = {} + for key, value in input_fields.items(): + try: + if isinstance(value, str): + data_batch[key] = getattr(data, value) + elif isinstance(value, tuple): + data_batch[key] = getattr(data, value[tuple_index]) + else: + # We've already checked that value is a string or a tuple of strings with length 2 + pass + except AttributeError: + logger.debug(f"`data` does not have attribute(s) {value}, skipping.") + data_batch[cls._identifier_key] = identifier + return BlockState(**data_batch) + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ) -> Self: + r""" + Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the guider configuration + saved with [`~BaseGuidance.save_pretrained`]. + subfolder (`str`, *optional*): + The subfolder location of a model file within a larger model repository on the Hub or locally. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + """ + config, kwargs, commit_hash = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + return_commit_hash=True, + **kwargs, + ) + return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a guider configuration object to a directory so that it can be reloaded using the + [`~BaseGuidance.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py new file mode 100644 index 000000000000..1b2256732ffc --- /dev/null +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -0,0 +1,271 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from ..utils import get_logger +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class PerturbedAttentionGuidance(BaseGuidance): + """ + Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377 + + The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from + worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea + of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the + attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen + layers. + + Additional reading: + - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) + + PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters + and implementation details. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + perturbed_guidance_scale (`float`, defaults to `2.8`): + The scale parameter for perturbed attention guidance. + perturbed_guidance_start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which perturbed attention guidance starts. + perturbed_guidance_stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which perturbed attention guidance stops. + perturbed_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers. + If not provided, `perturbed_guidance_config` must be provided. + perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + # NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in + # the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very + # complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation + # for each model architecture. + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + + @register_to_config + def __init__( + self, + guidance_scale: float = 7.5, + perturbed_guidance_scale: float = 2.8, + perturbed_guidance_start: float = 0.01, + perturbed_guidance_stop: float = 0.2, + perturbed_guidance_layers: Optional[Union[int, List[int]]] = None, + perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.skip_layer_guidance_scale = perturbed_guidance_scale + self.skip_layer_guidance_start = perturbed_guidance_start + self.skip_layer_guidance_stop = perturbed_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if perturbed_guidance_config is None: + if perturbed_guidance_layers is None: + raise ValueError( + "`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified." + ) + perturbed_guidance_config = LayerSkipConfig( + indices=perturbed_guidance_layers, + fqn="auto", + skip_attention=False, + skip_attention_scores=True, + skip_ff=False, + ) + else: + if perturbed_guidance_layers is not None: + raise ValueError( + "`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified." + ) + + if isinstance(perturbed_guidance_config, dict): + perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config) + + if isinstance(perturbed_guidance_config, LayerSkipConfig): + perturbed_guidance_config = [perturbed_guidance_config] + + if not isinstance(perturbed_guidance_config, list): + raise ValueError( + "`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`." + ) + elif isinstance(next(iter(perturbed_guidance_config), None), dict): + perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config] + + for config in perturbed_guidance_config: + if config.skip_attention or not config.skip_attention_scores or config.skip_ff: + logger.warning( + "Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. " + "Please check your configuration. Modifying the config to match the expected values." + ) + config.skip_attention = False + config.skip_attention_scores = True + config.skip_ff = False + + self.skip_layer_config = perturbed_guidance_config + self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models + def prepare_models(self, denoiser: torch.nn.Module) -> None: + self._count_prepared += 1 + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._skip_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: + if input_fields is None: + input_fields = self._input_fields + + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ( + ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + ) + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_skip: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_slg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift + elif not self._is_slg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_skip = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional + def is_conditional(self) -> bool: + return self._count_prepared == 1 or self._count_prepared == 3 + + @property + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_slg_enabled(): + num_conditions += 1 + return num_conditions + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled + def _is_slg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py new file mode 100644 index 000000000000..68a657960a45 --- /dev/null +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -0,0 +1,262 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class SkipLayerGuidance(BaseGuidance): + """ + Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 + + Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664 + + SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by + skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional + batch of data, apart from the conditional and unconditional batches already used in CFG + ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions + based on the difference between conditional without skipping and conditional with skipping predictions. + + The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from + worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse + version of the model for the conditional prediction). + + STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving + generation quality in video diffusion models. + + Additional reading: + - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) + + The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are + defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + skip_layer_guidance_scale (`float`, defaults to `2.8`): + The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + skip_layer_guidance_start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which skip layer guidance starts. + skip_layer_guidance_stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which skip layer guidance stops. + skip_layer_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + + @register_to_config + def __init__( + self, + guidance_scale: float = 7.5, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_start: float = 0.01, + skip_layer_guidance_stop: float = 0.2, + skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None, + skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.skip_layer_guidance_scale = skip_layer_guidance_scale + self.skip_layer_guidance_start = skip_layer_guidance_start + self.skip_layer_guidance_stop = skip_layer_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if not (0.0 <= skip_layer_guidance_start < 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}." + ) + if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}." + ) + + if skip_layer_guidance_layers is None and skip_layer_config is None: + raise ValueError( + "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." + ) + if skip_layer_guidance_layers is not None and skip_layer_config is not None: + raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.") + + if skip_layer_guidance_layers is not None: + if isinstance(skip_layer_guidance_layers, int): + skip_layer_guidance_layers = [skip_layer_guidance_layers] + if not isinstance(skip_layer_guidance_layers, list): + raise ValueError( + f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}." + ) + skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers] + + if isinstance(skip_layer_config, dict): + skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config) + + if isinstance(skip_layer_config, LayerSkipConfig): + skip_layer_config = [skip_layer_config] + + if not isinstance(skip_layer_config, list): + raise ValueError( + f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}." + ) + elif isinstance(next(iter(skip_layer_config), None), dict): + skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config] + + self.skip_layer_config = skip_layer_config + self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + self._count_prepared += 1 + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._skip_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: + if input_fields is None: + input_fields = self._input_fields + + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ( + ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + ) + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_skip: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_slg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift + elif not self._is_slg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_skip = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 or self._count_prepared == 3 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_slg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + def _is_slg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py new file mode 100644 index 000000000000..d8e8a3cf2fa8 --- /dev/null +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -0,0 +1,251 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from ..hooks import HookRegistry +from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class SmoothedEnergyGuidance(BaseGuidance): + """ + Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the + future without warning or guarantee of reproducibility. This implementation assumes: + - Generated images are square (height == width) + - The model does not combine different modalities together (e.g., text and image latent streams are not combined + together such as Flux) + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + seg_guidance_scale (`float`, defaults to `3.0`): + The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + seg_blur_sigma (`float`, defaults to `9999999.0`): + The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in + infinite blur, which means uniform queries. Controlling it exponentially is empirically effective. + seg_blur_threshold_inf (`float`, defaults to `9999.0`): + The threshold above which the blur is considered infinite. + seg_guidance_start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance starts. + seg_guidance_stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance stops. + seg_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If + not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable + Diffusion 3.5 Medium. + seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*): + The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or + a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + + @register_to_config + def __init__( + self, + guidance_scale: float = 7.5, + seg_guidance_scale: float = 2.8, + seg_blur_sigma: float = 9999999.0, + seg_blur_threshold_inf: float = 9999.0, + seg_guidance_start: float = 0.0, + seg_guidance_stop: float = 1.0, + seg_guidance_layers: Optional[Union[int, List[int]]] = None, + seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.seg_guidance_scale = seg_guidance_scale + self.seg_blur_sigma = seg_blur_sigma + self.seg_blur_threshold_inf = seg_blur_threshold_inf + self.seg_guidance_start = seg_guidance_start + self.seg_guidance_stop = seg_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if not (0.0 <= seg_guidance_start < 1.0): + raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.") + if not (seg_guidance_start <= seg_guidance_stop <= 1.0): + raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.") + + if seg_guidance_layers is None and seg_guidance_config is None: + raise ValueError( + "Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance." + ) + if seg_guidance_layers is not None and seg_guidance_config is not None: + raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.") + + if seg_guidance_layers is not None: + if isinstance(seg_guidance_layers, int): + seg_guidance_layers = [seg_guidance_layers] + if not isinstance(seg_guidance_layers, list): + raise ValueError( + f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}." + ) + seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers] + + if isinstance(seg_guidance_config, dict): + seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config) + + if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig): + seg_guidance_config = [seg_guidance_config] + + if not isinstance(seg_guidance_config, list): + raise ValueError( + f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}." + ) + elif isinstance(next(iter(seg_guidance_config), None), dict): + seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config] + + self.seg_guidance_config = seg_guidance_config + self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: + for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config): + _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name) + + def cleanup_models(self, denoiser: torch.nn.Module): + if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._seg_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: + if input_fields is None: + input_fields = self._input_fields + + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ( + ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"] + ) + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_seg: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_seg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_cond_seg + pred = pred + self.seg_guidance_scale * shift + elif not self._is_seg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_seg = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 or self._count_prepared == 3 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_seg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + def _is_seg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.seg_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.seg_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py new file mode 100644 index 000000000000..b3187e526316 --- /dev/null +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -0,0 +1,143 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class TangentialClassifierFreeGuidance(BaseGuidance): + """ + Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + @register_to_config + def __init__( + self, + guidance_scale: float = 7.5, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs( + self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None + ) -> List["BlockState"]: + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_tcfg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_tcfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_tcfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def normalized_guidance( + pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False +) -> torch.Tensor: + cond_dtype = pred_cond.dtype + preds = torch.stack([pred_cond, pred_uncond], dim=1).float() + preds = preds.flatten(2) + U, S, Vh = torch.linalg.svd(preds, full_matrices=False) + Vh_modified = Vh.clone() + Vh_modified[:, 1] = 0 + + uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float() + x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1)) + x_Vh_V = torch.matmul(x_Vh, Vh_modified) + pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype) + + pred = pred_cond if use_original_formulation else pred_uncond + shift = pred_cond - pred_uncond + pred = pred + guidance_scale * shift + + return pred diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 365bed371864..525a0747da8b 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -20,5 +20,7 @@ from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook + from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py index 3be77dd4cedf..08f474fc1cc7 100644 --- a/src/diffusers/hooks/_common.py +++ b/src/diffusers/hooks/_common.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +import torch + +from ..models.attention import FeedForward, LuminaFeedForward from ..models.attention_processor import Attention, MochiAttention _ATTENTION_CLASSES = (Attention, MochiAttention) +_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) @@ -28,3 +34,10 @@ *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, } ) + + +def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: + for submodule_name, submodule in module.named_modules(): + if submodule_name == fqn: + return submodule + return None diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py new file mode 100644 index 000000000000..14e6c2f8881e --- /dev/null +++ b/src/diffusers/hooks/layer_skip.py @@ -0,0 +1,254 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import asdict, dataclass +from typing import Callable, List, Optional + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import ( + _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, + _ATTENTION_CLASSES, + _FEEDFORWARD_CLASSES, + _get_submodule_from_fqn, +) +from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_LAYER_SKIP_HOOK = "layer_skip_hook" + + +# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed +# either remove or make it serializable +@dataclass +class LayerSkipConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must + provide the correct fqn. + skip_attention (`bool`, defaults to `True`): + Whether to skip attention blocks. + skip_ff (`bool`, defaults to `True`): + Whether to skip feed-forward blocks. + skip_attention_scores (`bool`, defaults to `False`): + Whether to skip attention score computation in the attention blocks. This is equivalent to using `value` + projections as the output of scaled dot product attention. + dropout (`float`, defaults to `1.0`): + The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`, + meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the + skipped layers are fully retained, which is equivalent to not skipping any layers. + """ + + indices: List[int] + fqn: str = "auto" + skip_attention: bool = True + skip_attention_scores: bool = False + skip_ff: bool = True + dropout: float = 1.0 + + def __post_init__(self): + if not (0 <= self.dropout <= 1): + raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.") + if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores: + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) + + def to_dict(self): + return asdict(self) + + @staticmethod + def from_dict(data: dict) -> "LayerSkipConfig": + return LayerSkipConfig(**data) + + +class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func is torch.nn.functional.scaled_dot_product_attention: + value = kwargs.get("value", None) + if value is None: + value = args[2] + return value + return func(*args, **kwargs) + + +class AttentionProcessorSkipHook(ModelHook): + def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0): + self.skip_processor_output_fn = skip_processor_output_fn + self.skip_attention_scores = skip_attention_scores + self.dropout = dropout + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.skip_attention_scores: + if not math.isclose(self.dropout, 1.0): + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) + with AttentionScoreSkipFunctionMode(): + output = self.fn_ref.original_forward(*args, **kwargs) + else: + if math.isclose(self.dropout, 1.0): + output = self.skip_processor_output_fn(module, *args, **kwargs) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output + + +class FeedForwardSkipHook(ModelHook): + def __init__(self, dropout: float): + super().__init__() + self.dropout = dropout + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if math.isclose(self.dropout, 1.0): + output = kwargs.get("hidden_states", None) + if output is None: + output = kwargs.get("x", None) + if output is None and len(args) > 0: + output = args[0] + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output + + +class TransformerBlockSkipHook(ModelHook): + def __init__(self, dropout: float): + super().__init__() + self.dropout = dropout + + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if math.isclose(self.dropout, 1.0): + original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) + if self._metadata.return_encoder_hidden_states_index is None: + output = original_hidden_states + else: + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + output = (original_hidden_states, original_encoder_hidden_states) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output + + +def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: + r""" + Apply layer skipping to internal layers of a transformer. + + Args: + module (`torch.nn.Module`): + The transformer model to which the layer skip hook should be applied. + config (`LayerSkipConfig`): + The configuration for the layer skip hook. + + Example: + + ```python + >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig + + >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks") + >>> apply_layer_skip_hook(transformer, config) + ``` + """ + _apply_layer_skip_hook(module, config) + + +def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None: + name = name or _LAYER_SKIP_HOOK + + if config.skip_attention and config.skip_attention_scores: + raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.") + if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores: + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) + if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList): + raise ValueError( + f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify " + f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks." + ) + if len(config.indices) == 0: + raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.") + + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + + blocks_found = True + + if config.skip_attention and config.skip_ff: + logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'") + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = TransformerBlockSkipHook(config.dropout) + registry.register_hook(hook, name) + + elif config.skip_attention or config.skip_attention_scores: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention: + logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'") + output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout) + registry.register_hook(hook, name) + + if config.skip_ff: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _FEEDFORWARD_CLASSES): + logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'") + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = FeedForwardSkipHook(config.dropout) + registry.register_hook(hook, name) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py new file mode 100644 index 000000000000..622f60764762 --- /dev/null +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -0,0 +1,167 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import asdict, dataclass +from typing import List, Optional + +import torch +import torch.nn.functional as F + +from ..utils import get_logger +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _get_submodule_from_fqn +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook" + + +@dataclass +class SmoothedEnergyGuidanceConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must + provide the correct fqn. + _query_proj_identifiers (`List[str]`, defaults to `None`): + The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If + `None`, `to_q` is used by default. + """ + + indices: List[int] + fqn: str = "auto" + _query_proj_identifiers: List[str] = None + + def to_dict(self): + return asdict(self) + + @staticmethod + def from_dict(data: dict) -> "SmoothedEnergyGuidanceConfig": + return SmoothedEnergyGuidanceConfig(**data) + + +class SmoothedEnergyGuidanceHook(ModelHook): + def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None: + super().__init__() + self.blur_sigma = blur_sigma + self.blur_threshold_inf = blur_threshold_inf + + def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor: + # Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102 + kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2 + smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf) + return smoothed_output + + +def _apply_smoothed_energy_guidance_hook( + module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None +) -> None: + name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + if config._query_proj_identifiers is None: + config._query_proj_identifiers = ["to_q"] + + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + + blocks_found = True + + for submodule_name, submodule in block.named_modules(): + if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention: + continue + for identifier in config._query_proj_identifiers: + query_proj = getattr(submodule, identifier, None) + if query_proj is None or not isinstance(query_proj, torch.nn.Linear): + continue + logger.debug( + f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}" + ) + registry = HookRegistry.check_if_exists_or_initialize(query_proj) + hook = SmoothedEnergyGuidanceHook(blur_sigma) + registry.register_hook(hook, name) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) + + +# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71 +def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor: + """ + This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian blur. + However, some models use joint text-visual token attention for which this may not be suitable. Additionally, this + implementation also assumes that the visual tokens come from a square image/video. In practice, despite these + assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results for + Smoothed Energy Guidance. + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the + future without warning or guarantee of reproducibility. + """ + assert query.ndim == 3 + + is_inf = sigma > sigma_threshold_inf + batch_size, seq_len, embed_dim = query.shape + + seq_len_sqrt = int(math.sqrt(seq_len)) + num_square_tokens = seq_len_sqrt * seq_len_sqrt + query_slice = query[:, :num_square_tokens, :] + query_slice = query_slice.permute(0, 2, 1) + query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt) + + if is_inf: + kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1)) + kernel_size_half = (kernel_size - 1) / 2 + + x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size) + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + kernel1d = kernel1d.to(query) + kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :]) + kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + query_slice = F.pad(query_slice, padding, mode="reflect") + query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim) + else: + query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True) + + query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens) + query_slice = query_slice.permute(0, 2, 1) + query[:, :num_square_tokens, :] = query_slice.clone() + + return query diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 84c6d9f32c66..335d7e623f07 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -84,6 +84,7 @@ def text_encoder_attn_modules(text_encoder): "IPAdapterMixin", "FluxIPAdapterMixin", "SD3IPAdapterMixin", + "ModularIPAdapterMixin", ] _import_structure["peft"] = ["PeftAdapterMixin"] @@ -101,6 +102,7 @@ def text_encoder_attn_modules(text_encoder): from .ip_adapter import ( FluxIPAdapterMixin, IPAdapterMixin, + ModularIPAdapterMixin, SD3IPAdapterMixin, ) from .lora_pipeline import ( diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 521cb3b6fddf..e05d53687a24 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -354,6 +354,256 @@ def unload_ip_adapter(self): self.unet.set_attn_processor(attn_procs) +class ModularIPAdapterMixin: + """Mixin for handling IP Adapters.""" + + @validate_hf_hub_args + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]], + subfolder: Union[str, List[str]], + weight_name: Union[str, List[str]], + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + subfolder (`str` or `List[str]`): + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. + weight_name (`str` or `List[str]`): + The name of the weight file to load. If a list is passed, it should have the same length as + `subfolder`. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + + # handle the list inputs for multiple IP Adapters + if not isinstance(weight_name, list): + weight_name = [weight_name] + + if not isinstance(pretrained_model_name_or_path_or_dict, list): + pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] + if len(pretrained_model_name_or_path_or_dict) == 1: + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name) + + if not isinstance(subfolder, list): + subfolder = [subfolder] + if len(subfolder) == 1: + subfolder = subfolder * len(weight_name) + + if len(weight_name) != len(pretrained_model_name_or_path_or_dict): + raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.") + + if len(weight_name) != len(subfolder): + raise ValueError("`weight_name` and `subfolder` must have the same length.") + + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + state_dicts = [] + for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( + pretrained_model_name_or_path_or_dict, weight_name, subfolder + ): + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if "image_proj" not in keys and "ip_adapter" not in keys: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + state_dicts.append(state_dict) + + unet_name = getattr(self, "unet_name", "unet") + unet = getattr(self, unet_name) + unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + + extra_loras = unet._load_ip_adapter_loras(state_dicts) + if extra_loras != {}: + if not USE_PEFT_BACKEND: + logger.warning("PEFT backend is required to load these weights.") + else: + # apply the IP Adapter Face ID LoRA weights + peft_config = getattr(unet, "peft_config", {}) + for k, lora in extra_loras.items(): + if f"faceid_{k}" not in peft_config: + self.load_lora_weights(lora, adapter_name=f"faceid_{k}") + self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0]) + + def set_ip_adapter_scale(self, scale): + """ + Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for + granular control over each IP-Adapter behavior. A config can be a float or a dictionary. + + Example: + + ```py + # To use original IP-Adapter + scale = 1.0 + pipeline.set_ip_adapter_scale(scale) + + # To use style block only + scale = { + "up": {"block_0": [0.0, 1.0, 0.0]}, + } + pipeline.set_ip_adapter_scale(scale) + + # To use style+layout blocks + scale = { + "down": {"block_2": [0.0, 1.0]}, + "up": {"block_0": [0.0, 1.0, 0.0]}, + } + pipeline.set_ip_adapter_scale(scale) + + # To use style and layout from 2 reference images + scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}] + pipeline.set_ip_adapter_scale(scales) + ``` + """ + unet_name = getattr(self, "unet_name", "unet") + unet = getattr(self, unet_name) + if not isinstance(scale, list): + scale = [scale] + scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0) + + for attn_name, attn_processor in unet.attn_processors.items(): + if isinstance( + attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor) + ): + if len(scale_configs) != len(attn_processor.scale): + raise ValueError( + f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter." + ) + elif len(scale_configs) == 1: + scale_configs = scale_configs * len(attn_processor.scale) + for i, scale_config in enumerate(scale_configs): + if isinstance(scale_config, dict): + for k, s in scale_config.items(): + if attn_name.startswith(k): + attn_processor.scale[i] = s + else: + attn_processor.scale[i] = scale_config + + def unload_ip_adapter(self): + """ + Unloads the IP Adapter weights + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.unload_ip_adapter() + >>> ... + ``` + """ + + # remove hidden encoder + if self.unet is None: + return + + self.unet.encoder_hid_proj = None + self.unet.config.encoder_hid_dim_type = None + + # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj` + if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None: + self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj + self.unet.text_encoder_hid_proj = None + self.unet.config.encoder_hid_dim_type = "text_proj" + + # restore original Unet attention processors layers + attn_procs = {} + for name, value in self.unet.attn_processors.items(): + attn_processor_class = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() + ) + attn_procs[name] = ( + attn_processor_class + if isinstance( + value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor) + ) + else value.__class__() + ) + self.unet.set_attn_processor(attn_procs) + + class FluxIPAdapterMixin: """Mixin for handling Flux IP Adapters.""" diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index cd4738cfa03c..412c05779492 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -25,7 +25,6 @@ from huggingface_hub import model_info from huggingface_hub.constants import HF_HUB_OFFLINE -from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading from ..models.modeling_utils import ModelMixin, load_state_dict from ..utils import ( USE_PEFT_BACKEND, @@ -331,6 +330,8 @@ def _load_lora_into_text_encoder( hotswap: bool = False, metadata=None, ): + from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading + if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -442,6 +443,8 @@ def _func_optionally_disable_offloading(_pipeline): tuple: A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True. """ + from ..hooks.group_offloading import _is_group_offload_enabled + is_model_cpu_offload = False is_sequential_cpu_offload = False is_group_offload = False diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 4ade3374d80e..393c8ee27d05 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -22,7 +22,6 @@ import safetensors import torch -from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading from ..utils import ( MIN_PEFT_VERSION, USE_PEFT_BACKEND, @@ -164,6 +163,8 @@ def load_lora_adapter( from peft import inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer + from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading + cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -695,6 +696,7 @@ def unload_lora(self): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for `unload_lora()`.") + from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading from ..utils import recurse_remove_peft_layers recurse_remove_peft_layers(self) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index c9b6a7d7d862..3546497c195b 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -22,7 +22,6 @@ import torch.nn.functional as F from huggingface_hub.utils import validate_hf_hub_args -from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading from ..models.embeddings import ( ImageProjection, IPAdapterFaceIDImageProjection, @@ -132,6 +131,8 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict ) ``` """ + from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading + cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py new file mode 100644 index 000000000000..bf34eed28b8c --- /dev/null +++ b/src/diffusers/modular_pipelines/__init__.py @@ -0,0 +1,84 @@ +from typing import TYPE_CHECKING + +from ..utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +# These modules contain pipelines from multiple libraries/frameworks +_dummy_objects = {} +_import_structure = {} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_pt_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["modular_pipeline"] = [ + "ModularPipelineBlocks", + "ModularPipeline", + "PipelineBlock", + "AutoPipelineBlocks", + "SequentialPipelineBlocks", + "LoopSequentialPipelineBlocks", + "PipelineState", + "BlockState", + ] + _import_structure["modular_pipeline_utils"] = [ + "ComponentSpec", + "ConfigSpec", + "InputParam", + "OutputParam", + "InsertableDict", + ] + _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] + _import_structure["components_manager"] = ["ComponentsManager"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 + else: + from .components_manager import ComponentsManager + from .modular_pipeline import ( + AutoPipelineBlocks, + BlockState, + LoopSequentialPipelineBlocks, + ModularPipeline, + ModularPipelineBlocks, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, + ) + from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + InsertableDict, + OutputParam, + ) + from .stable_diffusion_xl import ( + StableDiffusionXLAutoBlocks, + StableDiffusionXLModularPipeline, + ) +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py new file mode 100644 index 000000000000..08e6d80fefd2 --- /dev/null +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -0,0 +1,1046 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import time +from collections import OrderedDict +from itertools import combinations +from typing import Any, Dict, List, Optional, Union + +import torch + +from ..hooks import ModelHook +from ..utils import ( + is_accelerate_available, + logging, +) + + +if is_accelerate_available(): + from accelerate.hooks import add_hook_to_module, remove_hook_from_module + from accelerate.state import PartialState + from accelerate.utils import send_to_device + from accelerate.utils.memory import clear_device_cache + from accelerate.utils.modeling import convert_file_size_to_int + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CustomOffloadHook(ModelHook): + """ + A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are + on the given device. Optionally offloads other models to the CPU before the forward pass is called. + + Args: + execution_device(`str`, `int` or `torch.device`, *optional*): + The device on which the model should be executed. Will default to the MPS device if it's available, then + GPU 0 if there is a GPU, and finally to the CPU. + """ + + no_grad = False + + def __init__( + self, + execution_device: Optional[Union[str, int, torch.device]] = None, + other_hooks: Optional[List["UserCustomOffloadHook"]] = None, + offload_strategy: Optional["AutoOffloadStrategy"] = None, + ): + self.execution_device = execution_device if execution_device is not None else PartialState().default_device + self.other_hooks = other_hooks + self.offload_strategy = offload_strategy + self.model_id = None + + def set_strategy(self, offload_strategy: "AutoOffloadStrategy"): + self.offload_strategy = offload_strategy + + def add_other_hook(self, hook: "UserCustomOffloadHook"): + """ + Add a hook to the list of hooks to consider for offloading. + """ + if self.other_hooks is None: + self.other_hooks = [] + self.other_hooks.append(hook) + + def init_hook(self, module): + return module.to("cpu") + + def pre_forward(self, module, *args, **kwargs): + if module.device != self.execution_device: + if self.other_hooks is not None: + hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device] + # offload all other hooks + start_time = time.perf_counter() + if self.offload_strategy is not None: + hooks_to_offload = self.offload_strategy( + hooks=hooks_to_offload, + model_id=self.model_id, + model=module, + execution_device=self.execution_device, + ) + end_time = time.perf_counter() + logger.info( + f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds" + ) + + for hook in hooks_to_offload: + logger.info( + f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu" + ) + hook.offload() + + if hooks_to_offload: + clear_device_cache() + module.to(self.execution_device) + return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device) + + +class UserCustomOffloadHook: + """ + A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of + the hook or remove it entirely. + """ + + def __init__(self, model_id, model, hook): + self.model_id = model_id + self.model = model + self.hook = hook + + def offload(self): + self.hook.init_hook(self.model) + + def attach(self): + add_hook_to_module(self.model, self.hook) + self.hook.model_id = self.model_id + + def remove(self): + remove_hook_from_module(self.model) + self.hook.model_id = None + + def add_other_hook(self, hook: "UserCustomOffloadHook"): + self.hook.add_other_hook(hook) + + +def custom_offload_with_hook( + model_id: str, + model: torch.nn.Module, + execution_device: Union[str, int, torch.device] = None, + offload_strategy: Optional["AutoOffloadStrategy"] = None, +): + hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy) + user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook) + user_hook.attach() + return user_hook + + +# this is the class that user can customize to implement their own offload strategy +class AutoOffloadStrategy: + """ + Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on + the available memory on the device. + """ + + # YiYi TODO: instead of memory_reserve_margin, we should let user set the maximum_total_models_size to keep on device + # the actual memory usage would be higher. But it's simpler this way, and can be tested + def __init__(self, memory_reserve_margin="3GB"): + self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin) + + def __call__(self, hooks, model_id, model, execution_device): + if len(hooks) == 0: + return [] + + current_module_size = model.get_memory_footprint() + + mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0] + mem_on_device = mem_on_device - self.memory_reserve_margin + if current_module_size < mem_on_device: + return [] + + min_memory_offload = current_module_size - mem_on_device + logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory") + + # exlucde models that's not currently loaded on the device + module_sizes = dict( + sorted( + {hook.model_id: hook.model.get_memory_footprint() for hook in hooks}.items(), + key=lambda x: x[1], + reverse=True, + ) + ) + + # YiYi/Dhruv TODO: sort smallest to largest, and offload in that order we would tend to keep the larger models on GPU more often + def search_best_candidate(module_sizes, min_memory_offload): + """ + search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a + minimum memory offload size. the combination of models should add up to the smallest modulesize that is + larger than `min_memory_offload` + """ + model_ids = list(module_sizes.keys()) + best_candidate = None + best_size = float("inf") + for r in range(1, len(model_ids) + 1): + for candidate_model_ids in combinations(model_ids, r): + candidate_size = sum( + module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids + ) + if candidate_size < min_memory_offload: + continue + else: + if best_candidate is None or candidate_size < best_size: + best_candidate = candidate_model_ids + best_size = candidate_size + + return best_candidate + + best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload) + + if best_offload_model_ids is None: + # if no combination is found, meaning that we cannot meet the memory requirement, offload all models + logger.warning("no combination of models to offload to cpu is found, offloading all models") + hooks_to_offload = hooks + else: + hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids] + + return hooks_to_offload + + +# utils for display component info in a readable format +# TODO: move to a different file +def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: + """Summarizes a dictionary by finding common prefixes that share the same value. + + For a dictionary with dot-separated keys like: { + 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], + 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], + 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], + } + + Returns a dictionary where keys are the shortest common prefixes and values are their shared values: { + 'down_blocks': [0.6], 'up_blocks': [0.3] + } + """ + # First group by values - convert lists to tuples to make them hashable + value_to_keys = {} + for key, value in d.items(): + value_tuple = tuple(value) if isinstance(value, list) else value + if value_tuple not in value_to_keys: + value_to_keys[value_tuple] = [] + value_to_keys[value_tuple].append(key) + + def find_common_prefix(keys: List[str]) -> str: + """Find the shortest common prefix among a list of dot-separated keys.""" + if not keys: + return "" + if len(keys) == 1: + return keys[0] + + # Split all keys into parts + key_parts = [k.split(".") for k in keys] + + # Find how many initial parts are common + common_length = 0 + for parts in zip(*key_parts): + if len(set(parts)) == 1: # All parts at this position are the same + common_length += 1 + else: + break + + if common_length == 0: + return "" + + # Return the common prefix + return ".".join(key_parts[0][:common_length]) + + # Create summary by finding common prefixes for each value group + summary = {} + for value_tuple, keys in value_to_keys.items(): + prefix = find_common_prefix(keys) + if prefix: # Only add if we found a common prefix + # Convert tuple back to list if it was originally a list + value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple + summary[prefix] = value + else: + summary[""] = value # Use empty string if no common prefix + + return summary + + +class ComponentsManager: + """ + A central registry and management system for model components across multiple pipelines. + + [`ComponentsManager`] provides a unified way to register, track, and reuse model components (like UNet, VAE, text + encoders, etc.) across different modular pipelines. It includes features for duplicate detection, memory + management, and component organization. + + + + This is an experimental feature and is likely to change in the future. + + + + Example: + ```python + from diffusers import ComponentsManager + + # Create a components manager + cm = ComponentsManager() + + # Add components + cm.add("unet", unet_model, collection="sdxl") + cm.add("vae", vae_model, collection="sdxl") + + # Enable auto offloading + cm.enable_auto_cpu_offload(device="cuda") + + # Retrieve components + unet = cm.get_one(name="unet", collection="sdxl") + ``` + """ + + _available_info_fields = [ + "model_id", + "added_time", + "collection", + "class_name", + "size_gb", + "adapters", + "has_hook", + "execution_device", + "ip_adapter", + ] + + def __init__(self): + self.components = OrderedDict() + # YiYi TODO: can remove once confirm we don't need this in mellon + self.added_time = OrderedDict() # Store when components were added + self.collections = OrderedDict() # collection_name -> set of component_names + self.model_hooks = None + self._auto_offload_enabled = False + + def _lookup_ids( + self, + name: Optional[str] = None, + collection: Optional[str] = None, + load_id: Optional[str] = None, + components: Optional[OrderedDict] = None, + ): + """ + Lookup component_ids by name, collection, or load_id. Does not support pattern matching. Returns a set of + component_ids + """ + if components is None: + components = self.components + + if name: + ids_by_name = set() + for component_id, component in components.items(): + comp_name = self._id_to_name(component_id) + if comp_name == name: + ids_by_name.add(component_id) + else: + ids_by_name = set(components.keys()) + if collection: + ids_by_collection = set() + for component_id, component in components.items(): + if component_id in self.collections[collection]: + ids_by_collection.add(component_id) + else: + ids_by_collection = set(components.keys()) + if load_id: + ids_by_load_id = set() + for name, component in components.items(): + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: + ids_by_load_id.add(name) + else: + ids_by_load_id = set(components.keys()) + + ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id) + return ids + + @staticmethod + def _id_to_name(component_id: str): + return "_".join(component_id.split("_")[:-1]) + + def add(self, name: str, component: Any, collection: Optional[str] = None): + """ + Add a component to the ComponentsManager. + + Args: + name (str): The name of the component + component (Any): The component to add + collection (Optional[str]): The collection to add the component to + + Returns: + str: The unique component ID, which is generated as "{name}_{id(component)}" where + id(component) is Python's built-in unique identifier for the object + """ + component_id = f"{name}_{id(component)}" + + # check for duplicated components + for comp_id, comp in self.components.items(): + if comp == component: + comp_name = self._id_to_name(comp_id) + if comp_name == name: + logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'") + component_id = comp_id + break + else: + logger.warning( + f"ComponentsManager: adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'" + f"To remove a duplicate, call `components_manager.remove('')`." + ) + + # check for duplicated load_id and warn (we do not delete for you) + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": + components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id) + components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id] + + if components_with_same_load_id: + existing = ", ".join(components_with_same_load_id) + logger.warning( + f"ComponentsManager: adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " + f"To remove a duplicate, call `components_manager.remove('')`." + ) + + # add component to components manager + self.components[component_id] = component + self.added_time[component_id] = time.time() + + if collection: + if collection not in self.collections: + self.collections[collection] = set() + if component_id not in self.collections[collection]: + comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) + for comp_id in comp_ids_in_collection: + logger.warning( + f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}" + ) + self.remove(comp_id) + self.collections[collection].add(component_id) + logger.info( + f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}" + ) + else: + logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'") + + if self._auto_offload_enabled: + self.enable_auto_cpu_offload(self._auto_offload_device) + + return component_id + + def remove(self, component_id: str = None): + """ + Remove a component from the ComponentsManager. + + Args: + component_id (str): The ID of the component to remove + """ + if component_id not in self.components: + logger.warning(f"Component '{component_id}' not found in ComponentsManager") + return + + component = self.components.pop(component_id) + self.added_time.pop(component_id) + + for collection in self.collections: + if component_id in self.collections[collection]: + self.collections[collection].remove(component_id) + + if self._auto_offload_enabled: + self.enable_auto_cpu_offload(self._auto_offload_device) + else: + if isinstance(component, torch.nn.Module): + component.to("cpu") + del component + import gc + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # YiYi TODO: rename to search_components for now, may remove this method + def search_components( + self, + names: Optional[str] = None, + collection: Optional[str] = None, + load_id: Optional[str] = None, + return_dict_with_names: bool = True, + ): + """ + Search components by name with simple pattern matching. Optionally filter by collection or load_id. + + Args: + names: Component name(s) or pattern(s) + Patterns: + - "unet" : match any component with base name "unet" (e.g., unet_123abc) + - "!unet" : everything except components with base name "unet" + - "unet*" : anything with base name starting with "unet" + - "!unet*" : anything with base name NOT starting with "unet" + - "*unet*" : anything with base name containing "unet" + - "!*unet*" : anything with base name NOT containing "unet" + - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" + - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" + - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" + collection: Optional collection to filter by + load_id: Optional load_id to filter by + return_dict_with_names: + If True, returns a dictionary with component names as keys, throw an error if + multiple components with the same name are found If False, returns a dictionary + with component IDs as keys + + Returns: + Dictionary mapping component names to components if return_dict_with_names=True, or a dictionary mapping + component IDs to components if return_dict_with_names=False + """ + + # select components based on collection and load_id filters + selected_ids = self._lookup_ids(collection=collection, load_id=load_id) + components = {k: self.components[k] for k in selected_ids} + + def get_return_dict(components, return_dict_with_names): + """ + Create a dictionary mapping component names to components if return_dict_with_names=True, or a dictionary + mapping component IDs to components if return_dict_with_names=False, throw an error if duplicate component + names are found when return_dict_with_names=True + """ + if return_dict_with_names: + dict_to_return = {} + for comp_id, comp in components.items(): + comp_name = self._id_to_name(comp_id) + if comp_name in dict_to_return: + raise ValueError( + f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys" + ) + dict_to_return[comp_name] = comp + return dict_to_return + else: + return components + + # if no names are provided, return the filtered components as it is + if names is None: + return get_return_dict(components, return_dict_with_names) + + # if names is not a string, raise an error + elif not isinstance(names, str): + raise ValueError(f"Invalid type for `names: {type(names)}, only support string") + + # Create mapping from component_id to base_name for components to be used for pattern matching + base_names = {comp_id: self._id_to_name(comp_id) for comp_id in components.keys()} + + # Helper function to check if a component matches a pattern based on its base name + def matches_pattern(component_id, pattern, exact_match=False): + """ + Helper function to check if a component matches a pattern based on its base name. + + Args: + component_id: The component ID to check + pattern: The pattern to match against + exact_match: If True, only exact matches to base_name are considered + """ + base_name = base_names[component_id] + + # Exact match with base name + if exact_match: + return pattern == base_name + + # Prefix match (ends with *) + elif pattern.endswith("*"): + prefix = pattern[:-1] + return base_name.startswith(prefix) + + # Contains match (starts with *) + elif pattern.startswith("*"): + search = pattern[1:-1] if pattern.endswith("*") else pattern[1:] + return search in base_name + + # Exact match (no wildcards) + else: + return pattern == base_name + + # Check if this is a "not" pattern + is_not_pattern = names.startswith("!") + if is_not_pattern: + names = names[1:] # Remove the ! prefix + + # Handle OR patterns (containing |) + if "|" in names: + terms = names.split("|") + matches = {} + + for comp_id, comp in components.items(): + # For OR patterns with exact names (no wildcards), we do exact matching on base names + exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms) + + # Check if any of the terms match this component + should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) + + # Flip the decision if this is a NOT pattern + if is_not_pattern: + should_include = not should_include + + if should_include: + matches[comp_id] = comp + + log_msg = "NOT " if is_not_pattern else "" + match_type = "exactly matching" if exact_match else "matching any of patterns" + logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") + + # Try exact match with a base name + elif any(names == base_name for base_name in base_names.values()): + # Find all components with this base name + matches = { + comp_id: comp + for comp_id, comp in components.items() + if (base_names[comp_id] == names) != is_not_pattern + } + + if is_not_pattern: + logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") + + # Prefix match (ends with *) + elif names.endswith("*"): + prefix = names[:-1] + matches = { + comp_id: comp + for comp_id, comp in components.items() + if base_names[comp_id].startswith(prefix) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") + else: + logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") + + # Contains match (starts with *) + elif names.startswith("*"): + search = names[1:-1] if names.endswith("*") else names[1:] + matches = { + comp_id: comp + for comp_id, comp in components.items() + if (search in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{search}': {list(matches.keys())}") + + # Substring match (no wildcards, but not an exact component name) + elif any(names in base_name for base_name in base_names.values()): + matches = { + comp_id: comp + for comp_id, comp in components.items() + if (names in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{names}': {list(matches.keys())}") + + else: + raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") + + if not matches: + raise ValueError(f"No components found matching pattern '{names}'") + + return get_return_dict(matches, return_dict_with_names) + + def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"): + """ + Enable automatic CPU offloading for all components. + + The algorithm works as follows: + 1. All models start on CPU by default + 2. When a model's forward pass is called, it's moved to the execution device + 3. If there's insufficient memory, other models on the device are moved back to CPU + 4. The system tries to offload the smallest combination of models that frees enough memory + 5. Models stay on the execution device until another model needs memory and forces them off + + Args: + device (Union[str, int, torch.device]): The execution device where models are moved for forward passes + memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of + memory to keep free on the device to avoid running out of memory during model + execution (e.g., for intermediate activations, gradients, etc.) + """ + if not is_accelerate_available(): + raise ImportError("Make sure to install accelerate to use auto_cpu_offload") + + for name, component in self.components.items(): + if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): + remove_hook_from_module(component, recurse=True) + + self.disable_auto_cpu_offload() + offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin) + device = torch.device(device) + if device.index is None: + device = torch.device(f"{device.type}:{0}") + all_hooks = [] + for name, component in self.components.items(): + if isinstance(component, torch.nn.Module): + hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy) + all_hooks.append(hook) + + for hook in all_hooks: + other_hooks = [h for h in all_hooks if h is not hook] + for other_hook in other_hooks: + if other_hook.hook.execution_device == hook.hook.execution_device: + hook.add_other_hook(other_hook) + + self.model_hooks = all_hooks + self._auto_offload_enabled = True + self._auto_offload_device = device + + def disable_auto_cpu_offload(self): + """ + Disable automatic CPU offloading for all components. + """ + if self.model_hooks is None: + self._auto_offload_enabled = False + return + + for hook in self.model_hooks: + hook.offload() + hook.remove() + if self.model_hooks: + clear_device_cache() + self.model_hooks = None + self._auto_offload_enabled = False + + # YiYi TODO: (1) add quantization info + def get_model_info( + self, + component_id: str, + fields: Optional[Union[str, List[str]]] = None, + ) -> Optional[Dict[str, Any]]: + """Get comprehensive information about a component. + + Args: + component_id (str): Name of the component to get info for + fields (Optional[Union[str, List[str]]]): + Field(s) to return. Can be a string for single field or list of fields. If None, uses the + available_info_fields setting. + + Returns: + Dictionary containing requested component metadata. If fields is specified, returns only those fields. + Otherwise, returns all fields. + """ + if component_id not in self.components: + raise ValueError(f"Component '{component_id}' not found in ComponentsManager") + + component = self.components[component_id] + + # Validate fields if specified + if fields is not None: + if isinstance(fields, str): + fields = [fields] + for field in fields: + if field not in self._available_info_fields: + raise ValueError(f"Field '{field}' not found in available_info_fields") + + # Build complete info dict first + info = { + "model_id": component_id, + "added_time": self.added_time[component_id], + "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) + or None, + } + + # Additional info for torch.nn.Module components + if isinstance(component, torch.nn.Module): + # Check for hook information + has_hook = hasattr(component, "_hf_hook") + execution_device = None + if has_hook and hasattr(component._hf_hook, "execution_device"): + execution_device = component._hf_hook.execution_device + + info.update( + { + "class_name": component.__class__.__name__, + "size_gb": component.get_memory_footprint() / (1024**3), + "adapters": None, # Default to None + "has_hook": has_hook, + "execution_device": execution_device, + } + ) + + # Get adapters if applicable + if hasattr(component, "peft_config"): + info["adapters"] = list(component.peft_config.keys()) + + # Check for IP-Adapter scales + if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"): + processors = copy.deepcopy(component.attn_processors) + # First check if any processor is an IP-Adapter + processor_types = [v.__class__.__name__ for v in processors.values()] + if any("IPAdapter" in ptype for ptype in processor_types): + # Then get scales only from IP-Adapter processors + scales = { + k: v.scale + for k, v in processors.items() + if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__ + } + if scales: + info["ip_adapter"] = summarize_dict_by_value_and_parts(scales) + + # If fields specified, filter info + if fields is not None: + return {k: v for k, v in info.items() if k in fields} + else: + return info + + # YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table + def __repr__(self): + # Handle empty components case + if not self.components: + return "Components:\n" + "=" * 50 + "\nNo components registered.\n" + "=" * 50 + + # Extract load_id if available + def get_load_id(component): + if hasattr(component, "_diffusers_load_id"): + return component._diffusers_load_id + return "N/A" + + # Format device info compactly + def format_device(component, info): + if not info["has_hook"]: + return str(getattr(component, "device", "N/A")) + else: + device = str(getattr(component, "device", "N/A")) + exec_device = str(info["execution_device"] or "N/A") + return f"{device}({exec_device})" + + # Get max length of load_ids for models + load_ids = [ + get_load_id(component) + for component in self.components.values() + if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") + ] + max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 + + # Get all collections for each component + component_collections = {} + for name in self.components.keys(): + component_collections[name] = [] + for coll, comps in self.collections.items(): + if name in comps: + component_collections[name].append(coll) + if not component_collections[name]: + component_collections[name] = ["N/A"] + + # Find the maximum collection name length + all_collections = [coll for colls in component_collections.values() for coll in colls] + max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10 + + col_widths = { + "id": max(15, max(len(name) for name in self.components.keys())), + "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), + "device": 20, + "dtype": 15, + "size": 10, + "load_id": max_load_id_len, + "collection": max_collection_len, + } + + # Create the header lines + sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" + dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" + + output = "Components:\n" + sep_line + + # Separate components into models and others + models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} + others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)} + + # Models section + if models: + output += "Models:\n" + dash_line + # Column headers + output += f"{'Name_ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " + output += f"{'Device: act(exec)':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " + output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" + output += dash_line + + # Model entries + for name, component in models.items(): + info = self.get_model_info(name) + device_str = format_device(component, info) + dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" + load_id = get_load_id(component) + + # Print first collection on the main line + first_collection = component_collections[name][0] if component_collections[name] else "N/A" + + output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | " + output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " + output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n" + + # Print additional collections on separate lines if they exist + for i in range(1, len(component_collections[name])): + collection = component_collections[name][i] + output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | " + output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | " + output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n" + + output += dash_line + + # Other components section + if others: + if models: # Add extra newline if we had models section + output += "\n" + output += "Other Components:\n" + dash_line + # Column headers for other components + output += f"{'ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | Collection\n" + output += dash_line + + # Other component entries + for name, component in others.items(): + info = self.get_model_info(name) + + # Print first collection on the main line + first_collection = component_collections[name][0] if component_collections[name] else "N/A" + + output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n" + + # Print additional collections on separate lines if they exist + for i in range(1, len(component_collections[name])): + collection = component_collections[name][i] + output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | {collection}\n" + + output += dash_line + + # Add additional component info + output += "\nAdditional Component Info:\n" + "=" * 50 + "\n" + for name in self.components: + info = self.get_model_info(name) + if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): + output += f"\n{name}:\n" + if info.get("adapters") is not None: + output += f" Adapters: {info['adapters']}\n" + if info.get("ip_adapter"): + output += " IP-Adapter: Enabled\n" + + return output + + def get_one( + self, + component_id: Optional[str] = None, + name: Optional[str] = None, + collection: Optional[str] = None, + load_id: Optional[str] = None, + ) -> Any: + """ + Get a single component by either: + - searching name (pattern matching), collection, or load_id. + - passing in a component_id + Raises an error if multiple components match or none are found. + + Args: + component_id (Optional[str]): Optional component ID to get + name (Optional[str]): Component name or pattern + collection (Optional[str]): Optional collection to filter by + load_id (Optional[str]): Optional load_id to filter by + + Returns: + A single component + + Raises: + ValueError: If no components match or multiple components match + """ + + if component_id is not None and (name is not None or collection is not None or load_id is not None): + raise ValueError("If searching by component_id, do not pass name, collection, or load_id") + + # search by component_id + if component_id is not None: + if component_id not in self.components: + raise ValueError(f"Component '{component_id}' not found in ComponentsManager") + return self.components[component_id] + # search with name/collection/load_id + results = self.search_components(name, collection, load_id) + + if not results: + raise ValueError(f"No components found matching '{name}'") + + if len(results) > 1: + raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") + + return next(iter(results.values())) + + def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None): + """ + Get component IDs by a list of names, optionally filtered by collection. + + Args: + names (Union[str, List[str]]): List of component names + collection (Optional[str]): Optional collection to filter by + + Returns: + List[str]: List of component IDs + """ + ids = set() + if not isinstance(names, list): + names = [names] + for name in names: + ids.update(self._lookup_ids(name=name, collection=collection)) + return list(ids) + + def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True): + """ + Get components by a list of IDs. + + Args: + ids (List[str]): + List of component IDs + return_dict_with_names (Optional[bool]): + Whether to return a dictionary with component names as keys: + + Returns: + Dict[str, Any]: Dictionary of components. + - If return_dict_with_names=True, keys are component names. + - If return_dict_with_names=False, keys are component IDs. + + Raises: + ValueError: If duplicate component names are found in the search results when return_dict_with_names=True + """ + components = {id: self.components[id] for id in ids} + + if return_dict_with_names: + dict_to_return = {} + for comp_id, comp in components.items(): + comp_name = self._id_to_name(comp_id) + if comp_name in dict_to_return: + raise ValueError( + f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys" + ) + dict_to_return[comp_name] = comp + return dict_to_return + else: + return components + + def get_components_by_names(self, names: List[str], collection: Optional[str] = None): + """ + Get components by a list of names, optionally filtered by collection. + + Args: + names (List[str]): List of component names + collection (Optional[str]): Optional collection to filter by + + Returns: + Dict[str, Any]: Dictionary of components with component names as keys + + Raises: + ValueError: If duplicate component names are found in the search results + """ + ids = self.get_ids(names, collection) + return self.get_components_by_ids(ids) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py new file mode 100644 index 000000000000..b99478cb58d1 --- /dev/null +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -0,0 +1,2827 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import inspect +import os +import traceback +import warnings +from collections import OrderedDict +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from huggingface_hub import create_repo +from huggingface_hub.utils import validate_hf_hub_args +from tqdm.auto import tqdm +from typing_extensions import Self + +from ..configuration_utils import ConfigMixin, FrozenDict +from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj +from ..utils import ( + PushToHubMixin, + is_accelerate_available, + logging, +) +from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from ..utils.hub_utils import load_or_create_model_card, populate_model_card +from .components_manager import ComponentsManager +from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + InsertableDict, + OutputParam, + format_components, + format_configs, + format_inputs_short, + format_intermediates_short, + make_doc_string, +) + + +if is_accelerate_available(): + import accelerate + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +MODULAR_PIPELINE_MAPPING = OrderedDict( + [ + ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), + ] +) + +MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict( + [ + ("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"), + ] +) + + +@dataclass +class PipelineState: + """ + [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. + """ + + inputs: Dict[str, Any] = field(default_factory=dict) + intermediates: Dict[str, Any] = field(default_factory=dict) + input_kwargs: Dict[str, List[str]] = field(default_factory=dict) + intermediate_kwargs: Dict[str, List[str]] = field(default_factory=dict) + + def set_input(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an input to the immutable pipeline state, i.e, pipeline_state.inputs. + + The kwargs_type parameter allows you to associate inputs with specific input types. For example, if you call + set_input(prompt_embeds=..., kwargs_type="guider_kwargs"), this input will be automatically fetched when a + pipeline block has "guider_kwargs" in its expected_inputs list. + + Args: + key (str): The key for the input + value (Any): The input value + kwargs_type (str): The kwargs_type with which the input is associated + """ + self.inputs[key] = value + if kwargs_type is not None: + if kwargs_type not in self.input_kwargs: + self.input_kwargs[kwargs_type] = [key] + else: + self.input_kwargs[kwargs_type].append(key) + + def set_intermediate(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an intermediate value to the mutable pipeline state, i.e, pipeline_state.intermediates. + + The kwargs_type parameter allows you to associate intermediate values with specific input types. For example, + if you call set_intermediate(latents=..., kwargs_type="latents_kwargs"), this intermediate value will be + automatically fetched when a pipeline block has "latents_kwargs" in its expected_intermediate_inputs list. + + Args: + key (str): The key for the intermediate value + value (Any): The intermediate value + kwargs_type (str): The kwargs_type with which the intermediate value is associated + """ + self.intermediates[key] = value + if kwargs_type is not None: + if kwargs_type not in self.intermediate_kwargs: + self.intermediate_kwargs[kwargs_type] = [key] + else: + self.intermediate_kwargs[kwargs_type].append(key) + + def get_input(self, key: str, default: Any = None) -> Any: + """ + Get an input from the pipeline state. + + Args: + key (str): The key for the input + default (Any): The default value to return if the input is not found + + Returns: + Any: The input value + """ + value = self.inputs.get(key, default) + if value is not None: + return deepcopy(value) + + def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + """ + Get multiple inputs from the pipeline state. + + Args: + keys (List[str]): The keys for the inputs + default (Any): The default value to return if the input is not found + + Returns: + Dict[str, Any]: Dictionary of inputs with matching keys + """ + return {key: self.inputs.get(key, default) for key in keys} + + def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all inputs with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of inputs with matching kwargs_type + """ + input_names = self.input_kwargs.get(kwargs_type, []) + return self.get_inputs(input_names) + + def get_intermediate_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all intermediates with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of intermediates with matching kwargs_type + """ + intermediate_names = self.intermediate_kwargs.get(kwargs_type, []) + return self.get_intermediates(intermediate_names) + + def get_intermediate(self, key: str, default: Any = None) -> Any: + """ + Get an intermediate value from the pipeline state. + + Args: + key (str): The key for the intermediate value + default (Any): The default value to return if the intermediate value is not found + + Returns: + Any: The intermediate value + """ + return self.intermediates.get(key, default) + + def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + """ + Get multiple intermediate values from the pipeline state. + + Args: + keys (List[str]): The keys for the intermediate values + default (Any): The default value to return if the intermediate value is not found + + Returns: + Dict[str, Any]: Dictionary of intermediate values with matching keys + """ + return {key: self.intermediates.get(key, default) for key in keys} + + def to_dict(self) -> Dict[str, Any]: + """ + Convert PipelineState to a dictionary. + + Returns: + Dict[str, Any]: Dictionary containing all attributes of the PipelineState + """ + return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} + + def __repr__(self): + def format_value(v): + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(dtype={v.dtype}, shape={v.shape})" + elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]" + else: + return repr(v) + + inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) + intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) + + # Format input_kwargs and intermediate_kwargs + input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) + intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) + + return ( + f"PipelineState(\n" + f" inputs={{\n{inputs}\n }},\n" + f" intermediates={{\n{intermediates}\n }},\n" + f" input_kwargs={{\n{input_kwargs_str}\n }},\n" + f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n" + f")" + ) + + +@dataclass +class BlockState: + """ + Container for block state data with attribute access and formatted representation. + """ + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def __getitem__(self, key: str): + # allows block_state["foo"] + return getattr(self, key, None) + + def __setitem__(self, key: str, value: Any): + # allows block_state["foo"] = "bar" + setattr(self, key, value) + + def as_dict(self): + """ + Convert BlockState to a dictionary. + + Returns: + Dict[str, Any]: Dictionary containing all attributes of the BlockState + """ + return dict(self.__dict__.items()) + + def __repr__(self): + def format_value(v): + # Handle tensors directly + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(dtype={v.dtype}, shape={v.shape})" + + # Handle lists of tensors + elif isinstance(v, list): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"List[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle tuples of tensors + elif isinstance(v, tuple): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle dicts with tensor values + elif isinstance(v, dict): + formatted_dict = {} + for k, val in v.items(): + if hasattr(val, "shape") and hasattr(val, "dtype"): + formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})" + elif ( + isinstance(val, list) + and len(val) > 0 + and hasattr(val[0], "shape") + and hasattr(val[0], "dtype") + ): + shapes = [t.shape for t in val] + formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}" + else: + formatted_dict[k] = repr(val) + return formatted_dict + + # Default case + return repr(v) + + attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items()) + return f"BlockState(\n{attributes}\n)" + + +class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): + """ + Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks, + LoopSequentialPipelineBlocks + + [`ModularPipelineBlocks`] provides method to load and save the defination of pipeline blocks. + + + + This is an experimental feature and is likely to change in the future. + + + """ + + config_name = "config.json" + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + + return expected_modules, optional_parameters + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: Optional[bool] = None, + **kwargs, + ): + hub_kwargs_names = [ + "cache_dir", + "force_download", + "local_files_only", + "proxies", + "resume_download", + "revision", + "subfolder", + "token", + ] + hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} + + config = cls.load_config(pretrained_model_name_or_path) + has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_remote_code + ) + if not (has_remote_code and trust_remote_code): + raise ValueError("TODO") + + class_ref = config["auto_map"][cls.__name__] + module_file, class_name = class_ref.split(".") + module_file = module_file + ".py" + block_cls = get_class_from_dynamic_module( + pretrained_model_name_or_path, + module_file=module_file, + class_name=class_name, + is_modular=True, + **hub_kwargs, + **kwargs, + ) + expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls) + block_kwargs = { + name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs + } + + return block_cls(**block_kwargs) + + def save_pretrained(self, save_directory, push_to_hub=False, **kwargs): + # TODO: factor out this logic. + cls_name = self.__class__.__name__ + + full_mod = type(self).__module__ + module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "") + parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] + auto_map = {f"{parent_module}": f"{module}.{cls_name}"} + + self.register_to_config(auto_map=auto_map) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + config = dict(self.config) + self._internal_dict = FrozenDict(config) + + def init_pipeline( + self, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + components_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + ) -> "ModularPipeline": + """ + create a ModularPipeline, optionally accept modular_repo to load from hub. + """ + pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__) + diffusers_module = importlib.import_module("diffusers") + pipeline_class = getattr(diffusers_module, pipeline_class_name) + + modular_pipeline = pipeline_class( + blocks=deepcopy(self), + pretrained_model_name_or_path=pretrained_model_name_or_path, + components_manager=components_manager, + collection=collection, + ) + return modular_pipeline + + @staticmethod + def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: + """ + Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if + current default value is None and new default value is not None. Warns if multiple non-None default values + exist for the same input. + + Args: + named_input_lists: List of tuples containing (block_name, input_param_list) pairs + + Returns: + List[InputParam]: Combined list of unique InputParam objects + """ + combined_dict = {} # name -> InputParam + value_sources = {} # name -> block_name + + for block_name, inputs in named_input_lists: + for input_param in inputs: + if input_param.name is None and input_param.kwargs_type is not None: + input_name = "*_" + input_param.kwargs_type + else: + input_name = input_param.name + if input_name in combined_dict: + current_param = combined_dict[input_name] + if ( + current_param.default is not None + and input_param.default is not None + and current_param.default != input_param.default + ): + warnings.warn( + f"Multiple different default values found for input '{input_name}': " + f"{current_param.default} (from block '{value_sources[input_name]}') and " + f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." + ) + if current_param.default is None and input_param.default is not None: + combined_dict[input_name] = input_param + value_sources[input_name] = block_name + else: + combined_dict[input_name] = input_param + value_sources[input_name] = block_name + + return list(combined_dict.values()) + + @staticmethod + def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: + """ + Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first + occurrence of each output name. + + Args: + named_output_lists: List of tuples containing (block_name, output_param_list) pairs + + Returns: + List[OutputParam]: Combined list of unique OutputParam objects + """ + combined_dict = {} # name -> OutputParam + + for block_name, outputs in named_output_lists: + for output_param in outputs: + if (output_param.name not in combined_dict) or ( + combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None + ): + combined_dict[output_param.name] = output_param + + return list(combined_dict.values()) + + +class PipelineBlock(ModularPipelineBlocks): + """ + A Pipeline Block is the basic building block of a Modular Pipeline. + + This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the + library implements for all the pipeline blocks (such as loading or saving etc.) + + + + This is an experimental feature and is likely to change in the future. + + + + Args: + description (str, optional): A description of the block, defaults to None. Define as a property in subclasses. + expected_components (List[ComponentSpec], optional): + A list of components that are expected to be used in the block, defaults to []. To override, define as a + property in subclasses. + expected_configs (List[ConfigSpec], optional): + A list of configs that are expected to be used in the block, defaults to []. To override, define as a + property in subclasses. + inputs (List[InputParam], optional): + A list of inputs that are expected to be used in the block, defaults to []. To override, define as a + property in subclasses. + intermediate_inputs (List[InputParam], optional): + A list of intermediate inputs that are expected to be used in the block, defaults to []. To override, + define as a property in subclasses. + intermediate_outputs (List[OutputParam], optional): + A list of intermediate outputs that are expected to be used in the block, defaults to []. To override, + define as a property in subclasses. + outputs (List[OutputParam], optional): + A list of outputs that are expected to be used in the block, defaults to []. To override, define as a + property in subclasses. + required_inputs (List[str], optional): + A list of required inputs that are expected to be used in the block, defaults to []. To override, define as + a property in subclasses. + required_intermediate_inputs (List[str], optional): + A list of required intermediate inputs that are expected to be used in the block, defaults to []. To + override, define as a property in subclasses. + required_intermediate_outputs (List[str], optional): + A list of required intermediate outputs that are expected to be used in the block, defaults to []. To + override, define as a property in subclasses. + """ + + model_name = None + + def __init__(self): + self.sub_blocks = InsertableDict() + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + # raise NotImplementedError("description method must be implemented in subclasses") + return "TODO: add a description" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def intermediate_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + def _get_outputs(self): + return self.intermediate_outputs + + # YiYi TODO: is it too easy for user to unintentionally override these properties? + # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks + @property + def outputs(self) -> List[OutputParam]: + return self._get_outputs() + + def _get_required_inputs(self): + input_names = [] + for input_param in self.inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def required_inputs(self) -> List[str]: + return self._get_required_inputs() + + def _get_required_intermediate_inputs(self): + input_names = [] + for input_param in self.intermediate_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediate_inputs(self) -> List[str]: + return self._get_required_intermediate_inputs() + + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + raise NotImplementedError("__call__ method must be implemented in subclasses") + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + + # Format description with proper indentation + desc_lines = self.description.split("\n") + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = "\n".join(desc) + "\n" + + # Components section - use format_components with add_empty_lines=False + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + components = " " + components_str.replace("\n", "\n ") + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + configs = " " + configs_str.replace("\n", "\n ") + + # Inputs section + inputs_str = format_inputs_short(self.inputs) + inputs = "Inputs:\n " + inputs_str + + # Intermediates section + intermediates_str = format_intermediates_short( + self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs + ) + intermediates = f"Intermediates:\n{intermediates_str}" + + return f"{class_name}(\n Class: {base_class}\n{desc}{components}\n{configs}\n {inputs}\n {intermediates}\n)" + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediate_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs, + ) + + # YiYi TODO: input and inteermediate inputs with same name? should warn? + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediate_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) + if intermediate_kwargs: + for k, v in intermediate_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def set_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediate_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.set_intermediate(output_param.name, param, output_param.kwargs_type) + + for input_param in self.intermediate_inputs: + if hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.set_intermediate(input_param.name, param, input_param.kwargs_type) + + for input_param in self.intermediate_inputs: + if input_param.name and hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.set_intermediate(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: + # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters + # we need to first find out which inputs are and loop through them. + intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediate_kwargs.items(): + param = getattr(block_state, param_name) + if current_value is not param: # Using identity comparison to check if object was modified + state.set_intermediate(param_name, param, input_param.kwargs_type) + + +class AutoPipelineBlocks(ModularPipelineBlocks): + """ + A Pipeline Blocks that automatically selects a block to run based on the inputs. + + This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the + library implements for all the pipeline blocks (such as loading or saving etc.) + + + + This is an experimental feature and is likely to change in the future. + + + + Attributes: + block_classes: List of block classes to be used + block_names: List of prefixes for each block + block_trigger_inputs: List of input names that trigger specific blocks, with None for default + """ + + block_classes = [] + block_names = [] + block_trigger_inputs = [] + + def __init__(self): + sub_blocks = InsertableDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + sub_blocks[block_name] = block_cls() + self.sub_blocks = sub_blocks + if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): + raise ValueError( + f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same." + ) + default_blocks = [t for t in self.block_trigger_inputs if t is None] + # can only have 1 or 0 default block, and has to put in the last + # the order of blocks matters here because the first block with matching trigger will be dispatched + # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] + # as long as mask is provided, it is inpaint; if only image is provided, it is img2img + if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None): + raise ValueError( + f"In {self.__class__.__name__}, exactly one None must be specified as the last element " + "in block_trigger_inputs." + ) + + # Map trigger inputs to block objects + self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.values())) + self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.keys())) + self.block_to_trigger_map = dict(zip(self.sub_blocks.keys(), self.block_trigger_inputs)) + + @property + def model_name(self): + return next(iter(self.sub_blocks.values())).model_name + + @property + def description(self): + return "" + + @property + def expected_components(self): + expected_components = [] + for block in self.sub_blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + @property + def expected_configs(self): + expected_configs = [] + for block in self.sub_blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + @property + def required_inputs(self) -> List[str]: + if None not in self.block_trigger_inputs: + return [] + first_block = next(iter(self.sub_blocks.values())) + required_by_all = set(getattr(first_block, "required_inputs", set())) + + # Intersect with required inputs from all other blocks + for block in list(self.sub_blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_all.intersection_update(block_required) + + return list(required_by_all) + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediate_inputs(self) -> List[str]: + if None not in self.block_trigger_inputs: + return [] + first_block = next(iter(self.sub_blocks.values())) + required_by_all = set(getattr(first_block, "required_intermediate_inputs", set())) + + # Intersect with required inputs from all other blocks + for block in list(self.sub_blocks.values())[1:]: + block_required = set(getattr(block, "required_intermediate_inputs", set())) + required_by_all.intersection_update(block_required) + + return list(required_by_all) + + # YiYi TODO: add test for this + @property + def inputs(self) -> List[Tuple[str, Any]]: + named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()] + combined_inputs = self.combine_inputs(*named_inputs) + # mark Required inputs only if that input is required by all the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + @property + def intermediate_inputs(self) -> List[str]: + named_inputs = [(name, block.intermediate_inputs) for name, block in self.sub_blocks.items()] + combined_inputs = self.combine_inputs(*named_inputs) + # mark Required inputs only if that input is required by all the blocks + for input_param in combined_inputs: + if input_param.name in self.required_intermediate_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + @property + def intermediate_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()] + combined_outputs = self.combine_outputs(*named_outputs) + return combined_outputs + + @property + def outputs(self) -> List[str]: + named_outputs = [(name, block.outputs) for name, block in self.sub_blocks.items()] + combined_outputs = self.combine_outputs(*named_outputs) + return combined_outputs + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + # Find default block first (if any) + + block = self.trigger_to_block_map.get(None) + for input_name in self.block_trigger_inputs: + if input_name is not None and state.get_input(input_name) is not None: + block = self.trigger_to_block_map[input_name] + break + elif input_name is not None and state.get_intermediate(input_name) is not None: + block = self.trigger_to_block_map[input_name] + break + + if block is None: + logger.warning(f"skipping auto block: {self.__class__.__name__}") + return pipeline, state + + try: + logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") + return block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: {block.__class__.__name__}\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + + def _get_trigger_inputs(self): + """ + Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique + block_trigger_inputs values + """ + + def fn_recursive_get_trigger(blocks): + trigger_values = set() + + if blocks is not None: + for name, block in blocks.items(): + # Check if current block has trigger inputs(i.e. auto block) + if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: + # Add all non-None values from the trigger inputs list + trigger_values.update(t for t in block.block_trigger_inputs if t is not None) + + # If block has sub_blocks, recursively check them + if block.sub_blocks: + nested_triggers = fn_recursive_get_trigger(block.sub_blocks) + trigger_values.update(nested_triggers) + + return trigger_values + + trigger_inputs = set(self.block_trigger_inputs) + trigger_inputs.update(fn_recursive_get_trigger(self.sub_blocks)) + + return trigger_inputs + + @property + def trigger_inputs(self): + return self._get_trigger_inputs() + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n" + ) + + if self.trigger_inputs: + header += "\n" + header += " " + "=" * 100 + "\n" + header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" + header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n" + header += " " + "=" * 100 + "\n\n" + + # Format description with proper indentation + desc_lines = self.description.split("\n") + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = "\n".join(desc) + "\n" + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Sub-Blocks:\n" + for i, (name, block) in enumerate(self.sub_blocks.items()): + # Get trigger input for this block + trigger = None + if hasattr(self, "block_to_trigger_map"): + trigger = self.block_to_trigger_map.get(name) + # Format the trigger info + if trigger is None: + trigger_str = "[default]" + elif isinstance(trigger, (list, tuple)): + trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" + else: + trigger_str = f"[trigger: {trigger}]" + # For AutoPipelineBlocks, add bullet points + blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + else: + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split("\n") + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediate_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs, + ) + + +class SequentialPipelineBlocks(ModularPipelineBlocks): + """ + A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in + sequence. + + This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the + library implements for all the pipeline blocks (such as loading or saving etc.) + + + + This is an experimental feature and is likely to change in the future. + + + + Attributes: + block_classes: List of block classes to be used + block_names: List of prefixes for each block + """ + + block_classes = [] + block_names = [] + + @property + def description(self): + return "" + + @property + def model_name(self): + return next(iter(self.sub_blocks.values())).model_name + + @property + def expected_components(self): + expected_components = [] + for block in self.sub_blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + @property + def expected_configs(self): + expected_configs = [] + for block in self.sub_blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + @classmethod + def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": + """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. + + Args: + blocks_dict: Dictionary mapping block names to block classes or instances + + Returns: + A new SequentialPipelineBlocks instance + """ + instance = cls() + + # Create instances if classes are provided + sub_blocks = InsertableDict() + for name, block in blocks_dict.items(): + if inspect.isclass(block): + sub_blocks[name] = block() + else: + sub_blocks[name] = block + + instance.block_classes = [block.__class__ for block in sub_blocks.values()] + instance.block_names = list(sub_blocks.keys()) + instance.sub_blocks = sub_blocks + return instance + + def __init__(self): + sub_blocks = InsertableDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + sub_blocks[block_name] = block_cls() + self.sub_blocks = sub_blocks + + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.sub_blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + # Union with required inputs from all other blocks + for block in list(self.sub_blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediate_inputs(self) -> List[str]: + required_intermediate_inputs = [] + for input_param in self.intermediate_inputs: + if input_param.required: + required_intermediate_inputs.append(input_param.name) + return required_intermediate_inputs + + # YiYi TODO: add test for this + @property + def inputs(self) -> List[Tuple[str, Any]]: + return self.get_inputs() + + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()] + combined_inputs = self.combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + @property + def intermediate_inputs(self) -> List[str]: + return self.get_intermediate_inputs() + + def get_intermediate_inputs(self): + inputs = [] + outputs = set() + added_inputs = set() + + # Go through all blocks in order + for block in self.sub_blocks.values(): + # Add inputs that aren't in outputs yet + for inp in block.intermediate_inputs: + if inp.name not in outputs and inp.name not in added_inputs: + inputs.append(inp) + added_inputs.add(inp.name) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediate_outputs = [out.name for out in block.intermediate_outputs] + outputs.update(block_intermediate_outputs) + return inputs + + @property + def intermediate_outputs(self) -> List[str]: + named_outputs = [] + for name, block in self.sub_blocks.items(): + inp_names = {inp.name for inp in block.intermediate_inputs} + # so we only need to list new variables as intermediate_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) + # filter out them here so they do not end up as intermediate_outputs + if name not in inp_names: + named_outputs.append((name, block.intermediate_outputs)) + combined_outputs = self.combine_outputs(*named_outputs) + return combined_outputs + + # YiYi TODO: I think we can remove the outputs property + @property + def outputs(self) -> List[str]: + # return next(reversed(self.sub_blocks.values())).intermediate_outputs + return self.intermediate_outputs + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + for block_name, block in self.sub_blocks.items(): + try: + pipeline, state = block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return pipeline, state + + def _get_trigger_inputs(self): + """ + Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique + block_trigger_inputs values + """ + + def fn_recursive_get_trigger(blocks): + trigger_values = set() + + if blocks is not None: + for name, block in blocks.items(): + # Check if current block has trigger inputs(i.e. auto block) + if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: + # Add all non-None values from the trigger inputs list + trigger_values.update(t for t in block.block_trigger_inputs if t is not None) + + # If block has sub_blocks, recursively check them + if block.sub_blocks: + nested_triggers = fn_recursive_get_trigger(block.sub_blocks) + trigger_values.update(nested_triggers) + + return trigger_values + + return fn_recursive_get_trigger(self.sub_blocks) + + @property + def trigger_inputs(self): + return self._get_trigger_inputs() + + def _traverse_trigger_blocks(self, trigger_inputs): + # Convert trigger_inputs to a set for easier manipulation + active_triggers = set(trigger_inputs) + + def fn_recursive_traverse(block, block_name, active_triggers): + result_blocks = OrderedDict() + + # sequential(include loopsequential) or PipelineBlock + if not hasattr(block, "block_trigger_inputs"): + if block.sub_blocks: + # sequential or LoopSequentialPipelineBlocks (keep traversing) + for sub_block_name, sub_block in block.sub_blocks.items(): + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()} + result_blocks.update(blocks_to_update) + else: + # PipelineBlock + result_blocks[block_name] = block + # Add this block's output names to active triggers if defined + if hasattr(block, "outputs"): + active_triggers.update(out.name for out in block.outputs) + return result_blocks + + # auto + else: + # Find first block_trigger_input that matches any value in our active_triggers + this_block = None + for trigger_input in block.block_trigger_inputs: + if trigger_input is not None and trigger_input in active_triggers: + this_block = block.trigger_to_block_map[trigger_input] + break + + # If no matches found, try to get the default (None) block + if this_block is None and None in block.block_trigger_inputs: + this_block = block.trigger_to_block_map[None] + + if this_block is not None: + # sequential/auto (keep traversing) + if this_block.sub_blocks: + result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) + else: + # PipelineBlock + result_blocks[block_name] = this_block + # Add this block's output names to active triggers if defined + # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? + if hasattr(this_block, "outputs"): + active_triggers.update(out.name for out in this_block.outputs) + + return result_blocks + + all_blocks = OrderedDict() + for block_name, block in self.sub_blocks.items(): + blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) + all_blocks.update(blocks_to_update) + return all_blocks + + def get_execution_blocks(self, *trigger_inputs): + trigger_inputs_all = self.trigger_inputs + + if trigger_inputs is not None: + if not isinstance(trigger_inputs, (list, tuple, set)): + trigger_inputs = [trigger_inputs] + invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] + if invalid_inputs: + logger.warning( + f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" + ) + trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] + + if trigger_inputs is None: + if None in trigger_inputs_all: + trigger_inputs = [None] + else: + trigger_inputs = [trigger_inputs_all[0]] + blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) + return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n" + ) + + if self.trigger_inputs: + header += "\n" + header += " " + "=" * 100 + "\n" + header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" + header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n" + # Get first trigger input as example + example_input = next(t for t in self.trigger_inputs if t is not None) + header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += " " + "=" * 100 + "\n\n" + + # Format description with proper indentation + desc_lines = self.description.split("\n") + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = "\n".join(desc) + "\n" + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Sub-Blocks:\n" + for i, (name, block) in enumerate(self.sub_blocks.items()): + # Get trigger input for this block + trigger = None + if hasattr(self, "block_to_trigger_map"): + trigger = self.block_to_trigger_map.get(name) + # Format the trigger info + if trigger is None: + trigger_str = "[default]" + elif isinstance(trigger, (list, tuple)): + trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" + else: + trigger_str = f"[trigger: {trigger}]" + # For AutoPipelineBlocks, add bullet points + blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + else: + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split("\n") + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediate_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs, + ) + + +class LoopSequentialPipelineBlocks(ModularPipelineBlocks): + """ + A Pipeline blocks that combines multiple pipeline block classes into a For Loop. When called, it will call each + block in sequence. + + This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the + library implements for all the pipeline blocks (such as loading or saving etc.) + + + + This is an experimental feature and is likely to change in the future. + + + + Attributes: + block_classes: List of block classes to be used + block_names: List of prefixes for each block + """ + + model_name = None + block_classes = [] + block_names = [] + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + raise NotImplementedError("description method must be implemented in subclasses") + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def loop_expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def loop_inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediate_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediate_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_required_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def loop_required_intermediate_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_intermediate_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + # modified from SequentialPipelineBlocks to include loop_expected_components + @property + def expected_components(self): + expected_components = [] + for block in self.sub_blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + for component in self.loop_expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + # modified from SequentialPipelineBlocks to include loop_expected_configs + @property + def expected_configs(self): + expected_configs = [] + for block in self.sub_blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + for config in self.loop_expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + # modified from SequentialPipelineBlocks to include loop_inputs + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()] + named_inputs.append(("loop", self.loop_inputs)) + combined_inputs = self.combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + @property + # Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs + def inputs(self): + return self.get_inputs() + + # modified from SequentialPipelineBlocks to include loop_intermediate_inputs + @property + def intermediate_inputs(self): + intermediates = self.get_intermediate_inputs() + intermediate_names = [input.name for input in intermediates] + for loop_intermediate_input in self.loop_intermediate_inputs: + if loop_intermediate_input.name not in intermediate_names: + intermediates.append(loop_intermediate_input) + return intermediates + + # modified from SequentialPipelineBlocks + def get_intermediate_inputs(self): + inputs = [] + outputs = set() + + # Go through all blocks in order + for block in self.sub_blocks.values(): + # Add inputs that aren't in outputs yet + inputs.extend(input_name for input_name in block.intermediate_inputs if input_name.name not in outputs) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediate_outputs = [out.name for out in block.intermediate_outputs] + outputs.update(block_intermediate_outputs) + return inputs + + # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.sub_blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + required_by_loop = set(getattr(self, "loop_required_inputs", set())) + required_by_any.update(required_by_loop) + + # Union with required inputs from all other blocks + for block in list(self.sub_blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediate_inputs(self) -> List[str]: + required_intermediate_inputs = [] + for input_param in self.intermediate_inputs: + if input_param.required: + required_intermediate_inputs.append(input_param.name) + for input_param in self.loop_intermediate_inputs: + if input_param.required: + required_intermediate_inputs.append(input_param.name) + return required_intermediate_inputs + + # YiYi TODO: this need to be thought about more + # modified from SequentialPipelineBlocks to include loop_intermediate_outputs + @property + def intermediate_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()] + combined_outputs = self.combine_outputs(*named_outputs) + for output in self.loop_intermediate_outputs: + if output.name not in {output.name for output in combined_outputs}: + combined_outputs.append(output) + return combined_outputs + + # YiYi TODO: this need to be thought about more + @property + def outputs(self) -> List[str]: + return next(reversed(self.sub_blocks.values())).intermediate_outputs + + def __init__(self): + sub_blocks = InsertableDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + sub_blocks[block_name] = block_cls() + self.sub_blocks = sub_blocks + + @classmethod + def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks": + """ + Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks. + + Args: + blocks_dict: Dictionary mapping block names to block instances + + Returns: + A new LoopSequentialPipelineBlocks instance + """ + instance = cls() + + # Create instances if classes are provided + sub_blocks = InsertableDict() + for name, block in blocks_dict.items(): + if inspect.isclass(block): + sub_blocks[name] = block() + else: + sub_blocks[name] = block + + instance.block_classes = [block.__class__ for block in blocks_dict.values()] + instance.block_names = list(blocks_dict.keys()) + instance.sub_blocks = blocks_dict + return instance + + def loop_step(self, components, state: PipelineState, **kwargs): + for block_name, block in self.sub_blocks.items(): + try: + components, state = block(components, state, **kwargs) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return components, state + + def __call__(self, components, state: PipelineState) -> PipelineState: + raise NotImplementedError("`__call__` method needs to be implemented by the subclass") + + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediate_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) + if intermediate_kwargs: + for k, v in intermediate_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def set_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediate_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.set_intermediate(output_param.name, param, output_param.kwargs_type) + + for input_param in self.intermediate_inputs: + if input_param.name and hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.set_intermediate(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: + # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters + # we need to first find out which inputs are and loop through them. + intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediate_kwargs.items(): + if not hasattr(block_state, param_name): + continue + param = getattr(block_state, param_name) + if current_value is not param: # Using identity comparison to check if object was modified + state.set_intermediate(param_name, param, input_param.kwargs_type) + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediate_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs, + ) + + # modified from SequentialPipelineBlocks, + # (does not need trigger_inputs related part so removed them, + # do not need to support auto block for loop blocks) + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n" + ) + + # Format description with proper indentation + desc_lines = self.description.split("\n") + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = "\n".join(desc) + "\n" + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Sub-Blocks:\n" + for i, (name, block) in enumerate(self.sub_blocks.items()): + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split("\n") + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result + + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + +# YiYi TODO: +# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) +# 2. do we need ConfigSpec? the are basically just key/val kwargs +# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_default_components(), load_components() +class ModularPipeline(ConfigMixin, PushToHubMixin): + """ + Base class for all Modular pipelines. + + + + This is an experimental feature and is likely to change in the future. + + + + Args: + blocks: ModularPipelineBlocks, the blocks to be used in the pipeline + """ + + config_name = "modular_model_index.json" + hf_device_map = None + + # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name + def __init__( + self, + blocks: Optional[ModularPipelineBlocks] = None, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + components_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + **kwargs, + ): + """ + Initialize a ModularPipeline instance. + + This method sets up the pipeline by: + - creating default pipeline blocks if not provided + - gather component and config specifications based on the pipeline blocks's requirement (e.g. + expected_components, expected_configs) + - update the loading specs of from_pretrained components based on the modular_model_index.json file from + huggingface hub if `pretrained_model_name_or_path` is provided + - create defaultfrom_config components and register everything + + Args: + blocks: `ModularPipelineBlocks` instance. If None, will attempt to load + default blocks based on the pipeline class name. + pretrained_model_name_or_path: Path to a pretrained pipeline configuration. If provided, + will load component specs (only for from_pretrained components) and config values from the saved + modular_model_index.json file. + components_manager: + Optional ComponentsManager for managing multiple component cross different pipelines and apply + offloading strategies. + collection: Optional collection name for organizing components in the ComponentsManager. + **kwargs: Additional arguments passed to `load_config()` when loading pretrained configuration. + + Examples: + ```python + # Initialize with custom blocks + pipeline = ModularPipeline(blocks=my_custom_blocks) + + # Initialize from pretrained configuration + pipeline = ModularPipeline(blocks=my_blocks, pretrained_model_name_or_path="my-repo/modular-pipeline") + + # Initialize with components manager + pipeline = ModularPipeline( + blocks=my_blocks, components_manager=ComponentsManager(), collection="my_collection" + ) + ``` + + Notes: + - If blocks is None, the method will try to find default blocks based on the pipeline class name + - Components with default_creation_method="from_config" are created immediately, its specs are not included + in config dict and will not be saved in `modular_model_index.json` + - Components with default_creation_method="from_pretrained" are set to None and can be loaded later with + `load_default_components()`/`load_components()` + - The pipeline's config dict is populated with component specs (only for from_pretrained components) and + config values, which will be saved as `modular_model_index.json` during `save_pretrained` + - The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as + `_blocks_class_name` in the config dict + """ + if blocks is None: + blocks_class_name = MODULAR_PIPELINE_BLOCKS_MAPPING.get(self.__class__.__name__) + if blocks_class_name is not None: + diffusers_module = importlib.import_module("diffusers") + blocks_class = getattr(diffusers_module, blocks_class_name) + blocks = blocks_class() + else: + logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}") + + self.blocks = blocks + self._components_manager = components_manager + self._collection = collection + self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components} + self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs} + + # update component_specs and config_specs from modular_repo + if pretrained_model_name_or_path is not None: + config_dict = self.load_config(pretrained_model_name_or_path, **kwargs) + + for name, value in config_dict.items(): + # all the components in modular_model_index.json are from_pretrained components + if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = self._dict_to_component_spec(name, component_spec_dict) + component_spec.default_creation_method = "from_pretrained" + self._component_specs[name] = component_spec + + elif name in self._config_specs: + self._config_specs[name].default = value + + register_components_dict = {} + for name, component_spec in self._component_specs.items(): + if component_spec.default_creation_method == "from_config": + component = component_spec.create() + else: + component = None + register_components_dict[name] = component + self.register_components(**register_components_dict) + + default_configs = {} + for name, config_spec in self._config_specs.items(): + default_configs[name] = config_spec.default + self.register_to_config(**default_configs) + + self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None) + + @property + def default_call_parameters(self) -> Dict[str, Any]: + """ + Returns: + - Dictionary mapping input names to their default values + """ + params = {} + for input_param in self.blocks.inputs: + params[input_param.name] = input_param.default + return params + + def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Execute the pipeline by running the pipeline blocks with the given inputs. + + Args: + state (`PipelineState`, optional): + PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be + created based on the user inputs and the pipeline blocks's requirement. + output (`str` or `List[str]`, optional): + Optional specification of what to return: + - None: Returns the complete `PipelineState` with all inputs and intermediates (default) + - str: Returns a specific intermediate value from the state (e.g. `output="image"`) + - List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image", + "latents"]`) + + + Examples: + ```python + # Get complete pipeline state + state = pipeline(prompt="A beautiful sunset", num_inference_steps=20) + print(state.intermediates) # All intermediate outputs + + # Get specific output + image = pipeline(prompt="A beautiful sunset", output="image") + + # Get multiple specific outputs + results = pipeline(prompt="A beautiful sunset", output=["image", "latents"]) + image, latents = results["image"], results["latents"] + + # Continue from previous state + state = pipeline(prompt="A beautiful sunset") + new_state = pipeline(state=state, output="image") # Continue processing + ``` + + Returns: + - If `output` is None: Complete `PipelineState` containing all inputs and intermediates + - If `output` is str: The specific intermediate value from the state (e.g. `output="image"`) + - If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g. + `output=["image", "latents"]`) + """ + if state is None: + state = PipelineState() + + # Make a copy of the input kwargs + passed_kwargs = kwargs.copy() + + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + + intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs] + for expected_input_param in self.blocks.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: + if name not in intermediate_inputs: + state.set_input(name, passed_kwargs.pop(name), kwargs_type) + else: + state.set_input(name, passed_kwargs[name], kwargs_type) + elif name not in state.inputs: + state.set_input(name, default, kwargs_type) + + for expected_intermediate_param in self.blocks.intermediate_inputs: + name = expected_intermediate_param.name + kwargs_type = expected_intermediate_param.kwargs_type + if name in passed_kwargs: + state.set_intermediate(name, passed_kwargs.pop(name), kwargs_type) + + # Warn about unexpected inputs + if len(passed_kwargs) > 0: + warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + _, state = self.blocks(self, state) + except Exception: + error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n" + logger.error(error_msg) + raise + + if output is None: + return state + + elif isinstance(output, str): + return state.get_intermediate(output) + + elif isinstance(output, (list, tuple)): + return state.get_intermediates(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") + + def load_default_components(self, **kwargs): + """ + Load from_pretrained components using the loading specs in the config dict. + + Args: + **kwargs: Additional arguments passed to `from_pretrained` method, e.g. torch_dtype, cache_dir, etc. + """ + names = [ + name + for name in self._component_specs.keys() + if self._component_specs[name].default_creation_method == "from_pretrained" + ] + self.load_components(names=names, **kwargs) + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + trust_remote_code: Optional[bool] = None, + components_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + **kwargs, + ): + """ + Load a ModularPipeline from a huggingface hub repo. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`, optional): + Path to a pretrained pipeline configuration. If provided, will load component specs (only for + from_pretrained components) and config values from the modular_model_index.json file. + trust_remote_code (`bool`, optional): + Whether to trust remote code when loading the pipeline, need to be set to True if you want to create + pipeline blocks based on the custom code in `pretrained_model_name_or_path` + components_manager (`ComponentsManager`, optional): + ComponentsManager instance for managing multiple component cross different pipelines and apply + offloading strategies. + collection (`str`, optional):` + Collection name for organizing components in the ComponentsManager. + """ + from ..pipelines.pipeline_loading_utils import _get_pipeline_class + + try: + blocks = ModularPipelineBlocks.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + except EnvironmentError: + blocks = None + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + + load_config_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "token": token, + "local_files_only": local_files_only, + "revision": revision, + } + + try: + config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs) + pipeline_class = _get_pipeline_class(cls, config=config_dict) + except EnvironmentError: + pipeline_class = cls + pretrained_model_name_or_path = None + + pipeline = pipeline_class( + blocks=blocks, + pretrained_model_name_or_path=pretrained_model_name_or_path, + components_manager=components_manager, + collection=collection, + **kwargs, + ) + return pipeline + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save the pipeline to a directory. It does not save components, you need to save them separately. + + Args: + save_directory (`str` or `os.PathLike`): + Path to the directory where the pipeline will be saved. + push_to_hub (`bool`, optional): + Whether to push the pipeline to the huggingface hub. + **kwargs: Additional arguments passed to `save_config()` method + """ + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", None) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + # Create a new empty model card and eventually tag it + model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True) + model_card = populate_model_card(model_card) + model_card.save(os.path.join(save_directory, "README.md")) + + # YiYi TODO: maybe order the json file to make it more readable: configs first, then components + self.save_config(save_directory=save_directory) + + if push_to_hub: + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + + @property + def doc(self): + """ + Returns: + - The docstring of the pipeline blocks + """ + return self.blocks.doc + + def register_components(self, **kwargs): + """ + Register components with their corresponding specifications. + + This method is responsible for: + 1. Sets component objects as attributes on the loader (e.g., self.unet = unet) + 2. Updates the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` (only + for from_pretrained components) + 3. Adds components to the component manager if one is attached (only for from_pretrained components) + + This method is called when: + - Components are first initialized in __init__: + - from_pretrained components not loaded during __init__ so they are registered as None; + - non from_pretrained components are created during __init__ and registered as the object itself + - Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or + loader.update_components(guider=guider_spec) + - (from_pretrained) Components are loaded with the `load_default_components()` method: e.g. + loader.load_default_components(names=["unet"]) + + Args: + **kwargs: Keyword arguments where keys are component names and values are component objects. + E.g., register_components(unet=unet_model, text_encoder=encoder_model) + + Notes: + - When registering None for a component, it sets attribute to None but still syncs specs with the config + dict, which will be saved as `modular_model_index.json` during `save_pretrained` + - component_specs are updated to match the new component outside of this method, e.g. in + `update_components()` method + """ + for name, module in kwargs.items(): + # current component spec + component_spec = self._component_specs.get(name) + if component_spec is None: + logger.warning(f"ModularPipeline.register_components: skipping unknown component '{name}'") + continue + + # check if it is the first time registration, i.e. calling from __init__ + is_registered = hasattr(self, name) + is_from_pretrained = component_spec.default_creation_method == "from_pretrained" + + if module is not None: + # actual library and class name of the module + library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") + else: + # if module is None, e.g. self.register_components(unet=None) during __init__ + # we do not update the spec, + # but we still need to update the modular_model_index.json config based on component spec + library, class_name = None, None + + # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config + # e.g. {"repo": "stabilityai/stable-diffusion-2-1", + # "type_hint": ("diffusers", "UNet2DConditionModel"), + # "subfolder": "unet", + # "variant": None, + # "revision": None} + component_spec_dict = self._component_spec_to_dict(component_spec) + + register_dict = {name: (library, class_name, component_spec_dict)} + + # set the component as attribute + # if it is not set yet, just set it and skip the process to check and warn below + if not is_registered: + if is_from_pretrained: + self.register_to_config(**register_dict) + setattr(self, name, module) + if module is not None and is_from_pretrained and self._components_manager is not None: + self._components_manager.add(name, module, self._collection) + continue + + current_module = getattr(self, name, None) + # skip if the component is already registered with the same object + if current_module is module: + logger.info( + f"ModularPipeline.register_components: {name} is already registered with same object, skipping" + ) + continue + + # warn if unregister + if current_module is not None and module is None: + logger.info( + f"ModularPipeline.register_components: setting '{name}' to None " + f"(was {current_module.__class__.__name__})" + ) + # same type, new instance → replace but send debug log + elif ( + current_module is not None + and module is not None + and isinstance(module, current_module.__class__) + and current_module != module + ): + logger.debug( + f"ModularPipeline.register_components: replacing existing '{name}' " + f"(same type {type(current_module).__name__}, new instance)" + ) + + # update modular_model_index.json config + if is_from_pretrained: + self.register_to_config(**register_dict) + # finally set models + setattr(self, name, module) + # add to component manager if one is attached + if module is not None and is_from_pretrained and self._components_manager is not None: + self._components_manager.add(name, module, self._collection) + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + + @property + # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from + Accelerate's module hooks. + """ + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module): + continue + + if not hasattr(model, "_hf_hook"): + return self.device + for module in model.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @property + def dtype(self) -> torch.dtype: + r""" + Returns: + `torch.dtype`: The torch dtype on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.dtype + + return torch.float32 + + @property + def null_component_names(self) -> List[str]: + """ + Returns: + - List of names for components that needs to be loaded + """ + return [name for name in self._component_specs.keys() if hasattr(self, name) and getattr(self, name) is None] + + @property + def component_names(self) -> List[str]: + """ + Returns: + - List of names for all components + """ + return list(self.components.keys()) + + @property + def pretrained_component_names(self) -> List[str]: + """ + Returns: + - List of names for from_pretrained components + """ + return [ + name + for name in self._component_specs.keys() + if self._component_specs[name].default_creation_method == "from_pretrained" + ] + + @property + def config_component_names(self) -> List[str]: + """ + Returns: + - List of names for from_config components + """ + return [ + name + for name in self._component_specs.keys() + if self._component_specs[name].default_creation_method == "from_config" + ] + + @property + def components(self) -> Dict[str, Any]: + """ + Returns: + - Dictionary mapping component names to their objects (include both from_pretrained and from_config + components) + """ + # return only components we've actually set as attributes on self + return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)} + + def get_component_spec(self, name: str) -> ComponentSpec: + """ + Returns: + - a copy of the ComponentSpec object for the given component name + """ + return deepcopy(self._component_specs[name]) + + def update_components(self, **kwargs): + """ + Update components and configuration values and specs after the pipeline has been instantiated. + + This method allows you to: + 1. Replace existing components with new ones (e.g., updating `self.unet` or `self.text_encoder`) + 2. Update configuration values (e.g., changing `self.requires_safety_checker` flag) + + In addition to updating the components and configuration values as pipeline attributes, the method also + updates: + - the corresponding specs in `_component_specs` and `_config_specs` + - the `config` dict, which will be saved as `modular_model_index.json` during `save_pretrained` + + Args: + **kwargs: Component objects, ComponentSpec objects, or configuration values to update: + - Component objects: Only supports components we can extract specs using + `ComponentSpec.from_component()` method i.e. components created with ComponentSpec.load() or + ConfigMixin subclasses that aren't nn.Modules (e.g., `unet=new_unet, text_encoder=new_encoder`) + - ComponentSpec objects: Only supports default_creation_method == "from_config", will call create() + method to create a new component (e.g., `guider=ComponentSpec(name="guider", + type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`) + - Configuration values: Simple values to update configuration settings (e.g., + `requires_safety_checker=False`) + + Raises: + ValueError: If a component object is not supported in ComponentSpec.from_component() method: + - nn.Module components without a valid `_diffusers_load_id` attribute + - Non-ConfigMixin components without a valid `_diffusers_load_id` attribute + + Examples: + ```python + # Update multiple components at once + pipeline.update_components(unet=new_unet_model, text_encoder=new_text_encoder) + + # Update configuration values + pipeline.update_components(requires_safety_checker=False) + + # Update both components and configs together + pipeline.update_components(unet=new_unet_model, requires_safety_checker=False) + + # Update with ComponentSpec objects (from_config only) + pipeline.update_components( + guider=ComponentSpec( + name="guider", + type_hint=ClassifierFreeGuidance, + config={"guidance_scale": 5.0}, + default_creation_method="from_config", + ) + ) + ``` + + Notes: + - Components with trained weights must be created using ComponentSpec.load(). If the component has not been + shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()` + - ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly + - ComponentSpec objects with default_creation_method="from_pretrained" are not supported in + update_components() + """ + + # extract component_specs_updates & config_specs_updates from `specs` + passed_component_specs = { + k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec) + } + passed_components = { + k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec) + } + passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} + + for name, component in passed_components.items(): + current_component_spec = self._component_specs[name] + + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance( + component, current_component_spec.type_hint + ): + logger.warning( + f"ModularPipeline.update_components: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" + ) + # update _component_specs based on the new component + new_component_spec = ComponentSpec.from_component(name, component) + if new_component_spec.default_creation_method != current_component_spec.default_creation_method: + logger.warning( + f"ModularPipeline.update_components: changing the default_creation_method of {name} from {current_component_spec.default_creation_method} to {new_component_spec.default_creation_method}." + ) + + self._component_specs[name] = new_component_spec + + if len(kwargs) > 0: + logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") + + created_components = {} + for name, component_spec in passed_component_specs.items(): + if component_spec.default_creation_method == "from_pretrained": + raise ValueError( + "ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update_components() method" + ) + created_components[name] = component_spec.create() + current_component_spec = self._component_specs[name] + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance( + created_components[name], current_component_spec.type_hint + ): + logger.warning( + f"ModularPipeline.update_components: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" + ) + # update _component_specs based on the user passed component_spec + self._component_specs[name] = component_spec + self.register_components(**passed_components, **created_components) + + config_to_register = {} + for name, new_value in passed_config_values.items(): + # e.g. requires_aesthetics_score = False + self._config_specs[name].default = new_value + config_to_register[name] = new_value + self.register_to_config(**config_to_register) + + # YiYi TODO: support map for additional from_pretrained kwargs + # YiYi/Dhruv TODO: consolidate load_components and load_default_components? + def load_components(self, names: Union[List[str], str], **kwargs): + """ + Load selected components from specs. + + Args: + names: List of component names to load; by default will not load any components + **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: + - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 + - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} + - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, + `variant`, `revision`, etc. + """ + + if isinstance(names, str): + names = [names] + elif not isinstance(names, list): + raise ValueError(f"Invalid type for names: {type(names)}") + + components_to_load = {name for name in names if name in self._component_specs} + unknown_names = {name for name in names if name not in self._component_specs} + if len(unknown_names) > 0: + logger.warning(f"Unknown components will be ignored: {unknown_names}") + + components_to_register = {} + for name in components_to_load: + spec = self._component_specs[name] + component_load_kwargs = {} + for key, value in kwargs.items(): + if not isinstance(value, dict): + # if the value is a single value, apply it to all components + component_load_kwargs[key] = value + else: + if name in value: + # if it is a dict, check if the component name is in the dict + component_load_kwargs[key] = value[name] + elif "default" in value: + # check if the default is specified + component_load_kwargs[key] = value["default"] + try: + components_to_register[name] = spec.load(**component_load_kwargs) + except Exception as e: + logger.warning(f"Failed to create component '{name}': {e}") + + # Register all components at once + self.register_components(**components_to_register) + + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._maybe_raise_error_if_group_offload_active + def _maybe_raise_error_if_group_offload_active( + self, raise_error: bool = False, module: Optional[torch.nn.Module] = None + ) -> bool: + from ..hooks.group_offloading import _is_group_offload_enabled + + components = self.components.values() if module is None else [module] + components = [component for component in components if isinstance(component, torch.nn.Module)] + for component in components: + if _is_group_offload_enabled(component): + if raise_error: + raise ValueError( + "You are trying to apply model/sequential CPU offloading to a pipeline that contains components " + "with group offloading enabled. This is not supported. Please disable group offloading for " + "components of the pipeline to use other offloading methods." + ) + return True + return False + + # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to + def to(self, *args, **kwargs) -> Self: + r""" + Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the + arguments of `self.to(*args, **kwargs).` + + + + If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise, + the returned pipeline is a copy of self with the desired torch.dtype and torch.device. + + + + + Here are the ways to call `to`: + + - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified + [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) + - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the + specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + + Arguments: + dtype (`torch.dtype`, *optional*): + Returns a pipeline with the specified + [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + device (`torch.Device`, *optional*): + Returns a pipeline with the specified + [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) + silence_dtype_warnings (`str`, *optional*, defaults to `False`): + Whether to omit warnings if the target `dtype` is not compatible with the target `device`. + + Returns: + [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. + """ + from ..pipelines.pipeline_utils import _check_bnb_status + from ..utils import is_accelerate_available, is_accelerate_version, is_hpu_available, is_transformers_version + + dtype = kwargs.pop("dtype", None) + device = kwargs.pop("device", None) + silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) + + dtype_arg = None + device_arg = None + if len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype_arg = args[0] + else: + device_arg = torch.device(args[0]) if args[0] is not None else None + elif len(args) == 2: + if isinstance(args[0], torch.dtype): + raise ValueError( + "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`." + ) + device_arg = torch.device(args[0]) if args[0] is not None else None + dtype_arg = args[1] + elif len(args) > 2: + raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`") + + if dtype is not None and dtype_arg is not None: + raise ValueError( + "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two." + ) + + dtype = dtype or dtype_arg + + if device is not None and device_arg is not None: + raise ValueError( + "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two." + ) + + device = device or device_arg + device_type = torch.device(device).type if device is not None else None + pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items()) + + # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. + def module_is_sequentially_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): + return False + + _, _, is_loaded_in_8bit_bnb = _check_bnb_status(module) + + if is_loaded_in_8bit_bnb: + return False + + return hasattr(module, "_hf_hook") and ( + isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) + or hasattr(module._hf_hook, "hooks") + and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook) + ) + + def module_is_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): + return False + + return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) + + # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer + pipeline_is_sequentially_offloaded = any( + module_is_sequentially_offloaded(module) for _, module in self.components.items() + ) + + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + if is_pipeline_device_mapped: + raise ValueError( + "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline." + ) + + if device_type in ["cuda", "xpu"]: + if pipeline_is_sequentially_offloaded and not pipeline_has_bnb: + raise ValueError( + "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." + ) + # PR: https://github.com/huggingface/accelerate/pull/3223/ + elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"): + raise ValueError( + "You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation." + ) + + # Display a warning in this case (the operation succeeds but the benefits are lost) + pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) + if pipeline_is_offloaded and device_type in ["cuda", "xpu"]: + logger.warning( + f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." + ) + + # Enable generic support for Intel Gaudi accelerator using GPU/HPU migration + if device_type == "hpu" and kwargs.pop("hpu_migration", True) and is_hpu_available(): + os.environ["PT_HPU_GPU_MIGRATION"] = "1" + logger.debug("Environment variable set: PT_HPU_GPU_MIGRATION=1") + + import habana_frameworks.torch # noqa: F401 + + # HPU hardware check + if not (hasattr(torch, "hpu") and torch.hpu.is_available()): + raise ValueError("You are trying to call `.to('hpu')` but HPU device is unavailable.") + + os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1" + logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1") + + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded + for module in modules: + _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module) + is_group_offloaded = self._maybe_raise_error_if_group_offload_active(module=module) + + if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision." + ) + + if is_loaded_in_8bit_bnb and device is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." + ) + + # Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling + # components can be from outside diffusers too, but still have group offloading enabled. + if ( + self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module) + and device is not None + ): + logger.warning( + f"The module '{module.__class__.__name__}' is group offloaded and moving it to {device} via `.to()` is not supported." + ) + + # This can happen for `transformer` models. CPU placement was added in + # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. + if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): + module.to(device=device) + elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded: + module.to(device, dtype) + + if ( + module.dtype == torch.float16 + and str(device) in ["cpu"] + and not silence_dtype_warnings + and not is_offloaded + ): + logger.warning( + "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" + " is not recommended to move them to `cpu` as running them will fail. Please make" + " sure to use an accelerator to run the pipeline in inference, due to the lack of" + " support for`float16` operations on this device in PyTorch. Please, remove the" + " `torch_dtype=torch.float16` argument, or use another device for inference." + ) + return self + + @staticmethod + def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: + """ + Convert a ComponentSpec into a JSON‐serializable dict for saving as an entry in `modular_model_index.json`. If + the `default_creation_method` is not `from_pretrained`, return None. + + This dict contains: + - "type_hint": Tuple[str, str] + Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel")) + - All loading fields defined by `component_spec.loading_fields()`, typically: + - "repo": Optional[str] + The model repository (e.g., "stabilityai/stable-diffusion-xl"). + - "subfolder": Optional[str] + A subfolder within the repo where this component lives. + - "variant": Optional[str] + An optional variant identifier for the model. + - "revision": Optional[str] + A specific git revision (commit hash, tag, or branch). + - ... any other loading fields defined on the spec. + + Args: + component_spec (ComponentSpec): + The spec object describing one pipeline component. + + Returns: + Dict[str, Any]: A mapping suitable for JSON serialization. + + Example: + >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import + UNet2DConditionModel >>> spec = ComponentSpec( + ... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ... repo="path/to/repo", ... + subfolder="subfolder", ... variant=None, ... revision=None, ... + default_creation_method="from_pretrained", + ... ) >>> ModularPipeline._component_spec_to_dict(spec) { + "type_hint": ("diffusers", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": "subfolder", + "variant": None, "revision": None, + } + """ + if component_spec.default_creation_method != "from_pretrained": + return None + + if component_spec.type_hint is not None: + lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) + else: + lib_name = None + cls_name = None + load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} + return { + "type_hint": (lib_name, cls_name), + **load_spec_dict, + } + + @staticmethod + def _dict_to_component_spec( + name: str, + spec_dict: Dict[str, Any], + ) -> ComponentSpec: + """ + Reconstruct a ComponentSpec from a loading specdict. + + This method converts a dictionary representation back into a ComponentSpec object. The dict should contain: + - "type_hint": Tuple[str, str] + Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel")) + - All loading fields defined by `component_spec.loading_fields()`, typically: + - "repo": Optional[str] + The model repository (e.g., "stabilityai/stable-diffusion-xl"). + - "subfolder": Optional[str] + A subfolder within the repo where this component lives. + - "variant": Optional[str] + An optional variant identifier for the model. + - "revision": Optional[str] + A specific git revision (commit hash, tag, or branch). + - ... any other loading fields defined on the spec. + + Args: + name (str): + The name of the component. + specdict (Dict[str, Any]): + A dictionary containing the component specification data. + + Returns: + ComponentSpec: A reconstructed ComponentSpec object. + + Example: + >>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ... "repo": + "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant": None, ... "revision": None, ... + } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) ComponentSpec( + name="unet", type_hint=UNet2DConditionModel, config=None, repo="stabilityai/stable-diffusion-xl", + subfolder="unet", variant=None, revision=None, default_creation_method="from_pretrained" + ) + """ + # make a shallow copy so we can pop() safely + spec_dict = spec_dict.copy() + # pull out and resolve the stored type_hint + lib_name, cls_name = spec_dict.pop("type_hint") + if lib_name is not None and cls_name is not None: + type_hint = simple_get_class_obj(lib_name, cls_name) + else: + type_hint = None + + # re‐assemble the ComponentSpec + return ComponentSpec( + name=name, + type_hint=type_hint, + **spec_dict, + ) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py new file mode 100644 index 000000000000..4fac5ef4f2d5 --- /dev/null +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -0,0 +1,671 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +from collections import OrderedDict +from dataclasses import dataclass, field, fields +from typing import Any, Dict, List, Literal, Optional, Type, Union + +import torch + +from ..configuration_utils import ConfigMixin, FrozenDict +from ..utils import is_torch_available, logging + + +if is_torch_available(): + pass + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class InsertableDict(OrderedDict): + def insert(self, key, value, index): + items = list(self.items()) + + # Remove key if it already exists to avoid duplicates + items = [(k, v) for k, v in items if k != key] + + # Insert at the specified index + items.insert(index, (key, value)) + + # Clear and update self + self.clear() + self.update(items) + + # Return self for method chaining + return self + + def __repr__(self): + if not self: + return "InsertableDict()" + + items = [] + for i, (key, value) in enumerate(self.items()): + if isinstance(value, type): + # For classes, show class name and + obj_repr = f"" + else: + # For objects (instances) and other types, show class name and module + obj_repr = f"" + items.append(f"{i}: ({repr(key)}, {obj_repr})") + + return "InsertableDict([\n " + ",\n ".join(items) + "\n])" + + +# YiYi TODO: +# 1. validate the dataclass fields +# 2. improve the docstring and potentially add a validator for load methods, make sure they are valid inputs to pass to from_pretrained() +@dataclass +class ComponentSpec: + """Specification for a pipeline component. + + A component can be created in two ways: + 1. From scratch using __init__ with a config dict + 2. using `from_pretrained` + + Attributes: + name: Name of the component + type_hint: Type of the component (e.g. UNet2DConditionModel) + description: Optional description of the component + config: Optional config dict for __init__ creation + repo: Optional repo path for from_pretrained creation + subfolder: Optional subfolder in repo + variant: Optional variant in repo + revision: Optional revision in repo + default_creation_method: Preferred creation method - "from_config" or "from_pretrained" + """ + + name: Optional[str] = None + type_hint: Optional[Type] = None + description: Optional[str] = None + config: Optional[FrozenDict] = None + # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name + repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) + subfolder: Optional[str] = field(default=None, metadata={"loading": True}) + variant: Optional[str] = field(default=None, metadata={"loading": True}) + revision: Optional[str] = field(default=None, metadata={"loading": True}) + default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" + + def __hash__(self): + """Make ComponentSpec hashable, using load_id as the hash value.""" + return hash((self.name, self.load_id, self.default_creation_method)) + + def __eq__(self, other): + """Compare ComponentSpec objects based on name and load_id.""" + if not isinstance(other, ComponentSpec): + return False + return ( + self.name == other.name + and self.load_id == other.load_id + and self.default_creation_method == other.default_creation_method + ) + + @classmethod + def from_component(cls, name: str, component: Any) -> Any: + """Create a ComponentSpec from a Component. + + Currently supports: + - Components created with `ComponentSpec.load()` method + - Components that are ConfigMixin subclasses but not nn.Modules (e.g. schedulers, guiders) + + Args: + name: Name of the component + component: Component object to create spec from + + Returns: + ComponentSpec object + + Raises: + ValueError: If component is not supported (e.g. nn.Module without load_id, non-ConfigMixin) + """ + + # Check if component was created with ComponentSpec.load() + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": + # component has a usable load_id -> from_pretrained, no warning needed + default_creation_method = "from_pretrained" + else: + # Component doesn't have a usable load_id, check if it's a nn.Module + if isinstance(component, torch.nn.Module): + raise ValueError( + "Cannot create ComponentSpec from a nn.Module that was not created with `ComponentSpec.load()` method." + ) + # ConfigMixin objects without weights (e.g. scheduler & guider) can be recreated with from_config + elif isinstance(component, ConfigMixin): + # warn if component was not created with `ComponentSpec` + if not hasattr(component, "_diffusers_load_id"): + logger.warning( + "Component was not created using `ComponentSpec`, defaulting to `from_config` creation method" + ) + default_creation_method = "from_config" + else: + # Not a ConfigMixin and not created with `ComponentSpec.load()` method -> throw error + raise ValueError( + f"Cannot create ComponentSpec from {name}({component.__class__.__name__}). Currently ComponentSpec.from_component() only supports: " + f" - components created with `ComponentSpec.load()` method" + f" - components that are a subclass of ConfigMixin but not a nn.Module (e.g. guider, scheduler)." + ) + + type_hint = component.__class__ + + if isinstance(component, ConfigMixin) and default_creation_method == "from_config": + config = component.config + else: + config = None + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": + load_spec = cls.decode_load_id(component._diffusers_load_id) + else: + load_spec = {} + + return cls( + name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec + ) + + @classmethod + def loading_fields(cls) -> List[str]: + """ + Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True). + """ + return [f.name for f in fields(cls) if f.metadata.get("loading", False)] + + @property + def load_id(self) -> str: + """ + Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty + segments). + """ + parts = [getattr(self, k) for k in self.loading_fields()] + parts = ["null" if p is None else p for p in parts] + return "|".join(p for p in parts if p) + + @classmethod + def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: + """ + Decode a load_id string back into a dictionary of loading fields and values. + + Args: + load_id: The load_id string to decode, format: "repo|subfolder|variant|revision" + where None values are represented as "null" + + Returns: + Dict mapping loading field names to their values. e.g. { + "repo": "path/to/repo", "subfolder": "subfolder", "variant": "variant", "revision": "revision" + } If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating + component not created with `load` method). + """ + + # Get all loading fields in order + loading_fields = cls.loading_fields() + result = {f: None for f in loading_fields} + + if load_id == "null": + return result + + # Split the load_id + parts = load_id.split("|") + + # Map parts to loading fields by position + for i, part in enumerate(parts): + if i < len(loading_fields): + # Convert "null" string back to None + result[loading_fields[i]] = None if part == "null" else part + + return result + + # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin) + # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component) + # the config info is lost in the process + # remove error check in from_component spec and ModularPipeline.update_components() if we remove support for non configmixin in `create()` method + def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: + """Create component using from_config with config.""" + + if self.type_hint is None or not isinstance(self.type_hint, type): + raise ValueError("`type_hint` is required when using from_config creation method.") + + config = config or self.config or {} + + if issubclass(self.type_hint, ConfigMixin): + component = self.type_hint.from_config(config, **kwargs) + else: + signature_params = inspect.signature(self.type_hint.__init__).parameters + init_kwargs = {} + for k, v in config.items(): + if k in signature_params: + init_kwargs[k] = v + for k, v in kwargs.items(): + if k in signature_params: + init_kwargs[k] = v + component = self.type_hint(**init_kwargs) + + component._diffusers_load_id = "null" + if hasattr(component, "config"): + self.config = component.config + + return component + + # YiYi TODO: add guard for type of model, if it is supported by from_pretrained + def load(self, **kwargs) -> Any: + """Load component using from_pretrained.""" + + # select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change + passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} + # merge loading field value in the spec with user passed values to create load_kwargs + load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} + # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path + repo = load_kwargs.pop("repo", None) + if repo is None: + raise ValueError( + "`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)" + ) + + if self.type_hint is None: + try: + from diffusers import AutoModel + + component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}") + # update type_hint if AutoModel load successfully + self.type_hint = component.__class__ + else: + try: + component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Unable to load {self.name} using load method: {e}") + + self.repo = repo + for k, v in load_kwargs.items(): + setattr(self, k, v) + component._diffusers_load_id = self.load_id + + return component + + +@dataclass +class ConfigSpec: + """Specification for a pipeline configuration parameter.""" + + name: str + default: Any + description: Optional[str] = None + + +# YiYi Notes: both inputs and intermediate_inputs are InputParam objects +# however some fields are not relevant for intermediate_inputs +# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed +# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs +# -> should we use different class for inputs and intermediate_inputs? +@dataclass +class InputParam: + """Specification for an input parameter.""" + + name: str = None + type_hint: Any = None + default: Any = None + required: bool = False + description: str = "" + kwargs_type: str = None # YiYi Notes: remove this feature (maybe) + + def __repr__(self): + return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + + +@dataclass +class OutputParam: + """Specification for an output parameter.""" + + name: str + type_hint: Any = None + description: str = "" + kwargs_type: str = None # YiYi notes: remove this feature (maybe) + + def __repr__(self): + return ( + f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" + ) + + +def format_inputs_short(inputs): + """ + Format input parameters into a string representation, with required params first followed by optional ones. + + Args: + inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params + + Returns: + str: Formatted string of input parameters + + Example: + >>> inputs = [ ... InputParam(name="prompt", required=True), ... InputParam(name="image", required=True), ... + InputParam(name="guidance_scale", required=False, default=7.5), ... InputParam(name="num_inference_steps", + required=False, default=50) ... ] >>> format_inputs_short(inputs) 'prompt, image, guidance_scale=7.5, + num_inference_steps=50' + """ + required_inputs = [param for param in inputs if param.required] + optional_inputs = [param for param in inputs if not param.required] + + required_str = ", ".join(param.name for param in required_inputs) + optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) + + inputs_str = required_str + if optional_str: + inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str + + return inputs_str + + +def format_intermediates_short(intermediate_inputs, required_intermediate_inputs, intermediate_outputs): + """ + Formats intermediate inputs and outputs of a block into a string representation. + + Args: + intermediate_inputs: List of intermediate input parameters + required_intermediate_inputs: List of required intermediate input names + intermediate_outputs: List of intermediate output parameters + + Returns: + str: Formatted string like: + Intermediates: + - inputs: Required(latents), dtype + - modified: latents # variables that appear in both inputs and outputs + - outputs: images # new outputs only + """ + # Handle inputs + input_parts = [] + for inp in intermediate_inputs: + if inp.name in required_intermediate_inputs: + input_parts.append(f"Required({inp.name})") + else: + if inp.name is None and inp.kwargs_type is not None: + inp_name = "*_" + inp.kwargs_type + else: + inp_name = inp.name + input_parts.append(inp_name) + + # Handle modified variables (appear in both inputs and outputs) + inputs_set = {inp.name for inp in intermediate_inputs} + modified_parts = [] + new_output_parts = [] + + for out in intermediate_outputs: + if out.name in inputs_set: + modified_parts.append(out.name) + else: + new_output_parts.append(out.name) + + result = [] + if input_parts: + result.append(f" - inputs: {', '.join(input_parts)}") + if modified_parts: + result.append(f" - modified: {', '.join(modified_parts)}") + if new_output_parts: + result.append(f" - outputs: {', '.join(new_output_parts)}") + + return "\n".join(result) if result else " (none)" + + +def format_params(params, header="Args", indent_level=4, max_line_length=115): + """Format a list of InputParam or OutputParam objects into a readable string representation. + + Args: + params: List of InputParam or OutputParam objects to format + header: Header text to use (e.g. "Args" or "Returns") + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all parameters + """ + if not params: + return "" + + base_indent = " " * indent_level + param_indent = " " * (indent_level + 4) + desc_indent = " " * (indent_level + 8) + formatted_params = [] + + def get_type_str(type_hint): + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: + types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] + return f"Union[{', '.join(types)}]" + return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) + + def wrap_text(text, indent, max_length): + """Wrap text while preserving markdown links and maintaining indentation.""" + words = text.split() + lines = [] + current_line = [] + current_length = 0 + + for word in words: + word_length = len(word) + (1 if current_line else 0) + + if current_line and current_length + word_length > max_length: + lines.append(" ".join(current_line)) + current_line = [word] + current_length = len(word) + else: + current_line.append(word) + current_length += word_length + + if current_line: + lines.append(" ".join(current_line)) + + return f"\n{indent}".join(lines) + + # Add the header + formatted_params.append(f"{base_indent}{header}:") + + for param in params: + # Format parameter name and type + type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" + # YiYi Notes: remove this line if we remove kwargs_type + name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name + param_str = f"{param_indent}{name} (`{type_str}`" + + # Add optional tag and default value if parameter is an InputParam and optional + if hasattr(param, "required"): + if not param.required: + param_str += ", *optional*" + if param.default is not None: + param_str += f", defaults to {param.default}" + param_str += "):" + + # Add description on a new line with additional indentation and wrapping + if param.description: + desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description) + wrapped_desc = wrap_text(desc, desc_indent, max_line_length) + param_str += f"\n{desc_indent}{wrapped_desc}" + + formatted_params.append(param_str) + + return "\n\n".join(formatted_params) + + +def format_input_params(input_params, indent_level=4, max_line_length=115): + """Format a list of InputParam objects into a readable string representation. + + Args: + input_params: List of InputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all input parameters + """ + return format_params(input_params, "Inputs", indent_level, max_line_length) + + +def format_output_params(output_params, indent_level=4, max_line_length=115): + """Format a list of OutputParam objects into a readable string representation. + + Args: + output_params: List of OutputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all output parameters + """ + return format_params(output_params, "Outputs", indent_level, max_line_length) + + +def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ComponentSpec objects into a readable string representation. + + Args: + components: List of ComponentSpec objects to format + indent_level: Number of spaces to indent each component line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between components (default: True) + + Returns: + A formatted string representing all components + """ + if not components: + return "" + + base_indent = " " * indent_level + component_indent = " " * (indent_level + 4) + formatted_components = [] + + # Add the header + formatted_components.append(f"{base_indent}Components:") + if add_empty_lines: + formatted_components.append("") + + # Add each component with optional empty lines between them + for i, component in enumerate(components): + # Get type name, handling special cases + type_name = ( + component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) + ) + + component_desc = f"{component_indent}{component.name} (`{type_name}`)" + if component.description: + component_desc += f": {component.description}" + + # Get the loading fields dynamically + loading_field_values = [] + for field_name in component.loading_fields(): + field_value = getattr(component, field_name) + if field_value is not None: + loading_field_values.append(f"{field_name}={field_value}") + + # Add loading field information if available + if loading_field_values: + component_desc += f" [{', '.join(loading_field_values)}]" + + formatted_components.append(component_desc) + + # Add an empty line after each component except the last one + if add_empty_lines and i < len(components) - 1: + formatted_components.append("") + + return "\n".join(formatted_components) + + +def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ConfigSpec objects into a readable string representation. + + Args: + configs: List of ConfigSpec objects to format + indent_level: Number of spaces to indent each config line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between configs (default: True) + + Returns: + A formatted string representing all configs + """ + if not configs: + return "" + + base_indent = " " * indent_level + config_indent = " " * (indent_level + 4) + formatted_configs = [] + + # Add the header + formatted_configs.append(f"{base_indent}Configs:") + if add_empty_lines: + formatted_configs.append("") + + # Add each config with optional empty lines between them + for i, config in enumerate(configs): + config_desc = f"{config_indent}{config.name} (default: {config.default})" + if config.description: + config_desc += f": {config.description}" + formatted_configs.append(config_desc) + + # Add an empty line after each config except the last one + if add_empty_lines and i < len(configs) - 1: + formatted_configs.append("") + + return "\n".join(formatted_configs) + + +def make_doc_string( + inputs, + intermediate_inputs, + outputs, + description="", + class_name=None, + expected_components=None, + expected_configs=None, +): + """ + Generates a formatted documentation string describing the pipeline block's parameters and structure. + + Args: + inputs: List of input parameters + intermediate_inputs: List of intermediate input parameters + outputs: List of output parameters + description (str, *optional*): Description of the block + class_name (str, *optional*): Name of the class to include in the documentation + expected_components (List[ComponentSpec], *optional*): List of expected components + expected_configs (List[ConfigSpec], *optional*): List of expected configurations + + Returns: + str: A formatted string containing information about components, configs, call parameters, + intermediate inputs/outputs, and final outputs. + """ + output = "" + + # Add class name if provided + if class_name: + output += f"class {class_name}\n\n" + + # Add description + if description: + desc_lines = description.strip().split("\n") + aligned_desc = "\n".join(" " + line for line in desc_lines) + output += aligned_desc + "\n\n" + + # Add components section if provided + if expected_components and len(expected_components) > 0: + components_str = format_components(expected_components, indent_level=2) + output += components_str + "\n\n" + + # Add configs section if provided + if expected_configs and len(expected_configs) > 0: + configs_str = format_configs(expected_configs, indent_level=2) + output += configs_str + "\n\n" + + # Add inputs section + output += format_input_params(inputs + intermediate_inputs, indent_level=2) + + # Add outputs section + output += "\n\n" + output += format_output_params(outputs, indent_level=2) + + return output diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py new file mode 100644 index 000000000000..fb9a03c755ac --- /dev/null +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -0,0 +1,665 @@ +import json +import logging +import os +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch + +from ..configuration_utils import ConfigMixin +from ..image_processor import PipelineImageInput +from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks +from .modular_pipeline_utils import InputParam + + +logger = logging.getLogger(__name__) + +# YiYi Notes: this is actually for SDXL, put it here for now +SDXL_INPUTS_SCHEMA = { + "prompt": InputParam( + "prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation" + ), + "prompt_2": InputParam( + "prompt_2", + type_hint=Union[str, List[str]], + description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2", + ), + "negative_prompt": InputParam( + "negative_prompt", + type_hint=Union[str, List[str]], + description="The prompt or prompts not to guide the image generation", + ), + "negative_prompt_2": InputParam( + "negative_prompt_2", + type_hint=Union[str, List[str]], + description="The negative prompt or prompts for text_encoder_2", + ), + "cross_attention_kwargs": InputParam( + "cross_attention_kwargs", + type_hint=Optional[dict], + description="Kwargs dictionary passed to the AttentionProcessor", + ), + "clip_skip": InputParam( + "clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder" + ), + "image": InputParam( + "image", + type_hint=PipelineImageInput, + required=True, + description="The image(s) to modify for img2img or inpainting", + ), + "mask_image": InputParam( + "mask_image", + type_hint=PipelineImageInput, + required=True, + description="Mask image for inpainting, white pixels will be repainted", + ), + "generator": InputParam( + "generator", + type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], + description="Generator(s) for deterministic generation", + ), + "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), + "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), + "num_images_per_prompt": InputParam( + "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt" + ), + "num_inference_steps": InputParam( + "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps" + ), + "timesteps": InputParam( + "timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process" + ), + "sigmas": InputParam( + "sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process" + ), + "denoising_end": InputParam( + "denoising_end", + type_hint=Optional[float], + description="Fraction of denoising process to complete before termination", + ), + # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 + "strength": InputParam( + "strength", type_hint=float, default=0.3, description="How much to transform the reference image" + ), + "denoising_start": InputParam( + "denoising_start", type_hint=Optional[float], description="Starting point of the denoising process" + ), + "latents": InputParam( + "latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation" + ), + "padding_mask_crop": InputParam( + "padding_mask_crop", + type_hint=Optional[Tuple[int, int]], + description="Size of margin in crop for image and mask", + ), + "original_size": InputParam( + "original_size", + type_hint=Optional[Tuple[int, int]], + description="Original size of the image for SDXL's micro-conditioning", + ), + "target_size": InputParam( + "target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning" + ), + "negative_original_size": InputParam( + "negative_original_size", + type_hint=Optional[Tuple[int, int]], + description="Negative conditioning based on image resolution", + ), + "negative_target_size": InputParam( + "negative_target_size", + type_hint=Optional[Tuple[int, int]], + description="Negative conditioning based on target resolution", + ), + "crops_coords_top_left": InputParam( + "crops_coords_top_left", + type_hint=Tuple[int, int], + default=(0, 0), + description="Top-left coordinates for SDXL's micro-conditioning", + ), + "negative_crops_coords_top_left": InputParam( + "negative_crops_coords_top_left", + type_hint=Tuple[int, int], + default=(0, 0), + description="Negative conditioning crop coordinates", + ), + "aesthetic_score": InputParam( + "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image" + ), + "negative_aesthetic_score": InputParam( + "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score" + ), + "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), + "output_type": InputParam( + "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)" + ), + "ip_adapter_image": InputParam( + "ip_adapter_image", + type_hint=PipelineImageInput, + required=True, + description="Image(s) to be used as IP adapter", + ), + "control_image": InputParam( + "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition" + ), + "control_guidance_start": InputParam( + "control_guidance_start", + type_hint=Union[float, List[float]], + default=0.0, + description="When ControlNet starts applying", + ), + "control_guidance_end": InputParam( + "control_guidance_end", + type_hint=Union[float, List[float]], + default=1.0, + description="When ControlNet stops applying", + ), + "controlnet_conditioning_scale": InputParam( + "controlnet_conditioning_scale", + type_hint=Union[float, List[float]], + default=1.0, + description="Scale factor for ControlNet outputs", + ), + "guess_mode": InputParam( + "guess_mode", + type_hint=bool, + default=False, + description="Enables ControlNet encoder to recognize input without prompts", + ), + "control_mode": InputParam( + "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet" + ), +} + +SDXL_INTERMEDIATE_INPUTS_SCHEMA = { + "prompt_embeds": InputParam( + "prompt_embeds", + type_hint=torch.Tensor, + required=True, + description="Text embeddings used to guide image generation", + ), + "negative_prompt_embeds": InputParam( + "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings" + ), + "pooled_prompt_embeds": InputParam( + "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings" + ), + "negative_pooled_prompt_embeds": InputParam( + "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings" + ), + "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), + "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "preprocess_kwargs": InputParam( + "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor" + ), + "latents": InputParam( + "latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process" + ), + "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), + "num_inference_steps": InputParam( + "num_inference_steps", type_hint=int, required=True, description="Number of denoising steps" + ), + "latent_timestep": InputParam( + "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep" + ), + "image_latents": InputParam( + "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image" + ), + "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), + "masked_image_latents": InputParam( + "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting" + ), + "add_time_ids": InputParam( + "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning" + ), + "negative_add_time_ids": InputParam( + "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids" + ), + "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "ip_adapter_embeds": InputParam( + "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter" + ), + "negative_ip_adapter_embeds": InputParam( + "negative_ip_adapter_embeds", + type_hint=List[torch.Tensor], + description="Negative image embeddings for IP-Adapter", + ), + "images": InputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], + required=True, + description="Generated images", + ), +} + +SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA} + + +DEFAULT_PARAM_MAPS = { + "prompt": { + "label": "Prompt", + "type": "string", + "default": "a bear sitting in a chair drinking a milkshake", + "display": "textarea", + }, + "negative_prompt": { + "label": "Negative Prompt", + "type": "string", + "default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality", + "display": "textarea", + }, + "num_inference_steps": { + "label": "Steps", + "type": "int", + "default": 25, + "min": 1, + "max": 1000, + }, + "seed": { + "label": "Seed", + "type": "int", + "default": 0, + "min": 0, + "display": "random", + }, + "width": { + "label": "Width", + "type": "int", + "display": "text", + "default": 1024, + "min": 8, + "max": 8192, + "step": 8, + "group": "dimensions", + }, + "height": { + "label": "Height", + "type": "int", + "display": "text", + "default": 1024, + "min": 8, + "max": 8192, + "step": 8, + "group": "dimensions", + }, + "images": { + "label": "Images", + "type": "image", + "display": "output", + }, + "image": { + "label": "Image", + "type": "image", + "display": "input", + }, +} + +DEFAULT_TYPE_MAPS = { + "int": { + "type": "int", + "default": 0, + "min": 0, + }, + "float": { + "type": "float", + "default": 0.0, + "min": 0.0, + }, + "str": { + "type": "string", + "default": "", + }, + "bool": { + "type": "boolean", + "default": False, + }, + "image": { + "type": "image", + }, +} + +DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"] +DEFAULT_CATEGORY = "Modular Diffusers" +DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"] +DEFAULT_PARAMS_GROUPS_KEYS = { + "text_encoders": ["text_encoder", "tokenizer"], + "ip_adapter_embeds": ["ip_adapter_embeds"], + "prompt_embeddings": ["prompt_embeds"], +} + + +def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): + """ + Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" -> + "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None + """ + if name is None: + return None + for group_name, group_keys in group_params_keys.items(): + for group_key in group_keys: + if group_key in name: + return group_name + return None + + +class ModularNode(ConfigMixin): + """ + A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper + around a ModularPipelineBlocks object. + + + + This is an experimental feature and is likely to change in the future. + + + """ + + config_name = "node_config.json" + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: Optional[bool] = None, + **kwargs, + ): + blocks = ModularPipelineBlocks.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + return cls(blocks, **kwargs) + + def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): + self.blocks = blocks + + if label is None: + label = self.blocks.__class__.__name__ + # blocks param name -> mellon param name + self.name_mapping = {} + + input_params = {} + # pass or create a default param dict for each input + # e.g. for prompt, + # prompt = { + # "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers + # "label": "Prompt", + # "type": "string", + # "default": "a bear sitting in a chair drinking a milkshake", + # "display": "textarea"} + # if type is not specified, it'll be a "custom" param of its own type + # e.g. you can pass ModularNode(scheduler = {name :"scheduler"}) + # it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}} + # name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}} + inputs = self.blocks.inputs + self.blocks.intermediate_inputs + for inp in inputs: + param = kwargs.pop(inp.name, None) + if param: + # user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...}) + input_params[inp.name] = param + mellon_name = param.pop("name", inp.name) + if mellon_name != inp.name: + self.name_mapping[inp.name] = mellon_name + continue + + if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name): + continue + + if inp.name in DEFAULT_PARAM_MAPS: + # first check if it's in the default param map, if so, directly use that + param = DEFAULT_PARAM_MAPS[inp.name].copy() + elif get_group_name(inp.name): + param = get_group_name(inp.name) + if inp.name not in self.name_mapping: + self.name_mapping[inp.name] = param + else: + # if not, check if it's in the SDXL input schema, if so, + # 1. use the type hint to determine the type + # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} + if inp.type_hint is not None: + type_str = str(inp.type_hint).lower() + else: + inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None) + type_str = str(inp_spec.type_hint).lower() if inp_spec else "" + for type_key, type_param in DEFAULT_TYPE_MAPS.items(): + if type_key in type_str: + param = type_param.copy() + param["label"] = inp.name + param["display"] = "input" + break + else: + param = inp.name + # add the param dict to the inp_params dict + input_params[inp.name] = param + + component_params = {} + for comp in self.blocks.expected_components: + param = kwargs.pop(comp.name, None) + if param: + component_params[comp.name] = param + mellon_name = param.pop("name", comp.name) + if mellon_name != comp.name: + self.name_mapping[comp.name] = mellon_name + continue + + to_exclude = False + for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS: + if exclude_key in comp.name: + to_exclude = True + break + if to_exclude: + continue + + if get_group_name(comp.name): + param = get_group_name(comp.name) + if comp.name not in self.name_mapping: + self.name_mapping[comp.name] = param + elif comp.name in DEFAULT_MODEL_KEYS: + param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"} + else: + param = comp.name + # add the param dict to the model_params dict + component_params[comp.name] = param + + output_params = {} + if isinstance(self.blocks, SequentialPipelineBlocks): + last_block_name = list(self.blocks.sub_blocks.keys())[-1] + outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs + else: + outputs = self.blocks.intermediate_outputs + + for out in outputs: + param = kwargs.pop(out.name, None) + if param: + output_params[out.name] = param + mellon_name = param.pop("name", out.name) + if mellon_name != out.name: + self.name_mapping[out.name] = mellon_name + continue + + if out.name in DEFAULT_PARAM_MAPS: + param = DEFAULT_PARAM_MAPS[out.name].copy() + param["display"] = "output" + else: + group_name = get_group_name(out.name) + if group_name: + param = group_name + if out.name not in self.name_mapping: + self.name_mapping[out.name] = param + else: + param = out.name + # add the param dict to the outputs dict + output_params[out.name] = param + + if len(kwargs) > 0: + logger.warning(f"Unused kwargs: {kwargs}") + + register_dict = { + "category": category, + "label": label, + "input_params": input_params, + "component_params": component_params, + "output_params": output_params, + "name_mapping": self.name_mapping, + } + self.register_to_config(**register_dict) + + def setup(self, components_manager, collection=None): + self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection) + self._components_manager = components_manager + + @property + def mellon_config(self): + return self._convert_to_mellon_config() + + def _convert_to_mellon_config(self): + node = {} + node["label"] = self.config.label + node["category"] = self.config.category + + node_param = {} + for inp_name, inp_param in self.config.input_params.items(): + if inp_name in self.name_mapping: + mellon_name = self.name_mapping[inp_name] + else: + mellon_name = inp_name + if isinstance(inp_param, str): + param = { + "label": inp_param, + "type": inp_param, + "display": "input", + } + else: + param = inp_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}") + + for comp_name, comp_param in self.config.component_params.items(): + if comp_name in self.name_mapping: + mellon_name = self.name_mapping[comp_name] + else: + mellon_name = comp_name + if isinstance(comp_param, str): + param = { + "label": comp_param, + "type": comp_param, + "display": "input", + } + else: + param = comp_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}") + + for out_name, out_param in self.config.output_params.items(): + if out_name in self.name_mapping: + mellon_name = self.name_mapping[out_name] + else: + mellon_name = out_name + if isinstance(out_param, str): + param = { + "label": out_param, + "type": out_param, + "display": "output", + } + else: + param = out_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}") + node["params"] = node_param + return node + + def save_mellon_config(self, file_path): + """ + Save the Mellon configuration to a JSON file. + + Args: + file_path (str or Path): Path where the JSON file will be saved + + Returns: + Path: Path to the saved config file + """ + file_path = Path(file_path) + + # Create directory if it doesn't exist + os.makedirs(file_path.parent, exist_ok=True) + + # Create a combined dictionary with module definition and name mapping + config = {"module": self.mellon_config, "name_mapping": self.name_mapping} + + # Save the config to file + with open(file_path, "w", encoding="utf-8") as f: + json.dump(config, f, indent=2) + + logger.info(f"Mellon config and name mapping saved to {file_path}") + + return file_path + + @classmethod + def load_mellon_config(cls, file_path): + """ + Load a Mellon configuration from a JSON file. + + Args: + file_path (str or Path): Path to the JSON file containing Mellon config + + Returns: + dict: The loaded combined configuration containing 'module' and 'name_mapping' + """ + file_path = Path(file_path) + + if not file_path.exists(): + raise FileNotFoundError(f"Config file not found: {file_path}") + + with open(file_path, "r", encoding="utf-8") as f: + config = json.load(f) + + logger.info(f"Mellon config loaded from {file_path}") + + return config + + def process_inputs(self, **kwargs): + params_components = {} + for comp_name, comp_param in self.config.component_params.items(): + logger.debug(f"component: {comp_name}") + mellon_comp_name = self.name_mapping.get(comp_name, comp_name) + if mellon_comp_name in kwargs: + if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]: + comp = kwargs[mellon_comp_name].pop(comp_name) + else: + comp = kwargs.pop(mellon_comp_name) + if comp: + params_components[comp_name] = self._components_manager.get_one(comp["model_id"]) + + params_run = {} + for inp_name, inp_param in self.config.input_params.items(): + logger.debug(f"input: {inp_name}") + mellon_inp_name = self.name_mapping.get(inp_name, inp_name) + if mellon_inp_name in kwargs: + if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]: + inp = kwargs[mellon_inp_name].pop(inp_name) + else: + inp = kwargs.pop(mellon_inp_name) + if inp is not None: + params_run[inp_name] = inp + + return_output_names = list(self.config.output_params.keys()) + + return params_components, params_run, return_output_names + + def execute(self, **kwargs): + params_components, params_run, return_output_names = self.process_inputs(**kwargs) + + self.pipeline.update_components(**params_components) + output = self.pipeline(**params_run, output=return_output_names) + return output diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py new file mode 100644 index 000000000000..59ec46dc6d36 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -0,0 +1,77 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"] + _import_structure["modular_blocks"] = [ + "ALL_BLOCKS", + "AUTO_BLOCKS", + "CONTROLNET_BLOCKS", + "IMAGE2IMAGE_BLOCKS", + "INPAINT_BLOCKS", + "IP_ADAPTER_BLOCKS", + "TEXT2IMAGE_BLOCKS", + "StableDiffusionXLAutoBlocks", + "StableDiffusionXLAutoControlnetStep", + "StableDiffusionXLAutoDecodeStep", + "StableDiffusionXLAutoIPAdapterStep", + "StableDiffusionXLAutoVaeEncoderStep", + ] + _import_structure["modular_pipeline"] = ["StableDiffusionXLModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .encoders import ( + StableDiffusionXLTextEncoderStep, + ) + from .modular_blocks import ( + ALL_BLOCKS, + AUTO_BLOCKS, + CONTROLNET_BLOCKS, + IMAGE2IMAGE_BLOCKS, + INPAINT_BLOCKS, + IP_ADAPTER_BLOCKS, + TEXT2IMAGE_BLOCKS, + StableDiffusionXLAutoBlocks, + StableDiffusionXLAutoControlnetStep, + StableDiffusionXLAutoDecodeStep, + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + ) + from .modular_pipeline import StableDiffusionXLModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py new file mode 100644 index 000000000000..c56f4af1b8a5 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -0,0 +1,1929 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union + +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, UNet2DConditionModel +from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor, unwrap_module +from ..modular_pipeline import ( + PipelineBlock, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import StableDiffusionXLModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def prepare_latents_img2img( + vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True +): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError(f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}") + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + latents_mean = latents_std = None + if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: + latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: + latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) + # make sure the VAE is in float32 mode, as it overflows in float16 + if vae.config.force_upcast: + image = image.float() + vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(vae.encode(image), generator=generator) + + if vae.config.force_upcast: + vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std + else: + init_latents = vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + +class StableDiffusionXLInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_pooled_prompt_embeds", + description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "ip_adapter_embeds", + type_hint=List[torch.Tensor], + description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step.", + ), + InputParam( + "negative_ip_adapter_embeds", + type_hint=List[torch.Tensor], + description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields + description="negative text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields + description="pooled text embeddings used to guide the image generation", + ), + OutputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields + description="negative pooled text embeddings used to guide the image generation", + ), + OutputParam( + "ip_adapter_embeds", + type_hint=List[torch.Tensor], + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields + description="image embeddings for IP-Adapter", + ), + OutputParam( + "negative_ip_adapter_embeds", + type_hint=List[torch.Tensor], + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields + description="negative image embeddings for IP-Adapter", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." + ) + + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): + raise ValueError("`ip_adapter_embeds` must be a list") + + if block_state.negative_ip_adapter_embeds is not None and not isinstance( + block_state.negative_ip_adapter_embeds, list + ): + raise ValueError("`negative_ip_adapter_embeds` must be a list") + + if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: + raise ValueError( + "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" + f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" + f" {block_state.negative_ip_adapter_embeds[i].shape}." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) + + if block_state.negative_pooled_prompt_embeds is not None: + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) + + if block_state.ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + block_state.ip_adapter_embeds[i] = torch.cat( + [ip_adapter_embed] * block_state.num_images_per_prompt, dim=0 + ) + + if block_state.negative_ip_adapter_embeds is not None: + for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): + block_state.negative_ip_adapter_embeds[i] = torch.cat( + [negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0 + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + + "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + InputParam("strength", default=0.3), + InputParam("denoising_start"), + # YiYi TODO: do we need num_images_per_prompt here? + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + OutputParam( + "latent_timestep", + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image generation", + ), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self->components + def get_timesteps(components, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (denoising_start * components.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if components.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(components.scheduler.timesteps) - num_inference_steps + timesteps = components.scheduler.timesteps[t_start:] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + block_state.device, + block_state.timesteps, + block_state.sigmas, + ) + + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( + components, + block_state.num_inference_steps, + block_state.strength, + block_state.device, + denoising_start=block_state.denoising_start + if denoising_value_valid(block_state.denoising_start) + else None, + ) + block_state.latent_timestep = block_state.timesteps[:1].repeat( + block_state.batch_size * block_state.num_images_per_prompt + ) + + if ( + block_state.denoising_end is not None + and isinstance(block_state.denoising_end, float) + and block_state.denoising_end > 0 + and block_state.denoising_end < 1 + ): + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len( + list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps)) + ) + block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLSetTimestepsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + block_state.device, + block_state.timesteps, + block_state.sigmas, + ) + + if ( + block_state.denoising_end is not None + and isinstance(block_state.denoising_end, float) + and block_state.denoising_end > 0 + and block_state.denoising_end < 1 + ): + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len( + list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps)) + ) + block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that prepares the latents for the inpainting process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + InputParam( + "strength", + default=0.9999, + description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " + "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " + "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " + "be maximum and the denoising process will run for the full number of iterations specified in " + "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " + "`denoising_start` being declared as an integer, the value of `strength` will be ignored.", + ), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.", + ), + InputParam( + "mask", + required=True, + type_hint=torch.Tensor, + description="The mask for the inpainting generation. Can be generated in vae_encode step.", + ), + InputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + OutputParam( + "noise", + type_hint=torch.Tensor, + description="The noise added to the image latents, used for inpainting generation", + ), + ] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self->components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + @staticmethod + def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument + def prepare_latents_inpaint( + self, + components, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(components, image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * components.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + block_state.is_strength_max = block_state.strength == 1.0 + + # for non-inpainting specific unet, we do not need masked_image_latents + if hasattr(components, "unet") and components.unet is not None: + if components.unet.config.in_channels == 4: + block_state.masked_image_latents = None + + block_state.add_noise = True if block_state.denoising_start is None else False + + block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor + block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor + + block_state.latents, block_state.noise = self.prepare_latents_inpaint( + components, + block_state.batch_size * block_state.num_images_per_prompt, + components.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + image=block_state.image_latents, + timestep=block_state.latent_timestep, + is_strength_max=block_state.is_strength_max, + add_noise=block_state.add_noise, + return_noise=True, + return_image_latents=False, + ) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image_latents, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that prepares the latents for the image-to-image generation process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + block_state.add_noise = True if block_state.denoising_start is None else False + if block_state.latents is None: + block_state.latents = prepare_latents_img2img( + components.vae, + components.scheduler, + block_state.image_latents, + block_state.latent_timestep, + block_state.batch_size, + block_state.num_images_per_prompt, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.add_noise, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("vae", AutoencoderKL), + ] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-image generation process" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + @staticmethod + def check_inputs(components, block_state): + if ( + block_state.height is not None + and block_state.height % components.vae_scale_factor != 0 + or block_state.width is not None + and block_state.width % components.vae_scale_factor != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self->comp + def prepare_latents(comp, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // comp.vae_scale_factor, + int(width) // comp.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * comp.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.dtype is None: + block_state.dtype = components.vae.dtype + + block_state.device = components._execution_device + + self.check_inputs(components, block_state) + + block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor + block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor + block_state.num_channels_latents = components.num_channels_latents + block_state.latents = self.prepare_latents( + components, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec("requires_aesthetics_score", False), + ] + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam("aesthetic_score", default=6.0), + InputParam("negative_aesthetic_score", default=2.0), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "add_time_ids", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="The time ids to condition the denoising process", + ), + OutputParam( + "negative_add_time_ids", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="The negative time ids to condition the denoising process", + ), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self->components + def _get_add_time_ids( + components, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if components.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.vae_scale_factor = components.vae_scale_factor + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * block_state.vae_scale_factor + block_state.width = block_state.width * block_state.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + if block_state.negative_original_size is None: + block_state.negative_original_size = block_state.original_size + if block_state.negative_target_size is None: + block_state.negative_target_size = block_state.target_size + + block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.aesthetic_score, + block_state.negative_aesthetic_score, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + dtype=block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + block_state.add_time_ids = block_state.add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat( + block_state.batch_size * block_state.num_images_per_prompt + ) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "Step that prepares the additional conditioning for the text-to-image generation process" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "add_time_ids", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="The time ids to condition the denoising process", + ), + OutputParam( + "negative_add_time_ids", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="The negative time ids to condition the denoising process", + ), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self->components + def _get_add_time_ids( + components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + block_state.add_time_ids = self._get_add_time_ids( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + if block_state.negative_original_size is not None and block_state.negative_target_size is not None: + block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + else: + block_state.negative_add_time_ids = block_state.add_time_ids + + block_state.add_time_ids = block_state.add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat( + block_state.batch_size * block_state.num_images_per_prompt + ) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLControlNetInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetModel), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "step that prepare inputs for controlnet" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), + OutputParam( + "control_guidance_start", type_hint=List[float], description="The controlnet guidance start values" + ), + OutputParam( + "control_guidance_end", type_hint=List[float], description="The controlnet guidance end values" + ), + OutputParam( + "conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values" + ), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill" + ).to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to( + dtype=torch.float32 + ) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # (1) prepare controlnet inputs + block_state.device = components._execution_device + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + controlnet = unwrap_module(components.controlnet) + + # (1.1) + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance( + block_state.control_guidance_end, list + ): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [ + block_state.control_guidance_start + ] + elif not isinstance(block_state.control_guidance_end, list) and isinstance( + block_state.control_guidance_start, list + ): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [ + block_state.control_guidance_end + ] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance( + block_state.control_guidance_end, list + ): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], + ) + + # (1.2) + # controlnet_conditioning_scale (align format) + if isinstance(controlnet, MultiControlNetModel) and isinstance( + block_state.controlnet_conditioning_scale, float + ): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len( + controlnet.nets + ) + + # (1.3) + # global_pool_conditions + block_state.global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + # (1.4) + # guess_mode + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # (1.5) + # control_image + if isinstance(controlnet, ControlNetModel): + block_state.control_image = self.prepare_control_image( + components, + image=block_state.control_image, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in block_state.control_image: + control_image = self.prepare_control_image( + components, + image=control_image_, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + + control_images.append(control_image) + + block_state.control_image = control_images + else: + assert False + + # (1.6) + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + keeps = [ + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) + ] + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetUnionModel), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "step that prepares inputs for the ControlNetUnion model" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_mode", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of model tensor inputs. Can be generated in input step.", + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step.", + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), + OutputParam( + "control_type_idx", + type_hint=List[int], + description="The control mode indices", + kwargs_type="controlnet_kwargs", + ), + OutputParam( + "control_type", + type_hint=torch.Tensor, + description="The control type tensor that specifies which control type is active", + kwargs_type="controlnet_kwargs", + ), + OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), + OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), + OutputParam( + "conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values" + ), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill" + ).to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to( + dtype=torch.float32 + ) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + device = components._execution_device + dtype = block_state.dtype or components.controlnet.dtype + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance( + block_state.control_guidance_end, list + ): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [ + block_state.control_guidance_start + ] + elif not isinstance(block_state.control_guidance_end, list) and isinstance( + block_state.control_guidance_start, list + ): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [ + block_state.control_guidance_end + ] + + # guess_mode + block_state.global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # control_image + if not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + # control_mode + if not isinstance(block_state.control_mode, list): + block_state.control_mode = [block_state.control_mode] + + if len(block_state.control_image) != len(block_state.control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + # control_type + block_state.num_control_type = controlnet.config.num_control_type + block_state.control_type = [0 for _ in range(block_state.num_control_type)] + for control_idx in block_state.control_mode: + block_state.control_type[control_idx] = 1 + block_state.control_type = torch.Tensor(block_state.control_type) + + block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) + repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] + block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) + + # prepare control_image + for idx, _ in enumerate(block_state.control_image): + block_state.control_image[idx] = self.prepare_control_image( + components, + image=block_state.control_image[idx], + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=device, + dtype=dtype, + crops_coords=block_state.crops_coords, + ) + block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + block_state.controlnet_keep.append( + 1.0 + - float( + i / len(block_state.timesteps) < block_state.control_guidance_start + or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end + ) + ) + block_state.control_type_idx = block_state.control_mode + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py new file mode 100644 index 000000000000..e9f627636e8c --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -0,0 +1,218 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple, Union + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ...utils import logging +from ..modular_pipeline import ( + PipelineBlock, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionXLDecodeStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ) + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], + description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self->components + def upcast_vae(components): + dtype = components.vae.dtype + components.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + components.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + components.vae.post_quant_conv.to(dtype) + components.vae.decoder.conv_in.to(dtype) + components.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if not block_state.output_type == "latent": + latents = block_state.latents + # make sure the VAE is in float32 mode, as it overflows in float16 + block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast + + if block_state.needs_upcasting: + self.upcast_vae(components) + latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != components.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + components.vae = components.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + block_state.has_latents_mean = ( + hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None + ) + block_state.has_latents_std = ( + hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None + ) + if block_state.has_latents_mean and block_state.has_latents_std: + block_state.latents_mean = ( + torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + block_state.latents_std = ( + torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = ( + latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + ) + else: + latents = latents / components.vae.config.scaling_factor + + block_state.images = components.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if block_state.needs_upcasting: + components.vae.to(dtype=torch.float16) + else: + block_state.images = block_state.latents + + # apply watermark if available + if hasattr(components, "watermark") and components.watermark is not None: + block_state.images = components.watermark.apply_watermark(block_state.images) + + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "A post-processing step that overlays the mask on the image (inpainting task only).\n" + + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("image"), + InputParam("mask_image"), + InputParam("padding_mask_crop"), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], + description="The generated images from the decode step", + ), + InputParam( + "crops_coords", + type_hint=Tuple[int, int], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: + block_state.images = [ + components.image_processor.apply_overlay( + block_state.mask_image, block_state.image, i, block_state.crops_coords + ) + for i in block_state.images + ] + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py new file mode 100644 index 000000000000..7fe4a472eec3 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -0,0 +1,791 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import ControlNetModel, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + PipelineBlock, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import StableDiffusionXLModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# YiYi experimenting composible denoise loop +# loop step (1): prepare latent input for denoiser +class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepare the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + + return components, block_state + + +# loop step (1): prepare latent input for denoiser (with inpainting) +class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object" + ) + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.", + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + num_channels_unet = components.num_channels_unet + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if block_state.mask is None or block_state.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `components.unet` or your `mask_image` or `image` input." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + self.check_inputs(components, block_state) + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + if components.num_channels_unet == 9: + block_state.scaled_latents = torch.cat( + [block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1 + ) + + return components, block_state + + +# loop step (2): denoise the latents with guidance +class StableDiffusionXLLoopDenoiser(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step.", + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ), + ), + ] + + @torch.no_grad() + def __call__( + self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int + ) -> PipelineState: + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields = { + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=cond_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + + +# loop step (2): denoise the latents with guidance (with controlnet) +class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that denoise the latents with guidance (with controlnet). " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "controlnet_cond", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "conditioning_scale", + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ), + ), + InputParam( + kwargs_type="controlnet_kwargs", + description=( + "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" + "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ), + ), + ] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + extra_controlnet_kwargs = self.prepare_extra_kwargs( + components.controlnet.forward, **block_state.controlnet_kwargs + ) + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields = { + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [ + c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i]) + ] + else: + controlnet_cond_scale = block_state.conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + + # Prepare additional conditionings + added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: + added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds + + # Prepare controlnet additional conditionings + controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + # run controlnet for the guidance batch + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + down_block_res_samples = block_state.down_block_res_samples_zeros + mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + down_block_res_samples, mid_block_res_sample = components.controlnet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + controlnet_cond=block_state.controlnet_cond, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + **extra_controlnet_kwargs, + ) + + # assign it to block_state so it will be available for the uncond guidance batch + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) + + # Predict the noise + # store the noise_pred in guider_state_batch so we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + + +# loop step (3): scheduler step to update latents +class StableDiffusionXLLoopAfterDenoiser(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("eta", default=0.0), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + # YiYi TODO: move this out of here + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs( + components.scheduler.step, generator=block_state.generator, eta=block_state.eta + ) + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + **block_state.extra_step_kwargs, + **block_state.scheduler_step_kwargs, + return_dict=False, + )[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + return components, block_state + + +# loop step (3): scheduler step to update latents (with inpainting) +class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents (for inpainting workflow only). " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("eta", default=0.0), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam("generator"), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.", + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step.", + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + def check_inputs(self, components, block_state): + if components.num_channels_unet == 4: + if block_state.image_latents is None: + raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") + if block_state.mask is None: + raise ValueError(f"mask is required for this step {self.__class__.__name__}") + if block_state.noise is None: + raise ValueError(f"noise is required for this step {self.__class__.__name__}") + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + self.check_inputs(components, block_state) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs( + components.scheduler.step, generator=block_state.generator, eta=block_state.eta + ) + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + **block_state.extra_step_kwargs, + **block_state.scheduler_step_kwargs, + return_dict=False, + )[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + # adjust latent for inpainting + if components.num_channels_unet == 4: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) + ) + + block_state.latents = ( + 1 - block_state.mask + ) * block_state.init_latents_proper + block_state.mask * block_state.latents + + return components, block_state + + +# the loop wrapper that iterates over the timesteps +class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def loop_intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() + else: + components.guider.enable() + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +# composing the denoising loops +class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [ + StableDiffusionXLLoopBeforeDenoiser, + StableDiffusionXLLoopDenoiser, + StableDiffusionXLLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `StableDiffusionXLLoopBeforeDenoiser`\n" + " - `StableDiffusionXLLoopDenoiser`\n" + " - `StableDiffusionXLLoopAfterDenoiser`\n" + "This block supports both text2img and img2img tasks." + ) + + +# control_cond +class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [ + StableDiffusionXLLoopBeforeDenoiser, + StableDiffusionXLControlNetLoopDenoiser, + StableDiffusionXLLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents with controlnet. \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `StableDiffusionXLLoopBeforeDenoiser`\n" + " - `StableDiffusionXLControlNetLoopDenoiser`\n" + " - `StableDiffusionXLLoopAfterDenoiser`\n" + "This block supports using controlnet for both text2img and img2img tasks." + ) + + +# mask +class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [ + StableDiffusionXLInpaintLoopBeforeDenoiser, + StableDiffusionXLLoopDenoiser, + StableDiffusionXLInpaintLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents(for inpainting task only). \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n" + " - `StableDiffusionXLLoopDenoiser`\n" + " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n" + "This block onlysupports inpainting tasks." + ) + + +# control_cond + mask +class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [ + StableDiffusionXLInpaintLoopBeforeDenoiser, + StableDiffusionXLControlNetLoopDenoiser, + StableDiffusionXLInpaintLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n" + " - `StableDiffusionXLControlNetLoopDenoiser`\n" + " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n" + "This block only supports using controlnet for inpainting tasks." + ) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py new file mode 100644 index 000000000000..bd0e962140e8 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -0,0 +1,902 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import StableDiffusionXLModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLIPAdapterStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "IP Adapter step that prepares ip adapter image embeddings.\n" + "Note that this step only prepares the embeddings - in order for it to work correctly, " + "you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale().\n" + "See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" + " for more details" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_encoder", CLIPVisionModelWithProjection), + ComponentSpec( + "feature_extractor", + CLIPImageProcessor, + config=FrozenDict({"size": 224, "crop_size": 224}), + default_creation_method="from_config", + ), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "ip_adapter_image", + PipelineImageInput, + required=True, + description="The image(s) to be used as ip adapter", + ) + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), + OutputParam( + "negative_ip_adapter_embeds", + type_hint=torch.Tensor, + description="Negative IP adapter image embeddings", + ), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self->components + def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(components.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = components.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = components.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = components.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, + components, + ip_adapter_image, + ip_adapter_image_embeds, + device, + num_images_per_prompt, + prepare_unconditional_embeds, + ): + image_embeds = [] + if prepare_unconditional_embeds: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + components, single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if prepare_unconditional_embeds: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if prepare_unconditional_embeds: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if prepare_unconditional_embeds: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( + components, + ip_adapter_image=block_state.ip_adapter_image, + ip_adapter_image_embeds=None, + device=block_state.device, + num_images_per_prompt=1, + prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, + ) + if block_state.prepare_unconditional_embeds: + block_state.negative_ip_adapter_embeds = [] + for i, image_embeds in enumerate(block_state.ip_adapter_embeds): + negative_image_embeds, image_embeds = image_embeds.chunk(2) + block_state.negative_ip_adapter_embeds.append(negative_image_embeds) + block_state.ip_adapter_embeds[i] = image_embeds + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLTextEncoderStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ConfigSpec("force_zeros_for_empty_prompt", True)] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("negative_prompt"), + InputParam("negative_prompt_2"), + InputParam("cross_attention_kwargs"), + InputParam("clip_skip"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="negative text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="pooled text embeddings used to guide the image generation", + ), + OutputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="negative pooled text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + elif block_state.prompt_2 is not None and ( + not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list) + ): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prepare_unconditional_embeds: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or components._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): + components._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if components.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) + else: + scale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) + else: + scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = ( + [components.tokenizer, components.tokenizer_2] + if components.tokenizer is not None + else [components.tokenizer_2] + ) + text_encoders = ( + [components.text_encoder, components.text_encoder_2] + if components.text_encoder is not None + else [components.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt + if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif prepare_unconditional_embeds and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if components.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if prepare_unconditional_embeds: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if components.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=components.text_encoder_2.dtype, device=device + ) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if prepare_unconditional_embeds: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if components.text_encoder is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + # Encode input prompt + block_state.text_encoder_lora_scale = ( + block_state.cross_attention_kwargs.get("scale", None) + if block_state.cross_attention_kwargs is not None + else None + ) + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + block_state.pooled_prompt_embeds, + block_state.negative_pooled_prompt_embeds, + ) = self.encode_prompt( + components, + block_state.prompt, + block_state.prompt_2, + block_state.device, + 1, + block_state.prepare_unconditional_embeds, + block_state.negative_prompt, + block_state.negative_prompt_2, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + lora_scale=block_state.text_encoder_lora_scale, + clip_skip=block_state.clip_skip, + ) + # Add outputs + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLVaeEncoderStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "Vae Encoder step that encode the input image into a latent representation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image", required=True), + InputParam("height"), + InputParam("width"), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam( + "preprocess_kwargs", + type_hint=Optional[dict], + description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation", + ) + ] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} + block_state.device = components._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + + block_state.image = components.image_processor.preprocess( + block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs + ) + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + + block_state.batch_size = block_state.image.shape[0] + + # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) + if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." + ) + + block_state.image_latents = self._encode_vae_image( + components, image=block_state.image, generator=block_state.generator + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ComponentSpec( + "mask_processor", + VaeImageProcessor, + config=FrozenDict( + {"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True} + ), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Vae encoder step that prepares the image and mask for the inpainting process" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_latents", type_hint=torch.Tensor, description="The latents representation of the input image" + ), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)", + ), + OutputParam( + "crops_coords", + type_hint=Optional[Tuple[int, int]], + description="The crop coordinates to use for the preprocess/postprocess of the image and mask", + ), + ] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + if block_state.height is None: + block_state.height = components.default_height + if block_state.width is None: + block_state.width = components.default_width + + if block_state.padding_mask_crop is not None: + block_state.crops_coords = components.mask_processor.get_crop_region( + block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop + ) + block_state.resize_mode = "fill" + else: + block_state.crops_coords = None + block_state.resize_mode = "default" + + block_state.image = components.image_processor.preprocess( + block_state.image, + height=block_state.height, + width=block_state.width, + crops_coords=block_state.crops_coords, + resize_mode=block_state.resize_mode, + ) + block_state.image = block_state.image.to(dtype=torch.float32) + + block_state.mask = components.mask_processor.preprocess( + block_state.mask_image, + height=block_state.height, + width=block_state.width, + resize_mode=block_state.resize_mode, + crops_coords=block_state.crops_coords, + ) + block_state.masked_image = block_state.image * (block_state.mask < 0.5) + + block_state.batch_size = block_state.image.shape[0] + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + block_state.image_latents = self._encode_vae_image( + components, image=block_state.image, generator=block_state.generator + ) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image, + block_state.batch_size, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py new file mode 100644 index 000000000000..c9033856bcc0 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py @@ -0,0 +1,380 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + StableDiffusionXLControlNetInputStep, + StableDiffusionXLControlNetUnionInputStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLInputStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLSetTimestepsStep, +) +from .decoders import ( + StableDiffusionXLDecodeStep, + StableDiffusionXLInpaintOverlayMaskStep, +) +from .denoise import ( + StableDiffusionXLControlNetDenoiseStep, + StableDiffusionXLDenoiseStep, + StableDiffusionXLInpaintControlNetDenoiseStep, + StableDiffusionXLInpaintDenoiseStep, +) +from .encoders import ( + StableDiffusionXLInpaintVaeEncoderStep, + StableDiffusionXLIPAdapterStep, + StableDiffusionXLTextEncoderStep, + StableDiffusionXLVaeEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# auto blocks & sequential blocks & mappings + + +# vae encoder (run before before_denoise) +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n" + + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." + + " - if neither `mask_image` nor `image` is provided, step will be skipped." + ) + + +# optional ip-adapter (run before input step) +class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLIPAdapterStep] + block_names = ["ip_adapter"] + block_trigger_inputs = ["ip_adapter_image"] + + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n" + + +# before_denoise: text2img +class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLInputStep, + StableDiffusionXLSetTimestepsStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + ) + + +# before_denoise: img2img +class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLInputStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + ) + + +# before_denoise: inpainting +class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLInputStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + ) + + +# before_denoise: all task (text2img, img2img, inpainting) +class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [ + StableDiffusionXLInpaintBeforeDenoiseStep, + StableDiffusionXLImg2ImgBeforeDenoiseStep, + StableDiffusionXLBeforeDenoiseStep, + ] + block_names = ["inpaint", "img2img", "text2img"] + block_trigger_inputs = ["mask", "image_latents", None] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + ) + + +# optional controlnet input step (after before_denoise, before denoise) +# works for both controlnet and controlnet_union +class StableDiffusionXLAutoControlNetInputStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] + block_names = ["controlnet_union", "controlnet"] + block_trigger_inputs = ["control_mode", "control_image"] + + @property + def description(self): + return ( + "Controlnet Input step that prepare the controlnet input.\n" + + "This is an auto pipeline block that works for both controlnet and controlnet_union.\n" + + " (it should be called right before the denoise step)" + + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." + + " - if neither `control_mode` nor `control_image` is provided, step will be skipped." + ) + + +# denoise: controlnet (text2img, img2img, inpainting) +class StableDiffusionXLAutoControlNetDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintControlNetDenoiseStep, StableDiffusionXLControlNetDenoiseStep] + block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] + block_trigger_inputs = ["mask", "controlnet_cond"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents with controlnet. " + "This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks." + "This block should not be used without a controlnet_cond input" + " - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided." + " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided." + " - If neither mask nor controlnet_cond are provided, step will be skipped." + ) + + +# denoise: all task with or without controlnet (text2img, img2img, inpainting) +class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [ + StableDiffusionXLAutoControlNetDenoiseStep, + StableDiffusionXLInpaintDenoiseStep, + StableDiffusionXLDenoiseStep, + ] + block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"] + block_trigger_inputs = ["controlnet_cond", "mask", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet." + " - `StableDiffusionXLAutoControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (support controlnet withtext2img, img2img and inpainting tasks)." + " - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided (support inpainting tasks)." + " - `StableDiffusionXLDenoiseStep` (denoise) is used when neither mask nor controlnet_cond are provided (support text2img and img2img tasks)." + ) + + +# decode: inpaint +class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep] + block_names = ["decode", "mask_overlay"] + + @property + def description(self): + return ( + "Inpaint decode step that decode the denoised latents into images outputs.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image" + ) + + +# decode: all task (text2img, img2img, inpainting) +class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] + block_names = ["inpaint", "non-inpaint"] + block_trigger_inputs = ["padding_mask_crop", None] + + @property + def description(self): + return ( + "Decode step that decode the denoised latents into images outputs.\n" + + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." + ) + + +# ip-adapter, controlnet, text2img, img2img, inpainting +class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLTextEncoderStep, + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLAutoControlNetInputStep, + StableDiffusionXLAutoDenoiseStep, + StableDiffusionXLAutoDecodeStep, + ] + block_names = [ + "text_encoder", + "ip_adapter", + "image_encoder", + "before_denoise", + "controlnet_input", + "denoise", + "decoder", + ] + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + + "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + + "- to run the controlnet workflow, you need to provide `control_image`\n" + + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + + "- for text-to-image generation, all you need to provide is `prompt`" + ) + + +# controlnet (input + denoise step) +class StableDiffusionXLAutoControlnetStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLAutoControlNetInputStep, + StableDiffusionXLAutoControlNetDenoiseStep, + ] + block_names = ["controlnet_input", "controlnet_denoise"] + + @property + def description(self): + return ( + "Controlnet auto step that prepare the controlnet input and denoise the latents. " + + "It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks." + + " (it should be replace at 'denoise' step)" + ) + + +TEXT2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLSetTimestepsStep), + ("prepare_latents", StableDiffusionXLPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep), + ] +) + +IMAGE2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("image_encoder", StableDiffusionXLVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep), + ] +) + +INPAINT_BLOCKS = InsertableDict( + [ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLInpaintDenoiseStep), + ("decode", StableDiffusionXLInpaintDecodeStep), + ] +) + +CONTROLNET_BLOCKS = InsertableDict( + [ + ("denoise", StableDiffusionXLAutoControlnetStep), + ] +) + + +IP_ADAPTER_BLOCKS = InsertableDict( + [ + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ] +) + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("controlnet_input", StableDiffusionXLAutoControlNetInputStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), + ("decode", StableDiffusionXLAutoDecodeStep), + ] +) + + +ALL_BLOCKS = { + "text2img": TEXT2IMAGE_BLOCKS, + "img2img": IMAGE2IMAGE_BLOCKS, + "inpaint": INPAINT_BLOCKS, + "controlnet": CONTROLNET_BLOCKS, + "ip_adapter": IP_ADAPTER_BLOCKS, + "auto": AUTO_BLOCKS, +} diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py new file mode 100644 index 000000000000..fc030fae56fb --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -0,0 +1,376 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch + +from ...image_processor import PipelineImageInput +from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...pipelines.pipeline_utils import StableDiffusionMixin +from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from ...utils import logging +from ..modular_pipeline import ModularPipeline +from ..modular_pipeline_utils import InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? +# YiYi Notes: model specific components: +## (1) it should inherit from ModularPipeline +## (2) acts like a container that holds components and configs +## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents +## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) +## (5) how to use together with Components_manager? +class StableDiffusionXLModularPipeline( + ModularPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + ModularIPAdapterMixin, +): + """ + A ModularPipeline for Stable Diffusion XL. + + + + This is an experimental feature and is likely to change in the future. + + + """ + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + default_sample_size = 128 + if hasattr(self, "unet") and self.unet is not None: + default_sample_size = self.unet.config.sample_size + return default_sample_size + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_unet(self): + num_channels_unet = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_unet = self.unet.config.in_channels + return num_channels_unet + + @property + def num_channels_latents(self): + num_channels_latents = 4 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + return num_channels_latents + + +# YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks +# auto_docstring +SDXL_INPUTS_SCHEMA = { + "prompt": InputParam( + "prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation" + ), + "prompt_2": InputParam( + "prompt_2", + type_hint=Union[str, List[str]], + description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2", + ), + "negative_prompt": InputParam( + "negative_prompt", + type_hint=Union[str, List[str]], + description="The prompt or prompts not to guide the image generation", + ), + "negative_prompt_2": InputParam( + "negative_prompt_2", + type_hint=Union[str, List[str]], + description="The negative prompt or prompts for text_encoder_2", + ), + "cross_attention_kwargs": InputParam( + "cross_attention_kwargs", + type_hint=Optional[dict], + description="Kwargs dictionary passed to the AttentionProcessor", + ), + "clip_skip": InputParam( + "clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder" + ), + "image": InputParam( + "image", + type_hint=PipelineImageInput, + required=True, + description="The image(s) to modify for img2img or inpainting", + ), + "mask_image": InputParam( + "mask_image", + type_hint=PipelineImageInput, + required=True, + description="Mask image for inpainting, white pixels will be repainted", + ), + "generator": InputParam( + "generator", + type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], + description="Generator(s) for deterministic generation", + ), + "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), + "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), + "num_images_per_prompt": InputParam( + "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt" + ), + "num_inference_steps": InputParam( + "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps" + ), + "timesteps": InputParam( + "timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process" + ), + "sigmas": InputParam( + "sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process" + ), + "denoising_end": InputParam( + "denoising_end", + type_hint=Optional[float], + description="Fraction of denoising process to complete before termination", + ), + # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 + "strength": InputParam( + "strength", type_hint=float, default=0.3, description="How much to transform the reference image" + ), + "denoising_start": InputParam( + "denoising_start", type_hint=Optional[float], description="Starting point of the denoising process" + ), + "latents": InputParam( + "latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation" + ), + "padding_mask_crop": InputParam( + "padding_mask_crop", + type_hint=Optional[Tuple[int, int]], + description="Size of margin in crop for image and mask", + ), + "original_size": InputParam( + "original_size", + type_hint=Optional[Tuple[int, int]], + description="Original size of the image for SDXL's micro-conditioning", + ), + "target_size": InputParam( + "target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning" + ), + "negative_original_size": InputParam( + "negative_original_size", + type_hint=Optional[Tuple[int, int]], + description="Negative conditioning based on image resolution", + ), + "negative_target_size": InputParam( + "negative_target_size", + type_hint=Optional[Tuple[int, int]], + description="Negative conditioning based on target resolution", + ), + "crops_coords_top_left": InputParam( + "crops_coords_top_left", + type_hint=Tuple[int, int], + default=(0, 0), + description="Top-left coordinates for SDXL's micro-conditioning", + ), + "negative_crops_coords_top_left": InputParam( + "negative_crops_coords_top_left", + type_hint=Tuple[int, int], + default=(0, 0), + description="Negative conditioning crop coordinates", + ), + "aesthetic_score": InputParam( + "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image" + ), + "negative_aesthetic_score": InputParam( + "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score" + ), + "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), + "output_type": InputParam( + "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)" + ), + "ip_adapter_image": InputParam( + "ip_adapter_image", + type_hint=PipelineImageInput, + required=True, + description="Image(s) to be used as IP adapter", + ), + "control_image": InputParam( + "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition" + ), + "control_guidance_start": InputParam( + "control_guidance_start", + type_hint=Union[float, List[float]], + default=0.0, + description="When ControlNet starts applying", + ), + "control_guidance_end": InputParam( + "control_guidance_end", + type_hint=Union[float, List[float]], + default=1.0, + description="When ControlNet stops applying", + ), + "controlnet_conditioning_scale": InputParam( + "controlnet_conditioning_scale", + type_hint=Union[float, List[float]], + default=1.0, + description="Scale factor for ControlNet outputs", + ), + "guess_mode": InputParam( + "guess_mode", + type_hint=bool, + default=False, + description="Enables ControlNet encoder to recognize input without prompts", + ), + "control_mode": InputParam( + "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet" + ), +} + + +SDXL_INTERMEDIATE_INPUTS_SCHEMA = { + "prompt_embeds": InputParam( + "prompt_embeds", + type_hint=torch.Tensor, + required=True, + description="Text embeddings used to guide image generation", + ), + "negative_prompt_embeds": InputParam( + "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings" + ), + "pooled_prompt_embeds": InputParam( + "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings" + ), + "negative_pooled_prompt_embeds": InputParam( + "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings" + ), + "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), + "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "preprocess_kwargs": InputParam( + "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor" + ), + "latents": InputParam( + "latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process" + ), + "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), + "num_inference_steps": InputParam( + "num_inference_steps", type_hint=int, required=True, description="Number of denoising steps" + ), + "latent_timestep": InputParam( + "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep" + ), + "image_latents": InputParam( + "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image" + ), + "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), + "masked_image_latents": InputParam( + "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting" + ), + "add_time_ids": InputParam( + "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning" + ), + "negative_add_time_ids": InputParam( + "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids" + ), + "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "ip_adapter_embeds": InputParam( + "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter" + ), + "negative_ip_adapter_embeds": InputParam( + "negative_ip_adapter_embeds", + type_hint=List[torch.Tensor], + description="Negative image embeddings for IP-Adapter", + ), + "images": InputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], + required=True, + description="Generated images", + ), +} + + +SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { + "prompt_embeds": OutputParam( + "prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation" + ), + "negative_prompt_embeds": OutputParam( + "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings" + ), + "pooled_prompt_embeds": OutputParam( + "pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings" + ), + "negative_pooled_prompt_embeds": OutputParam( + "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings" + ), + "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), + "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "image_latents": OutputParam( + "image_latents", type_hint=torch.Tensor, description="Latents representing reference image" + ), + "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), + "masked_image_latents": OutputParam( + "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting" + ), + "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), + "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), + "latent_timestep": OutputParam( + "latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep" + ), + "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), + "negative_add_time_ids": OutputParam( + "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids" + ), + "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), + "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "ip_adapter_embeds": OutputParam( + "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter" + ), + "negative_ip_adapter_embeds": OutputParam( + "negative_ip_adapter_embeds", + type_hint=List[torch.Tensor], + description="Negative image embeddings for IP-Adapter", + ), + "images": OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], + description="Generated images", + ), +} + + +SDXL_OUTPUTS_SCHEMA = { + "images": OutputParam( + "images", + type_hint=Union[ + Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput + ], + description="The final generated images", + ) +} diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index b1a7ffaaea9c..8ca60d9f631f 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -248,14 +248,15 @@ def _get_connected_pipeline(pipeline_cls): return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False) -def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True): - def get_model(pipeline_class_name): - for task_mapping in SUPPORTED_TASKS_MAPPINGS: - for model_name, pipeline in task_mapping.items(): - if pipeline.__name__ == pipeline_class_name: - return model_name +def _get_model(pipeline_class_name): + for task_mapping in SUPPORTED_TASKS_MAPPINGS: + for model_name, pipeline in task_mapping.items(): + if pipeline.__name__ == pipeline_class_name: + return model_name + - model_name = get_model(pipeline_class_name) +def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True): + model_name = _get_model(pipeline_class_name) if model_name is not None: task_class = mapping.get(model_name, None) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index d1c2c2adb4c3..b5ac6cc3012f 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -371,6 +371,22 @@ def maybe_raise_or_warn( ) +# a simpler version of get_class_obj_and_candidates, it won't work with custom code +def simple_get_class_obj(library_name, class_name): + from diffusers import pipelines + + is_pipeline_module = hasattr(pipelines, library_name) + + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + else: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + + return class_obj + + def get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None ): @@ -452,7 +468,7 @@ def _get_pipeline_class( revision=revision, ) - if class_obj.__name__ != "DiffusionPipeline": + if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline": return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) @@ -892,7 +908,10 @@ def _fetch_class_library_tuple(module): library = not_compiled_module.__module__ # retrieve class_name - class_name = not_compiled_module.__class__.__name__ + if isinstance(not_compiled_module, type): + class_name = not_compiled_module.__name__ + else: + class_name = not_compiled_module.__class__.__name__ return (library, class_name) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2c03811b51ab..0375fbb0856a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1986,11 +1986,13 @@ def from_pipe(cls, pipeline, **kwargs): f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs } + optional_components = ( + pipeline._optional_components + if hasattr(pipeline, "_optional_components") and pipeline._optional_components + else [] + ) missing_modules = ( - set(expected_modules) - - set(pipeline._optional_components) - - set(pipeline_kwargs.keys()) - - set(true_optional_modules) + set(expected_modules) - set(optional_components) - set(pipeline_kwargs.keys()) - set(true_optional_modules) ) if len(missing_modules) > 0: diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6d25047a0f1c..247769306b53 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,126 @@ from ..utils import DummyObject, requires_backends +class AdaptiveProjectedGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ClassifierFreeGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ClassifierFreeZeroStarGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class PerturbedAttentionGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SkipLayerGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SmoothedEnergyGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class TangentialClassifierFreeGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class FasterCacheConfig(metaclass=DummyObject): _backends = ["torch"] @@ -47,6 +167,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LayerSkipConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PyramidAttentionBroadcastConfig(metaclass=DummyObject): _backends = ["torch"] @@ -62,6 +197,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class SmoothedEnergyGuidanceConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def apply_faster_cache(*args, **kwargs): requires_backends(apply_faster_cache, ["torch"]) @@ -70,6 +220,10 @@ def apply_first_block_cache(*args, **kwargs): requires_backends(apply_first_block_cache, ["torch"]) +def apply_layer_skip(*args, **kwargs): + requires_backends(apply_layer_skip, ["torch"]) + + def apply_pyramid_attention_broadcast(*args, **kwargs): requires_backends(apply_pyramid_attention_broadcast, ["torch"]) @@ -1199,6 +1353,66 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ComponentsManager(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ComponentSpec(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ModularPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ModularPipelineBlocks(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 9cb869c67a3e..e9c732d16415 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,6 +2,36 @@ from ..utils import DummyObject, requires_backends +class StableDiffusionXLAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AllegroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 4878937ab202..8eb99038c172 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -20,8 +20,11 @@ import os import re import shutil +import signal import sys +import threading from pathlib import Path +from types import ModuleType from typing import Dict, Optional, Union from urllib import request @@ -37,6 +40,8 @@ # See https://huggingface.co/datasets/diffusers/community-pipelines-mirror COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror" +TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15)) +_HF_REMOTE_CODE_LOCK = threading.Lock() def get_diffusers_versions(): @@ -154,33 +159,87 @@ def check_imports(filename): return get_relative_imports(filename) -def get_class_in_module(class_name, module_path, pretrained_model_name_or_path=None): +def _raise_timeout_error(signum, frame): + raise ValueError( + "Loading this model requires you to execute custom code contained in the model repository on your local " + "machine. Please set the option `trust_remote_code=True` to permit loading of this model." + ) + + +def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code): + if trust_remote_code is None: + if has_remote_code and TIME_OUT_REMOTE_CODE > 0: + prev_sig_handler = None + try: + prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error) + signal.alarm(TIME_OUT_REMOTE_CODE) + while trust_remote_code is None: + answer = input( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n" + f"Do you wish to run the custom code? [y/N] " + ) + if answer.lower() in ["yes", "y", "1"]: + trust_remote_code = True + elif answer.lower() in ["no", "n", "0", ""]: + trust_remote_code = False + signal.alarm(0) + except Exception: + # OS which does not support signal.SIGALRM + raise ValueError( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + finally: + if prev_sig_handler is not None: + signal.signal(signal.SIGALRM, prev_sig_handler) + signal.alarm(0) + elif has_remote_code: + # For the CI which puts the timeout at 0 + _raise_timeout_error(None, None) + + if has_remote_code and not trust_remote_code: + raise ValueError( + f"Loading {model_name} requires you to execute the configuration file in that" + " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" + " set the option `trust_remote_code=True` to remove this error." + ) + + return trust_remote_code + + +def get_class_in_module(class_name, module_path, force_reload=False): """ Import a module on the cache directory for modules and extract a class from it. """ - module_path = module_path.replace(os.path.sep, ".") - try: - module = importlib.import_module(module_path) - except ModuleNotFoundError as e: - # This can happen when the repo id contains ".", which Python's import machinery interprets as a directory - # separator. We do a bit of monkey patching to detect and fix this case. - if not ( - pretrained_model_name_or_path is not None - and "." in pretrained_model_name_or_path - and module_path.startswith("diffusers_modules") - and pretrained_model_name_or_path.replace("/", "--") in module_path - ): - raise e # We can't figure this one out, just reraise the original error + name = os.path.normpath(module_path) + if name.endswith(".py"): + name = name[:-3] + name = name.replace(os.path.sep, ".") + module_file: Path = Path(HF_MODULES_CACHE) / module_path + + with _HF_REMOTE_CODE_LOCK: + if force_reload: + sys.modules.pop(name, None) + importlib.invalidate_caches() + cached_module: Optional[ModuleType] = sys.modules.get(name) + module_spec = importlib.util.spec_from_file_location(name, location=module_file) + + module: ModuleType + if cached_module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + else: + module = cached_module - corrected_path = os.path.join(HF_MODULES_CACHE, module_path.replace(".", "/")) + ".py" - corrected_path = corrected_path.replace( - pretrained_model_name_or_path.replace("/", "--").replace(".", "/"), - pretrained_model_name_or_path.replace("/", "--"), - ) - module = importlib.machinery.SourceFileLoader(module_path, corrected_path).load_module() + module_spec.loader.exec_module(module) if class_name is None: return find_pipeline_class(module) + return getattr(module, class_name) @@ -472,4 +531,4 @@ def get_class_from_dynamic_module( revision=revision, local_files_only=local_files_only, ) - return get_class_in_module(class_name, final_module.replace(".py", ""), pretrained_model_name_or_path) + return get_class_in_module(class_name, final_module) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index f80f96a3425d..8aaee5b75d93 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -467,6 +467,7 @@ def _upload_folder( token: Optional[str] = None, commit_message: Optional[str] = None, create_pr: bool = False, + subfolder: Optional[str] = None, ): """ Uploads all files in `working_dir` to `repo_id`. @@ -481,7 +482,12 @@ def _upload_folder( logger.info(f"Uploading the files of {working_dir} to {repo_id}.") return upload_folder( - repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr + repo_id=repo_id, + folder_path=working_dir, + token=token, + commit_message=commit_message, + create_pr=create_pr, + path_in_repo=subfolder, ) def push_to_hub( @@ -493,6 +499,7 @@ def push_to_hub( create_pr: bool = False, safe_serialization: bool = True, variant: Optional[str] = None, + subfolder: Optional[str] = None, ) -> str: """ Upload model, scheduler, or pipeline files to the 🤗 Hugging Face Hub. @@ -534,8 +541,9 @@ def push_to_hub( repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id # Create a new empty model card and eventually tag it - model_card = load_or_create_model_card(repo_id, token=token) - model_card = populate_model_card(model_card) + if not subfolder: + model_card = load_or_create_model_card(repo_id, token=token) + model_card = populate_model_card(model_card) # Save all files. save_kwargs = {"safe_serialization": safe_serialization} @@ -546,7 +554,8 @@ def push_to_hub( self.save_pretrained(tmpdir, **save_kwargs) # Update model card if needed: - model_card.save(os.path.join(tmpdir, "README.md")) + if not subfolder: + model_card.save(os.path.join(tmpdir, "README.md")) return self._upload_folder( tmpdir, @@ -554,4 +563,5 @@ def push_to_hub( token=token, commit_message=commit_message, create_pr=create_pr, + subfolder=subfolder, )