Skip to content

Commit 7f3e9b8

Browse files
sayakpaulasomoza
andauthored
make flux ready for mellon (huggingface#12419)
* make flux ready for mellon * up * Apply suggestions from code review Co-authored-by: Álvaro Somoza <[email protected]> --------- Co-authored-by: Álvaro Somoza <[email protected]>
1 parent ce90f9b commit 7f3e9b8

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,13 @@ def inputs(self) -> List[InputParam]:
252252
InputParam(
253253
"prompt_embeds",
254254
required=True,
255+
kwargs_type="denoiser_input_fields",
255256
type_hint=torch.Tensor,
256257
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
257258
),
258259
InputParam(
259260
"pooled_prompt_embeds",
261+
kwargs_type="denoiser_input_fields",
260262
type_hint=torch.Tensor,
261263
description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
262264
),
@@ -279,11 +281,13 @@ def intermediate_outputs(self) -> List[str]:
279281
OutputParam(
280282
"prompt_embeds",
281283
type_hint=torch.Tensor,
284+
kwargs_type="denoiser_input_fields",
282285
description="text embeddings used to guide the image generation",
283286
),
284287
OutputParam(
285288
"pooled_prompt_embeds",
286289
type_hint=torch.Tensor,
290+
kwargs_type="denoiser_input_fields",
287291
description="pooled text embeddings used to guide the image generation",
288292
),
289293
# TODO: support negative embeddings?

src/diffusers/modular_pipelines/flux/encoders.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def inputs(self) -> List[InputParam]:
181181
return [
182182
InputParam("prompt"),
183183
InputParam("prompt_2"),
184+
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
184185
InputParam("joint_attention_kwargs"),
185186
]
186187

@@ -189,16 +190,19 @@ def intermediate_outputs(self) -> List[OutputParam]:
189190
return [
190191
OutputParam(
191192
"prompt_embeds",
193+
kwargs_type="denoiser_input_fields",
192194
type_hint=torch.Tensor,
193195
description="text embeddings used to guide the image generation",
194196
),
195197
OutputParam(
196198
"pooled_prompt_embeds",
199+
kwargs_type="denoiser_input_fields",
197200
type_hint=torch.Tensor,
198201
description="pooled text embeddings used to guide the image generation",
199202
),
200203
OutputParam(
201204
"text_ids",
205+
kwargs_type="denoiser_input_fields",
202206
type_hint=torch.Tensor,
203207
description="ids from the text sequence for RoPE",
204208
),
@@ -404,6 +408,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
404408
pooled_prompt_embeds=None,
405409
device=block_state.device,
406410
num_images_per_prompt=1, # TODO: hardcoded for now.
411+
max_sequence_length=block_state.max_sequence_length,
407412
lora_scale=block_state.text_encoder_lora_scale,
408413
)
409414

src/diffusers/modular_pipelines/flux/modular_blocks.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def description(self):
8484

8585
# before_denoise: all task (text2img, img2img)
8686
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
87-
block_classes = [FluxBeforeDenoiseStep, FluxImg2ImgBeforeDenoiseStep]
88-
block_names = ["text2image", "img2img"]
89-
block_trigger_inputs = [None, "image_latents"]
87+
block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
88+
block_names = ["img2img", "text2image"]
89+
block_trigger_inputs = ["image_latents", None]
9090

9191
@property
9292
def description(self):
@@ -124,16 +124,32 @@ def description(self):
124124
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
125125

126126

127+
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
128+
block_classes = [FluxInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
129+
block_names = ["input", "before_denoise", "denoise"]
130+
131+
@property
132+
def description(self):
133+
return (
134+
"Core step that performs the denoising process. \n"
135+
+ " - `FluxInputStep` (input) standardizes the inputs for the denoising step.\n"
136+
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
137+
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
138+
+ "This step support text-to-image and image-to-image tasks for Flux:\n"
139+
+ " - for image-to-image generation, you need to provide `image_latents`\n"
140+
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
141+
)
142+
143+
127144
# text2image
128145
class FluxAutoBlocks(SequentialPipelineBlocks):
129146
block_classes = [
130147
FluxTextEncoderStep,
131148
FluxAutoVaeEncoderStep,
132-
FluxAutoBeforeDenoiseStep,
133-
FluxAutoDenoiseStep,
149+
FluxCoreDenoiseStep,
134150
FluxAutoDecodeStep,
135151
]
136-
block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decoder"]
152+
block_names = ["text_encoder", "image_encoder", "denoise", "decode"]
137153

138154
@property
139155
def description(self):
@@ -171,8 +187,7 @@ def description(self):
171187
[
172188
("text_encoder", FluxTextEncoderStep),
173189
("image_encoder", FluxAutoVaeEncoderStep),
174-
("before_denoise", FluxAutoBeforeDenoiseStep),
175-
("denoise", FluxAutoDenoiseStep),
190+
("denoise", FluxCoreDenoiseStep),
176191
("decode", FluxAutoDecodeStep),
177192
]
178193
)

0 commit comments

Comments
 (0)