Skip to content

Commit 52bd9d7

Browse files
committed
up
1 parent fba8f3b commit 52bd9d7

File tree

3 files changed

+67
-31
lines changed

3 files changed

+67
-31
lines changed

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def calculate_shift(
104104
return mu
105105

106106

107+
# Adapted from the original implementation.
107108
def prepare_latents_img2img(
108109
vae, scheduler, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator
109110
):
@@ -196,8 +197,19 @@ def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
196197
return image_latents
197198

198199

199-
def _get_timesteps_and_optionals(transformer, scheduler, latents, num_inference_steps, guidance_scale, sigmas, device):
200-
image_seq_len = latents.shape[1]
200+
def _get_initial_timesteps_and_optionals(
201+
transformer,
202+
scheduler,
203+
batch_size,
204+
height,
205+
width,
206+
vae_scale_factor,
207+
num_inference_steps,
208+
guidance_scale,
209+
sigmas,
210+
device,
211+
):
212+
image_seq_len = (int(height) // vae_scale_factor // 2) * (int(width) // vae_scale_factor // 2)
201213

202214
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
203215
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
@@ -212,7 +224,7 @@ def _get_timesteps_and_optionals(transformer, scheduler, latents, num_inference_
212224
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
213225
if transformer.config.guidance_embeds:
214226
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
215-
guidance = guidance.expand(latents.shape[0])
227+
guidance = guidance.expand(batch_size)
216228
else:
217229
guidance = None
218230

@@ -328,18 +340,20 @@ def inputs(self) -> List[InputParam]:
328340
InputParam("timesteps"),
329341
InputParam("sigmas"),
330342
InputParam("guidance_scale", default=3.5),
331-
InputParam("latents", type_hint=torch.Tensor),
343+
InputParam("num_images_per_prompt", default=1),
344+
InputParam("height", type_hint=int),
345+
InputParam("width", type_hint=int),
332346
]
333347

334348
@property
335349
def intermediate_inputs(self) -> List[str]:
336350
return [
337351
InputParam(
338-
"latents",
352+
"batch_size",
339353
required=True,
340-
type_hint=torch.Tensor,
341-
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
342-
)
354+
type_hint=int,
355+
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
356+
),
343357
]
344358

345359
@property
@@ -362,10 +376,14 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
362376
scheduler = components.scheduler
363377
transformer = components.transformer
364378

365-
timesteps, num_inference_steps, sigmas, guidance = _get_timesteps_and_optionals(
379+
batch_size = block_state.batch_size * block_state.num_images_per_prompt
380+
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
366381
transformer,
367382
scheduler,
368-
block_state.latents,
383+
batch_size,
384+
block_state.height,
385+
block_state.width,
386+
components.vae_scale_factor,
369387
block_state.num_inference_steps,
370388
block_state.guidance_scale,
371389
block_state.sigmas,
@@ -397,20 +415,22 @@ def inputs(self) -> List[InputParam]:
397415
InputParam("num_inference_steps", default=50),
398416
InputParam("timesteps"),
399417
InputParam("sigmas"),
418+
InputParam("strength", default=0.6),
400419
InputParam("guidance_scale", default=3.5),
401-
InputParam("latents", type_hint=torch.Tensor),
402420
InputParam("num_images_per_prompt", default=1),
421+
InputParam("height", type_hint=int),
422+
InputParam("width", type_hint=int),
403423
]
404424

405425
@property
406426
def intermediate_inputs(self) -> List[str]:
407427
return [
408428
InputParam(
409-
"latents",
429+
"batch_size",
410430
required=True,
411-
type_hint=torch.Tensor,
412-
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
413-
)
431+
type_hint=int,
432+
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
433+
),
414434
]
415435

416436
@property
@@ -430,30 +450,48 @@ def intermediate_outputs(self) -> List[OutputParam]:
430450
OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
431451
]
432452

453+
@staticmethod
454+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self->scheduler
455+
def get_timesteps(scheduler, num_inference_steps, strength, device):
456+
# get the original timestep using init_timestep
457+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
458+
459+
t_start = int(max(num_inference_steps - init_timestep, 0))
460+
timesteps = scheduler.timesteps[t_start * scheduler.order :]
461+
if hasattr(scheduler, "set_begin_index"):
462+
scheduler.set_begin_index(t_start * scheduler.order)
463+
464+
return timesteps, num_inference_steps - t_start
465+
433466
@torch.no_grad()
434467
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
435468
block_state = self.get_block_state(state)
436469
block_state.device = components._execution_device
437470

