Skip to content

Commit 7c7e8a4

Browse files
committed
fix
1 parent da1096e commit 7c7e8a4

File tree

4 files changed

+136
-77
lines changed

4 files changed

+136
-77
lines changed

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ def prepare_latents(
398398
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
399399
)
400400

401+
# TODO: move packing latents code to a patchifier
401402
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
402403
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
403404

@@ -436,12 +437,13 @@ class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
436437

437438
@property
438439
def description(self) -> str:
439-
return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified."
440+
return "Step that adds noise to image latents for image-to-image. Should be run after `set_timesteps`,"
441+
" `prepare_latents`. Both noise and image latents should already be patchified."
440442

441443
@property
442444
def expected_components(self) -> List[ComponentSpec]:
443445
return [
444-
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
446+
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)
445447
]
446448

447449
@property
@@ -521,9 +523,9 @@ def description(self) -> str:
521523
@property
522524
def inputs(self) -> List[InputParam]:
523525
return [
524-
InputParam(name="image_height", required=True),
525-
InputParam(name="image_width", required=True),
526-
InputParam(name="prompt_embeds"),
526+
InputParam(name="height", required=True),
527+
InputParam(name="width", required=True),
528+
InputParam(name="prompt_embeds")
527529
]
528530

529531
@property
@@ -552,8 +554,8 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
552554
device=prompt_embeds.device, dtype=prompt_embeds.dtype
553555
)
554556

555-
height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
556-
width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
557+
height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
558+
width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
557559
block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
558560

559561
self.set_block_state(state, block_state)

src/diffusers/modular_pipelines/flux/denoise.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,17 @@ def inputs(self) -> List[Tuple[str, Any]]:
7676
description="Pooled prompt embeddings",
7777
),
7878
InputParam(
79-
"text_ids",
79+
"txt_ids",
8080
required=True,
8181
type_hint=torch.Tensor,
8282
description="IDs computed from text sequence needed for RoPE",
8383
),
8484
InputParam(
85-
"latent_image_ids",
85+
"img_ids",
8686
required=True,
8787
type_hint=torch.Tensor,
8888
description="IDs computed from image sequence needed for RoPE",
8989
),
90-
# TODO: guidance
9190
]
9291

9392
@torch.no_grad()
@@ -101,8 +100,8 @@ def __call__(
101100
encoder_hidden_states=block_state.prompt_embeds,
102101
pooled_projections=block_state.pooled_prompt_embeds,
103102
joint_attention_kwargs=block_state.joint_attention_kwargs,
104-
txt_ids=block_state.text_ids,
105-
img_ids=block_state.latent_image_ids,
103+
txt_ids=block_state.txt_ids,
104+
img_ids=block_state.img_ids,
106105
return_dict=False,
107106
)[0]
108107
block_state.noise_pred = noise_pred

src/diffusers/modular_pipelines/flux/encoders.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,13 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
204204
dtype = components.vae.dtype
205205

206206
image = getattr(block_state, self._image_input_name)
207+
image = image.to(device=device, dtype=dtype)
207208

208209
# Encode image into latents
209210
image_latents = encode_vae_image(
210211
image=image,
211212
vae=components.vae,
212-
generator=block_state.generator,
213-
device=device,
214-
dtype=dtype,
215-
latent_channels=components.num_channels_latents,
213+
generator=block_state.generator
216214
)
217215
setattr(block_state, self._image_latents_output_name, image_latents)
218216

@@ -412,7 +410,6 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
412410
prompt_embeds=None,
413411
pooled_prompt_embeds=None,
414412
device=block_state.device,
415-
num_images_per_prompt=1, # TODO: hardcoded for now.
416413
max_sequence_length=block_state.max_sequence_length,
417414
lora_scale=block_state.text_encoder_lora_scale,
418415
)

src/diffusers/modular_pipelines/flux/modular_blocks.py

Lines changed: 121 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,43 @@
1818
from .before_denoise import (
1919
FluxImg2ImgPrepareLatentsStep,
2020
FluxImg2ImgSetTimestepsStep,
21-
FluxInputStep,
2221
FluxPrepareLatentsStep,
2322
FluxSetTimestepsStep,
2423
)
2524
from .decoders import FluxDecodeStep
2625
from .denoise import FluxDenoiseStep
27-
from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep
26+
from .encoders import FluxTextEncoderStep, FluxVaeEncoderDynamicStep
27+
from .before_denoise import FluxRoPEInputsStep
28+
from .inputs import FluxTextInputStep, FluxInputsDynamicStep
29+
2830

2931

3032
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3133

3234

