Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/modular_pipelines/flux/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,13 @@ def inputs(self) -> List[InputParam]:
InputParam(
"prompt_embeds",
required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
InputParam(
"pooled_prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
),
Expand All @@ -279,11 +281,13 @@ def intermediate_outputs(self) -> List[str]:
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="pooled text embeddings used to guide the image generation",
),
# TODO: support negative embeddings?
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/modular_pipelines/flux/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took the liberty of supporting this input from the user so that Schnell can also work. In another PR, I will harmonize the steps in Flux Modular that include repetition along the batch size dimension (similar to Qwen).

Cc: @yiyixuxu

InputParam("joint_attention_kwargs"),
]

Expand All @@ -189,16 +190,19 @@ def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="pooled text embeddings used to guide the image generation",
),
OutputParam(
"text_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="ids from the text sequence for RoPE",
),
Expand Down Expand Up @@ -404,6 +408,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
pooled_prompt_embeds=None,
device=block_state.device,
num_images_per_prompt=1, # TODO: hardcoded for now.
max_sequence_length=block_state.max_sequence_length,
lora_scale=block_state.text_encoder_lora_scale,
)

Expand Down
31 changes: 23 additions & 8 deletions src/diffusers/modular_pipelines/flux/modular_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def description(self):

# before_denoise: all task (text2img, img2img)
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxBeforeDenoiseStep, FluxImg2ImgBeforeDenoiseStep]
block_names = ["text2image", "img2img"]
block_trigger_inputs = [None, "image_latents"]
block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
block_names = ["img2img", "text2image"]
block_trigger_inputs = ["image_latents", None]

@property
def description(self):
Expand Down Expand Up @@ -124,16 +124,32 @@ def description(self):
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"


class FluxCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [FluxInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]

@property
def description(self):
return (
"Core step that performs the denoising process. \n"
+ " - `FluxInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step support text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
)


# text2image
class FluxAutoBlocks(SequentialPipelineBlocks):
block_classes = [
FluxTextEncoderStep,
FluxAutoVaeEncoderStep,
FluxAutoBeforeDenoiseStep,
FluxAutoDenoiseStep,
FluxCoreDenoiseStep,
FluxAutoDecodeStep,
]
block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decoder"]
block_names = ["text_encoder", "image_encoder", "denoise", "decode"]

@property
def description(self):
Expand Down Expand Up @@ -171,8 +187,7 @@ def description(self):
[
("text_encoder", FluxTextEncoderStep),
("image_encoder", FluxAutoVaeEncoderStep),
("before_denoise", FluxAutoBeforeDenoiseStep),
("denoise", FluxAutoDenoiseStep),
("denoise", FluxCoreDenoiseStep),
("decode", FluxAutoDecodeStep),
]
)
Expand Down
Loading