438471
scheduler = components.scheduler
439472
transformer = components.transformer
440-
441-
timesteps, num_inference_steps, sigmas, guidance = _get_timesteps_and_optionals(
473+
batch_size = block_state.batch_size * block_state.num_images_per_prompt
474+
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
442475
transformer,
443476
scheduler,
444-
block_state.latents,
477+
batch_size,
478+
block_state.height,
479+
block_state.width,
480+
components.vae_scale_factor,
445481
block_state.num_inference_steps,
446482
block_state.guidance_scale,
447483
block_state.sigmas,
448484
block_state.device,
449485
)
486+
timesteps, num_inference_steps = self.get_timesteps(
487+
scheduler, num_inference_steps, block_state.strength, block_state.device
488+
)
450489
block_state.timesteps = timesteps
451490
block_state.num_inference_steps = num_inference_steps
452491
block_state.sigmas = sigmas
453492
block_state.guidance = guidance
454493

455-
batch_size = block_state.latents.shape[0]
456-
block_state.latent_timestep = timesteps[:1].repeat(batch_size * block_state.num_images_per_prompt)
494+
block_state.latent_timestep = timesteps[:1].repeat(batch_size)
457495

458496
self.set_block_state(state, block_state)
459497
return components, state
@@ -468,7 +506,7 @@ def expected_components(self) -> List[ComponentSpec]:
468506

469507
@property
470508
def description(self) -> str:
471-
return "Prepare latents step that prepares the latents for the text-to-video generation process"
509+
return "Prepare latents step that prepares the latents for the text-to-image generation process"
472510

473511
@property
474512
def inputs(self) -> List[InputParam]:
@@ -565,10 +603,10 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
565603
block_state.num_channels_latents = components.num_channels_latents
566604

567605
self.check_inputs(components, block_state)
568-
606+
batch_size = block_state.batch_size * block_state.num_images_per_prompt
569607
block_state.latents, block_state.latent_image_ids = self.prepare_latents(
570608
components,
571-
block_state.batch_size * block_state.num_images_per_prompt,
609+
batch_size,
572610
block_state.num_channels_latents,
573611
block_state.height,
574612
block_state.width,
@@ -601,7 +639,6 @@ def inputs(self) -> List[Tuple[str, Any]]:
601639
InputParam("width", type_hint=int),
602640
InputParam("latents", type_hint=Optional[torch.Tensor]),
603641
InputParam("num_images_per_prompt", type_hint=int, default=1),
604-
InputParam("latents"),
605642
]
606643

607644
@property
@@ -655,14 +692,14 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
655692
block_state.device = components._execution_device
656693

657694
# TODO: implement `check_inputs`
658-
695+
batch_size = block_state.batch_size * block_state.num_images_per_prompt
659696
if block_state.latents is None:
660697
block_state.latents, block_state.latent_image_ids = prepare_latents_img2img(
661698
components.vae,
662699
components.scheduler,
663700
block_state.image_latents,
664701
block_state.latent_timestep,
665-
block_state.batch_size * block_state.num_images_per_prompt,
702+
batch_size,
666703
block_state.num_channels_latents,
667704
block_state.height,
668705
block_state.width,

src/diffusers/modular_pipelines/flux/encoders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def intermediate_outputs(self) -> List[OutputParam]:
112112
)
113113
]
114114

115+
@staticmethod
115116
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae
116117
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
117118
if isinstance(generator, list):
@@ -148,7 +149,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
148149
)
149150

150151
block_state.image_latents = self._encode_vae_image(
151-
components, image=block_state.image, generator=block_state.generator
152+
components.vae, image=block_state.image, generator=block_state.generator
152153
)
153154

154155
self.set_block_state(state, block_state)

src/diffusers/modular_pipelines/flux/modular_blocks.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,8 @@ def description(self):
148148
[
149149
("text_encoder", FluxTextEncoderStep),
150150
("input", FluxInputStep),
151-
("prepare_latents", FluxPrepareLatentsStep),
152-
# Setting it after preparation of latents because we rely on `latents`
153-
# to calculate `img_seq_len` for `shift`.
154151
("set_timesteps", FluxSetTimestepsStep),
152+
("prepare_latents", FluxPrepareLatentsStep),
155153
("denoise", FluxDenoiseStep),
156154
("decode", FluxDecodeStep),
157155
]
@@ -162,8 +160,8 @@ def description(self):
162160
("text_encoder", FluxTextEncoderStep),
163161
("image_encoder", FluxVaeEncoderStep),
164162
("input", FluxInputStep),
165-
("prepare_latents", FluxImg2ImgPrepareLatentsStep),
166163
("set_timesteps", FluxImg2ImgSetTimestepsStep),
164+
("prepare_latents", FluxImg2ImgPrepareLatentsStep),
167165
("denoise", FluxDenoiseStep),
168166
("decode", FluxDecodeStep),
169167
]

0 commit comments

Comments
 (0)