Skip to content

Commit 0484e77

Browse files
committed
up
1 parent 0636e9d commit 0484e77

File tree

5 files changed

+79
-24
lines changed

5 files changed

+79
-24
lines changed

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,24 @@ def description(self) -> str:
233233

234234
@property
235235
def inputs(self) -> List[InputParam]:
236-
return [InputParam("num_inference_steps", default=50), InputParam("timesteps"), InputParam("sigmas")]
236+
return [
237+
InputParam("num_inference_steps", default=50),
238+
InputParam("timesteps"),
239+
InputParam("sigmas"),
240+
InputParam("guidance_scale", default=3.5),
241+
InputParam("latents", type_hint=torch.Tensor)
242+
]
243+
244+
@property
245+
def intermediate_inputs(self) -> List[str]:
246+
return [
247+
InputParam(
248+
"latents",
249+
required=True,
250+
type_hint=torch.Tensor,
251+
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
252+
)
253+
]
237254

238255
@property
239256
def intermediate_outputs(self) -> List[OutputParam]:
@@ -244,6 +261,7 @@ def intermediate_outputs(self) -> List[OutputParam]:
244261
type_hint=int,
245262
description="The number of denoising steps to perform at inference time",
246263
),
264+
OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used.")
247265
]
248266

249267
@torch.no_grad()
@@ -271,6 +289,12 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
271289
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
272290
scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu
273291
)
292+
if components.transformer.config.guidance_embeds:
293+
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
294+
guidance = guidance.expand(latents.shape[0])
295+
else:
296+
guidance = None
297+
block_state.guidance = guidance
274298

275299
self.set_block_state(state, block_state)
276300
return components, state
@@ -314,8 +338,12 @@ def intermediate_outputs(self) -> List[OutputParam]:
314338
return [
315339
OutputParam(
316340
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
341+
),
342+
OutputParam(
343+
"latent_image_ids", type_hint=torch.Tensor, description="IDs computed from the image sequence needed for RoPE"
317344
)
318345
]
346+
319347

320348
@staticmethod
321349
def check_inputs(components, block_state):
@@ -378,7 +406,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
378406

379407
self.check_inputs(components, block_state)
380408