3335
# vae encoder (run before before_denoise)
36+
from .encoders import FluxProcessImagesInputStep
37+
38+
FluxImg2ImgVaeEncoderBlocks = InsertableDict(
39+
[
40+
("preprocess", FluxProcessImagesInputStep()),
41+
("encode", FluxVaeEncoderDynamicStep()),
42+
]
43+
)
44+
45+
class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
46+
model_name = "flux"
47+
48+
block_classes = FluxImg2ImgVaeEncoderBlocks.values()
49+
block_names = FluxImg2ImgVaeEncoderBlocks.keys()
50+
51+
@property
52+
def description(self) -> str:
53+
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
54+
55+
3456
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
35-
block_classes = [FluxVaeEncoderStep]
57+
block_classes = [FluxImg2ImgVaeEncoderStep]
3658
block_names = ["img2img"]
3759
block_trigger_inputs = ["image"]
3860

@@ -41,44 +63,49 @@ def description(self):
4163
return (
4264
"Vae encoder step that encode the image inputs into their latent representations.\n"
4365
+ "This is an auto pipeline block that works for img2img tasks.\n"
44-
+ " - `FluxVaeEncoderStep` (img2img) is used when only `image` is provided."
45-
+ " - if `image` is provided, step will be skipped."
66+
+ " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided."
67+
+ " - if `image` is not provided, step will be skipped."
4668
)
4769

4870

49-
# before_denoise: text2img, img2img
50-
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
51-
block_classes = [
52-
FluxInputStep,
53-
FluxPrepareLatentsStep,
54-
FluxSetTimestepsStep,
71+
72+
# before_denoise: text2img
73+
FluxBeforeDenoiseBlocks = InsertableDict(
74+
[
75+
("prepare_latents", FluxPrepareLatentsStep()),
76+
("set_timesteps", FluxSetTimestepsStep()),
77+
("prepare_rope_inputs", FluxRoPEInputsStep())
5578
]
56-
block_names = ["input", "prepare_latents", "set_timesteps"]
79+
)
80+
81+
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
82+
block_classes = FluxBeforeDenoiseBlocks.values()
83+
block_names = FluxBeforeDenoiseBlocks.keys()
5784

5885
@property
5986
def description(self):
6087
return (
61-
"Before denoise step that prepare the inputs for the denoise step.\n"
62-
+ "This is a sequential pipeline blocks:\n"
63-
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
64-
+ " - `FluxPrepareLatentsStep` is used to prepare the latents\n"
65-
+ " - `FluxSetTimestepsStep` is used to set the timesteps\n"
88+
"Before denoise step that prepares the inputs for the denoise step in text-to-image generation."
6689
)
6790

6891

6992
# before_denoise: img2img
93+
FluxImg2ImgBeforeDenoiseBlocks = InsertableDict(
94+
[
95+
("prepare_latents", FluxPrepareLatentsStep()),
96+
("set_timesteps", FluxImg2ImgSetTimestepsStep()),
97+
("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
98+
("prepare_rope_inputs", FluxRoPEInputsStep())
99+
]
100+
)
70101
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
71-
block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep]
72-
block_names = ["input", "set_timesteps", "prepare_latents"]
102+
block_classes = FluxImg2ImgBeforeDenoiseBlocks.values()
103+
block_names = FluxImg2ImgBeforeDenoiseBlocks.keys()
73104

74105
@property
75106
def description(self):
76107
return (
77-
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
78-
+ "This is a sequential pipeline blocks:\n"
79-
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
80-
+ " - `FluxImg2ImgSetTimestepsStep` is used to set the timesteps\n"
81-
+ " - `FluxImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
108+
"Before denoise step that prepare the inputs for the denoise step for img2img task."
82109
)
83110

84111

@@ -113,7 +140,7 @@ def description(self) -> str:
113140
)
114141

115142

116-
# decode: all task (text2img, img2img, inpainting)
143+
# decode: all task (text2img, img2img)
117144
class FluxAutoDecodeStep(AutoPipelineBlocks):
118145
block_classes = [FluxDecodeStep]
119146
block_names = ["non-inpaint"]
@@ -124,32 +151,73 @@ def description(self):
124151
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
125152

126153

