diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 80c78b8a96d5..1414d0fc690a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -364,6 +364,8 @@ else: _import_structure["modular_pipelines"].extend( [ + "FluxAutoBlocks", + "FluxModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", "WanAutoBlocks", @@ -999,6 +1001,8 @@ from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .modular_pipelines import ( + FluxAutoBlocks, + FluxModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, WanAutoBlocks, diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 5fa047257f09..9b558ddb2123 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -107,6 +107,7 @@ def _register(cls): def _register_attention_processors_metadata(): from ..models.attention_processor import AttnProcessor2_0 from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor + from ..models.transformers.transformer_flux import FluxAttnProcessor from ..models.transformers.transformer_wan import WanAttnProcessor2_0 # AttnProcessor2_0 @@ -132,6 +133,11 @@ def _register_attention_processors_metadata(): skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0, ), ) + # FluxAttnProcessor + AttentionProcessorRegistry.register( + model_class=FluxAttnProcessor, + metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor), + ) def _register_transformer_blocks_metadata(): @@ -271,4 +277,6 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, * _skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states +# not sure what this is yet. +_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states # fmt: on diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index b3025bf4d3ab..e0f2e31388ce 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -41,6 +41,7 @@ ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] _import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"] + _import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"] _import_structure["components_manager"] = ["ComponentsManager"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -51,6 +52,7 @@ from ..utils.dummy_pt_objects import * # noqa F403 else: from .components_manager import ComponentsManager + from .flux import FluxAutoBlocks, FluxModularPipeline from .modular_pipeline import ( AutoPipelineBlocks, BlockState, diff --git a/src/diffusers/modular_pipelines/flux/__init__.py b/src/diffusers/modular_pipelines/flux/__init__.py new file mode 100644 index 000000000000..2891edf79041 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux/__init__.py @@ -0,0 +1,66 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["encoders"] = ["FluxTextEncoderStep"] + _import_structure["modular_blocks"] = [ + "ALL_BLOCKS", + "AUTO_BLOCKS", + "TEXT2IMAGE_BLOCKS", + "FluxAutoBeforeDenoiseStep", + "FluxAutoBlocks", + "FluxAutoBlocks", + "FluxAutoDecodeStep", + "FluxAutoDenoiseStep", + ] + _import_structure["modular_pipeline"] = ["FluxModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .encoders import FluxTextEncoderStep + from .modular_blocks import ( + ALL_BLOCKS, + AUTO_BLOCKS, + TEXT2IMAGE_BLOCKS, + FluxAutoBeforeDenoiseStep, + FluxAutoBlocks, + FluxAutoDecodeStep, + FluxAutoDenoiseStep, + ) + from .modular_pipeline import FluxModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py new file mode 100644 index 000000000000..ffc77bb24fdb --- /dev/null +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -0,0 +1,420 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Union + +import numpy as np +import torch + +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import FluxModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + +def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + +class FluxInputStep(PipelineBlock): + model_name = "flux" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.", + ), + # TODO: support negative embeddings? + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + description="text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + description="pooled text embeddings used to guide the image generation", + ), + # TODO: support negative embeddings? + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None: + if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]: + raise ValueError( + "`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`" + f" {block_state.pooled_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + # TODO: consider adding negative embeddings? + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + self.set_block_state(state, block_state) + + return components, state + + +class FluxSetTimestepsStep(PipelineBlock): + model_name = "flux" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("guidance_scale", default=3.5), + InputParam("latents", type_hint=torch.Tensor), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ) + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."), + ] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + scheduler = components.scheduler + + latents = block_state.latents + image_seq_len = latents.shape[1] + + num_inference_steps = block_state.num_inference_steps + sigmas = block_state.sigmas + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas: + sigmas = None + block_state.sigmas = sigmas + mu = calculate_shift( + image_seq_len, + scheduler.config.get("base_image_seq_len", 256), + scheduler.config.get("max_image_seq_len", 4096), + scheduler.config.get("base_shift", 0.5), + scheduler.config.get("max_shift", 1.15), + ) + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu + ) + if components.transformer.config.guidance_embeds: + guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + block_state.guidance = guidance + + self.set_block_state(state, block_state) + return components, state + + +class FluxPrepareLatentsStep(PipelineBlock): + model_name = "flux" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-video generation process" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=Optional[torch.Tensor]), + InputParam("num_images_per_prompt", type_hint=int, default=1), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + OutputParam( + "latent_image_ids", + type_hint=torch.Tensor, + description="IDs computed from the image sequence needed for RoPE", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % (components.vae_scale_factor * 2) != 0) or ( + block_state.width is not None and block_state.width % (components.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + def prepare_latents( + comp, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # Couldn't use the `prepare_latents` method directly from Flux because I decided to copy over + # the packing methods here. So, for example, `comp._pack_latents()` won't work if we were + # to go with the "# Copied from ..." approach. Or maybe there's a way? + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = _pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.device = components._execution_device + block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this? + block_state.num_channels_latents = components.num_channels_latents + + self.check_inputs(components, block_state) + + block_state.latents, block_state.latent_image_ids = self.prepare_latents( + components, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/flux/decoders.py b/src/diffusers/modular_pipelines/flux/decoders.py new file mode 100644 index 000000000000..8d561d38c6f2 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux/decoders.py @@ -0,0 +1,114 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple, Union + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL +from ...utils import logging +from ...video_processor import VaeImageProcessor +from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + +class FluxDecodeStep(PipelineBlock): + model_name = "flux" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + InputParam("height", default=1024), + InputParam("width", default=1024), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ) + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray], + description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae = components.vae + + if not block_state.output_type == "latent": + latents = block_state.latents + latents = _unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor) + latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + else: + block_state.images = block_state.latents + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py new file mode 100644 index 000000000000..c4619c17fb0e --- /dev/null +++ b/src/diffusers/modular_pipelines/flux/denoise.py @@ -0,0 +1,230 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple + +import torch + +from ...models import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + PipelineBlock, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import FluxModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class FluxLoopDenoiser(PipelineBlock): + model_name = "flux" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("transformer", FluxTransformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `FluxDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [InputParam("joint_attention_kwargs")] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "guidance", + required=True, + type_hint=torch.Tensor, + description="Guidance scale as a tensor", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Prompt embeddings", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pooled prompt embeddings", + ), + InputParam( + "text_ids", + required=True, + type_hint=torch.Tensor, + description="IDs computed from text sequence needed for RoPE", + ), + InputParam( + "latent_image_ids", + required=True, + type_hint=torch.Tensor, + description="IDs computed from image sequence needed for RoPE", + ), + # TODO: guidance + ] + + @torch.no_grad() + def __call__( + self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + noise_pred = components.transformer( + hidden_states=block_state.latents, + timestep=t.flatten() / 1000, + guidance=block_state.guidance, + encoder_hidden_states=block_state.prompt_embeds, + pooled_projections=block_state.pooled_prompt_embeds, + joint_attention_kwargs=block_state.joint_attention_kwargs, + txt_ids=block_state.text_ids, + img_ids=block_state.latent_image_ids, + return_dict=False, + )[0] + block_state.noise_pred = noise_pred + + return components, block_state + + +class FluxLoopAfterDenoiser(PipelineBlock): + model_name = "flux" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `FluxDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [] + + @property + def intermediate_inputs(self) -> List[str]: + return [InputParam("generator")] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # Perform scheduler step using the predicted output + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", FluxTransformer2DModel), + ] + + @property + def loop_intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + components.scheduler.set_begin_index(0) + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +class FluxDenoiseStep(FluxDenoiseLoopWrapper): + block_classes = [FluxLoopDenoiser, FluxLoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `FluxLoopDenoiser`\n" + " - `FluxLoopAfterDenoiser`\n" + "This block supports text2image tasks." + ) diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py new file mode 100644 index 000000000000..9bf2f54eece3 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux/encoders.py @@ -0,0 +1,306 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import List, Optional, Union + +import regex as re +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin +from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers +from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import FluxModularPipeline + + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class FluxTextEncoderStep(PipelineBlock): + model_name = "flux" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("text_encoder_2", T5EncoderModel), + ComponentSpec("tokenizer_2", T5TokenizerFast), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("joint_attention_kwargs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + description="text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + description="pooled text embeddings used to guide the image generation", + ), + OutputParam( + "text_ids", + type_hint=torch.Tensor, + description="ids from the text sequence for RoPE", + ), + ] + + @staticmethod + def check_inputs(block_state): + for prompt in [block_state.prompt, block_state.prompt_2]: + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + def _get_t5_prompt_embeds( + components, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + max_sequence_length: int, + device: torch.device, + ): + dtype = components.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2) + + text_inputs = components.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + untruncated_ids = components.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = components.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + @staticmethod + def _get_clip_prompt_embeds( + components, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + device: torch.device, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, components.tokenizer) + + text_inputs = components.tokenizer( + prompt, + padding="max_length", + max_length=components.tokenizer.model_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + tokenizer_max_length = components.tokenizer.model_max_length + untruncated_ids = components.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = components.tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = components.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + @staticmethod + def encode_prompt( + components, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or components._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin): + components._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if components.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = FluxTextEncoderStep._get_clip_prompt_embeds( + components, + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds( + components, + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if components.text_encoder is not None: + if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder_2, lora_scale) + + dtype = components.text_encoder.dtype if components.text_encoder is not None else torch.bfloat16 + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + # Encode input prompt + block_state.text_encoder_lora_scale = ( + block_state.joint_attention_kwargs.get("scale", None) + if block_state.joint_attention_kwargs is not None + else None + ) + (block_state.prompt_embeds, block_state.pooled_prompt_embeds, block_state.text_ids) = self.encode_prompt( + components, + prompt=block_state.prompt, + prompt_2=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + device=block_state.device, + num_images_per_prompt=1, # hardcoded for now. + lora_scale=block_state.text_encoder_lora_scale, + ) + + # Add outputs + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py new file mode 100644 index 000000000000..b17067303785 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py @@ -0,0 +1,125 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import FluxInputStep, FluxPrepareLatentsStep, FluxSetTimestepsStep +from .decoders import FluxDecodeStep +from .denoise import FluxDenoiseStep +from .encoders import FluxTextEncoderStep + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# before_denoise: text2vid +class FluxBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + FluxInputStep, + FluxPrepareLatentsStep, + FluxSetTimestepsStep, + ] + block_names = ["input", "prepare_latents", "set_timesteps"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is a sequential pipeline blocks:\n" + + " - `FluxInputStep` is used to adjust the batch size of the model inputs\n" + + " - `FluxPrepareLatentsStep` is used to prepare the latents\n" + + " - `FluxSetTimestepsStep` is used to set the timesteps\n" + ) + + +# before_denoise: all task (text2vid,) +class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [FluxBeforeDenoiseStep] + block_names = ["text2image"] + block_trigger_inputs = [None] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is an auto pipeline block that works for text2image.\n" + + " - `FluxBeforeDenoiseStep` (text2image) is used.\n" + ) + + +# denoise: text2image +class FluxAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [FluxDenoiseStep] + block_names = ["denoise"] + block_trigger_inputs = [None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2image tasks." + " - `FluxDenoiseStep` (denoise) for text2image tasks." + ) + + +# decode: all task (text2img, img2img, inpainting) +class FluxAutoDecodeStep(AutoPipelineBlocks): + block_classes = [FluxDecodeStep] + block_names = ["non-inpaint"] + block_trigger_inputs = [None] + + @property + def description(self): + return "Decode step that decode the denoised latents into videos outputs.\n - `FluxDecodeStep`" + + +# text2image +class FluxAutoBlocks(SequentialPipelineBlocks): + block_classes = [FluxTextEncoderStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep, FluxAutoDecodeStep] + block_names = ["text_encoder", "before_denoise", "denoise", "decoder"] + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-image using Flux.\n" + + "- for text-to-image generation, all you need to provide is `prompt`" + ) + + +TEXT2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", FluxTextEncoderStep), + ("input", FluxInputStep), + ("prepare_latents", FluxPrepareLatentsStep), + # Setting it after preparation of latents because we rely on `latents` + # to calculate `img_seq_len` for `shift`. + ("set_timesteps", FluxSetTimestepsStep), + ("denoise", FluxDenoiseStep), + ("decode", FluxDecodeStep), + ] +) + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", FluxTextEncoderStep), + ("before_denoise", FluxAutoBeforeDenoiseStep), + ("denoise", FluxAutoDenoiseStep), + ("decode", FluxAutoDecodeStep), + ] +) + + +ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "auto": AUTO_BLOCKS} diff --git a/src/diffusers/modular_pipelines/flux/modular_pipeline.py b/src/diffusers/modular_pipelines/flux/modular_pipeline.py new file mode 100644 index 000000000000..3cd5df0c70ee --- /dev/null +++ b/src/diffusers/modular_pipelines/flux/modular_pipeline.py @@ -0,0 +1,59 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import FluxLoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin): + """ + A ModularPipeline for Flux. + + + + This is an experimental feature and is likely to change in the future. + + + """ + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if getattr(self, "transformer", None): + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 8838a1cb5942..0ef1d59f4d65 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -61,6 +61,7 @@ [ ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), ("wan", "WanModularPipeline"), + ("flux", "FluxModularPipeline"), ] ) @@ -68,6 +69,7 @@ [ ("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"), ("WanModularPipeline", "WanAutoBlocks"), + ("FluxModularPipeline", "FluxAutoBlocks"), ] ) @@ -1663,7 +1665,7 @@ def get_block_state(self, state: PipelineState) -> dict: if input_param.name: value = state.get_intermediate(input_param.name) if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + raise ValueError(f"Required intermediate input '{input_param.name}' is missing.") elif value is not None or (value is None and input_param.name not in data): data[input_param.name] = value elif input_param.kwargs_type: diff --git a/src/diffusers/pipelines/flux/pipeline_output.py b/src/diffusers/pipelines/flux/pipeline_output.py index 388824e89f87..69e742d3e026 100644 --- a/src/diffusers/pipelines/flux/pipeline_output.py +++ b/src/diffusers/pipelines/flux/pipeline_output.py @@ -11,12 +11,14 @@ @dataclass class FluxPipelineOutput(BaseOutput): """ - Output class for Stable Diffusion pipelines. + Output class for Flux image generation pipelines. Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. """ images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 7538635c808e..20382eafea84 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,6 +2,36 @@ from ..utils import DummyObject, requires_backends +class FluxAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class FluxModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionXLAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"]