381-
block_state.latents = self.prepare_latents(
409+
block_state.latents, block_state.latent_image_ids = self.prepare_latents(
382410
components,
383411
block_state.batch_size * block_state.num_images_per_prompt,
384412
block_state.num_channels_latents,
@@ -389,7 +417,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
389417
block_state.generator,
390418
block_state.latents,
391419
)
392-
420+
393421
self.set_block_state(state, block_state)
394422

395423
return components, state

src/diffusers/modular_pipelines/flux/denoise.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from ...models import FluxTransformer2DModel
2020
from ...schedulers import FlowMatchEulerDiscreteScheduler
2121
from ...utils import logging
22+
from ...configuration_utils import FrozenDict
23+
from ...guiders import ClassifierFreeGuidance
2224
from ..modular_pipeline import (
2325
BlockState,
2426
LoopSequentialPipelineBlocks,
@@ -37,7 +39,9 @@ class FluxLoopDenoiser(PipelineBlock):
3739

3840
@property
3941
def expected_components(self) -> List[ComponentSpec]:
40-
return [ComponentSpec("transformer", FluxTransformer2DModel)]
42+
return [
43+
ComponentSpec("transformer", FluxTransformer2DModel)
44+
]
4145

4246
@property
4347
def description(self) -> str:
@@ -49,9 +53,7 @@ def description(self) -> str:
4953

5054
@property
5155
def inputs(self) -> List[Tuple[str, Any]]:
52-
return [
53-
InputParam("attention_kwargs"),
54-
]
56+
return [InputParam("joint_attention_kwargs")]
5557

5658
@property
5759
def intermediate_inputs(self) -> List[str]:
@@ -63,10 +65,34 @@ def intermediate_inputs(self) -> List[str]:
6365
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
6466
),
6567
InputParam(
66-
"num_inference_steps",
68+
"guidance",
6769
required=True,
68-
type_hint=int,
69-
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
70+
type_hint=torch.Tensor,
71+
description="Guidance scale as a tensor",
72+
),
73+
InputParam(
74+
"prompt_embeds",
75+
required=True,
76+
type_hint=torch.Tensor,
77+
description="Prompt embeddings",
78+
),
79+
InputParam(
80+
"pooled_prompt_embeds",
81+
required=True,
82+
type_hint=torch.Tensor,
83+
description="Pooled prompt embeddings",
84+
),
85+
InputParam(
86+
"text_ids",
87+
required=True,
88+
type_hint=torch.Tensor,
89+
description="IDs computed from text sequence needed for RoPE",
90+
),
91+
InputParam(
92+
"latent_image_ids",
93+
required=True,
94+
type_hint=torch.Tensor,
95+
description="IDs computed from image sequence needed for RoPE",
7096
),
7197
# TODO: guidance
7298
]
@@ -78,9 +104,10 @@ def __call__(
78104
noise_pred = components.transformer(
79105
hidden_states=block_state.latents,
80106
timestep=t.flatten() / 1000,
107+
guidance=block_state.guidance,
81108
encoder_hidden_states=block_state.prompt_embeds,
82109
pooled_projections=block_state.pooled_prompt_embeds,
83-
attention_kwargs=block_state.attention_kwargs,
110+
joint_attention_kwargs=block_state.joint_attention_kwargs,
84111
txt_ids=block_state.text_ids,
85112
img_ids=block_state.latent_image_ids,
86113
return_dict=False,
@@ -96,7 +123,7 @@ class FluxLoopAfterDenoiser(PipelineBlock):
96123
@property
97124
def expected_components(self) -> List[ComponentSpec]:
98125
return [
99-
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
126+
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)
100127
]
101128

102129
@property
@@ -113,9 +140,7 @@ def inputs(self) -> List[Tuple[str, Any]]:
113140

114141
@property
115142
def intermediate_inputs(self) -> List[str]:
116-
return [
117-
InputParam("generator"),
118-
]
143+
return [InputParam("generator")]
119144

120145
@property
121146
def intermediate_outputs(self) -> List[OutputParam]:
@@ -129,7 +154,6 @@ def __call__(self, components: FluxModularPipeline, block_state: BlockState, i:
129154
block_state.noise_pred,
130155
t,
131156
block_state.latents,
132-
**block_state.scheduler_step_kwargs,
133157
return_dict=False,
134158
)[0]
135159

@@ -199,9 +223,9 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
199223
class FluxDenoiseStep(FluxDenoiseLoopWrapper):
200224
block_classes = [
201225
FluxLoopDenoiser,
202-
FluxLoopAfterDenoiser,
226+
FluxLoopAfterDenoiser
203227
]
204-
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
228+
block_names = ["denoiser", "after_denoiser"]
205229

206230
@property
207231
def description(self) -> str:

src/diffusers/modular_pipelines/flux/encoders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
294294
else None
295295
)
296296
(block_state.prompt_embeds, block_state.pooled_prompt_embeds, block_state.text_ids) = self.encode_prompt(
297+
components,
297298
prompt=block_state.prompt,
298299
prompt_2=None,
299300
prompt_embeds=None,

src/diffusers/modular_pipelines/flux/modular_blocks.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,26 @@
2828
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
2929
block_classes = [
3030
FluxInputStep,
31-
FluxSetTimestepsStep,
3231
FluxPrepareLatentsStep,
32+
FluxSetTimestepsStep,
3333
]
34-
block_names = ["input", "set_timesteps", "prepare_latents"]
34+
block_names = ["input", "prepare_latents", "set_timesteps"]
3535

3636
@property
3737
def description(self):
3838
return (
3939
"Before denoise step that prepare the inputs for the denoise step.\n"
4040
+ "This is a sequential pipeline blocks:\n"
4141
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
42-
+ " - `FluxSetTimestepsStep` is used to set the timesteps\n"
4342
+ " - `FluxPrepareLatentsStep` is used to prepare the latents\n"
43+
+ " - `FluxSetTimestepsStep` is used to set the timesteps\n"
4444
)
4545

4646

4747
# before_denoise: all task (text2vid,)
4848
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
4949
block_classes = [
50-
FluxBeforeDenoiseStep,
50+
FluxBeforeDenoiseStep
5151
]
5252
block_names = ["text2image"]
5353
block_trigger_inputs = [None]
@@ -114,8 +114,10 @@ def description(self):
114114
[
115115
("text_encoder", FluxTextEncoderStep),
116116
("input", FluxInputStep),
117-
("set_timesteps", FluxSetTimestepsStep),
118117
("prepare_latents", FluxPrepareLatentsStep),
118+
# Setting it after preparation of latents because we rely on `latents`
119+
# to calculate `img_seq_len` for `shift`.
120+
("set_timesteps", FluxSetTimestepsStep),
119121
("denoise", FluxDenoiseStep),
120122
("decode", FluxDecodeStep),
121123
]

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1665,7 +1665,7 @@ def get_block_state(self, state: PipelineState) -> dict:
16651665
if input_param.name:
16661666
value = state.get_intermediate(input_param.name)
16671667
if input_param.required and value is None:
1668-
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
1668+
raise ValueError(f"Required intermediate input '{input_param.name}' is missing.")
16691669
elif value is not None or (value is None and input_param.name not in data):
16701670
data[input_param.name] = value
16711671
elif input_param.kwargs_type:

0 commit comments

Comments
 (0)