154+
# inputs: text2image/img2img
155+
FluxImg2ImgBlocks = InsertableDict(
156+
[
157+
("text_inputs", FluxTextInputStep()),
158+
("additional_inputs", FluxInputsDynamicStep())
159+
]
160+
)
161+
162+
class FluxImg2ImgInputStep(SequentialPipelineBlocks):
163+
model_name = "flux"
164+
block_classes = FluxImg2ImgBlocks.values()
165+
block_names = FluxImg2ImgBlocks.keys()
166+
167+
@property
168+
def description(self):
169+
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
170+
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
171+
" - update height/width based `image_latents`, patchify `image_latents`."
172+
173+
174+
class FluxImageAutoInputStep(AutoPipelineBlocks):
175+
block_classes = [FluxImg2ImgInputStep, FluxTextInputStep]
176+
block_names = ["img2img", "text2image"]
177+
block_trigger_inputs = [ "image_latents", None]
178+
179+
@property
180+
def description(self):
181+
return (
182+
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
183+
" This is an auto pipeline block that works for text2image/img2img tasks.\n"
184+
+ " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
185+
+ " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n"
186+
)
187+
188+
127189
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
128-
block_classes = [FluxInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
190+
model_name = "flux"
191+
block_classes = [FluxImageAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
129192
block_names = ["input", "before_denoise", "denoise"]
130193

131194
@property
132195
def description(self):
133196
return (
134197
"Core step that performs the denoising process. \n"
135-
+ " - `FluxInputStep` (input) standardizes the inputs for the denoising step.\n"
198+
+ " - `FluxImageAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
136199
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
137200
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
138-
+ "This step support text-to-image and image-to-image tasks for Flux:\n"
201+
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
139202
+ " - for image-to-image generation, you need to provide `image_latents`\n"
140-
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
203+
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
141204
)
142205

143206

144-
# text2image
145-
class FluxAutoBlocks(SequentialPipelineBlocks):
146-
block_classes = [
147-
FluxTextEncoderStep,
148-
FluxAutoVaeEncoderStep,
149-
FluxCoreDenoiseStep,
150-
FluxAutoDecodeStep,
207+
# Auto blocks (text2image and img2img)
208+
AUTO_BLOCKS = InsertableDict(
209+
[
210+
("text_encoder", FluxTextEncoderStep()),
211+
("image_encoder", FluxAutoVaeEncoderStep()),
212+
("denoise", FluxCoreDenoiseStep()),
213+
("decode", FluxDecodeStep())
151214
]
152-
block_names = ["text_encoder", "image_encoder", "denoise", "decode"]
215+
)
216+
class FluxAutoBlocks(SequentialPipelineBlocks):
217+
model_name = "flux"
218+
219+
block_classes = AUTO_BLOCKS.values()
220+
block_names = AUTO_BLOCKS.keys()
153221

154222
@property
155223
def description(self):
@@ -162,35 +230,28 @@ def description(self):
162230

163231
TEXT2IMAGE_BLOCKS = InsertableDict(
164232
[
165-
("text_encoder", FluxTextEncoderStep),
166-
("input", FluxInputStep),
167-
("prepare_latents", FluxPrepareLatentsStep),
168-
("set_timesteps", FluxSetTimestepsStep),
169-
("denoise", FluxDenoiseStep),
170-
("decode", FluxDecodeStep),
233+
("text_encoder", FluxTextEncoderStep()),
234+
("input", FluxTextInputStep()),
235+
("prepare_latents", FluxPrepareLatentsStep()),
236+
("set_timesteps", FluxSetTimestepsStep()),
237+
("prepare_rope_inputs", FluxRoPEInputsStep()),
238+
("denoise", FluxDenoiseStep()),
239+
("decode", FluxDecodeStep()),
171240
]
172241
)
173242

174243
IMAGE2IMAGE_BLOCKS = InsertableDict(
175244
[
176-
("text_encoder", FluxTextEncoderStep),
177-
("image_encoder", FluxVaeEncoderStep),
178-
("input", FluxInputStep),
179-
("set_timesteps", FluxImg2ImgSetTimestepsStep),
180-
("prepare_latents", FluxImg2ImgPrepareLatentsStep),
181-
("denoise", FluxDenoiseStep),
182-
("decode", FluxDecodeStep),
245+
("text_encoder", FluxTextEncoderStep()),
246+
("vae_encoder", FluxVaeEncoderDynamicStep()),
247+
("input", FluxImg2ImgInputStep()),
248+
("prepare_latents", FluxPrepareLatentsStep()),
249+
("set_timesteps", FluxImg2ImgSetTimestepsStep()),
250+
("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
251+
("prepare_rope_inputs", FluxRoPEInputsStep()),
252+
("denoise", FluxDenoiseStep()),
253+
("decode", FluxDecodeStep()),
183254
]
184255
)
185256

186-
AUTO_BLOCKS = InsertableDict(
187-
[
188-
("text_encoder", FluxTextEncoderStep),
189-
("image_encoder", FluxAutoVaeEncoderStep),
190-
("denoise", FluxCoreDenoiseStep),
191-
("decode", FluxAutoDecodeStep),
192-
]
193-
)
194-
195-
196257
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}

0 commit comments

Comments
 (0)