Skip to content

Commit 046bf9e

Browse files
committed
start
1 parent ba2ba90 commit 046bf9e

File tree

1 file changed

+279
-23
lines changed

1 file changed

+279
-23
lines changed

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 279 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
from typing import List, Optional, Union
16+
from typing import Any, List, Optional, Tuple, Union
1717

1818
import numpy as np
1919
import torch
2020

21+
from ...models import AutoencoderKL
2122
from ...schedulers import FlowMatchEulerDiscreteScheduler
2223
from ...utils import logging
2324
from ...utils.torch_utils import randn_tensor
@@ -103,6 +104,61 @@ def calculate_shift(
103104
return mu
104105

105106

107+
def prepare_latents_img2img(
108+
vae, scheduler, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator
109+
):
110+
if isinstance(generator, list) and len(generator) != batch_size:
111+
raise ValueError(
112+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
113+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
114+
)
115+
116+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
117+
latent_channels = vae.config.latent_channels
118+
119+
# VAE applies 8x compression on images but we must also account for packing which requires
120+
# latent height and width to be divisible by 2.
121+
height = 2 * (int(height) // (vae_scale_factor * 2))
122+
width = 2 * (int(width) // (vae_scale_factor * 2))
123+
shape = (batch_size, num_channels_latents, height, width)
124+
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
125+
126+
image = image.to(device=device, dtype=dtype)
127+
if image.shape[1] != latent_channels:
128+
image_latents = _encode_vae_image(image=image, generator=generator)
129+
else:
130+
image_latents = image
131+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
132+
# expand init_latents for batch_size
133+
additional_image_per_prompt = batch_size // image_latents.shape[0]
134+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
135+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
136+
raise ValueError(
137+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
138+
)
139+
else:
140+
image_latents = torch.cat([image_latents], dim=0)
141+
142+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
143+
latents = scheduler.scale_noise(image_latents, timestep, noise)
144+
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
145+
return latents, latent_image_ids
146+
147+
148+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
149+
def retrieve_latents(
150+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
151+
):
152+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
153+
return encoder_output.latent_dist.sample(generator)
154+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
155+
return encoder_output.latent_dist.mode()
156+
elif hasattr(encoder_output, "latents"):
157+
return encoder_output.latents
158+
else:
159+
raise AttributeError("Could not access latents of provided encoder_output")
160+
161+
106162
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
107163
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
108164
latents = latents.permute(0, 2, 4, 1, 3, 5)
@@ -125,6 +181,44 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
125181
return latent_image_ids.to(device=device, dtype=dtype)
126182

127183

184+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae
185+
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
186+
if isinstance(generator, list):
187+
image_latents = [
188+
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
189+
]
190+
image_latents = torch.cat(image_latents, dim=0)
191+
else:
192+
image_latents = retrieve_latents(vae.encode(image), generator=generator)
193+
194+
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
195+
196+
return image_latents
197+
198+
199+
def _get_timesteps_and_optionals(transformer, scheduler, latents, num_inference_steps, guidance_scale, sigmas, device):
200+
image_seq_len = latents.shape[1]
201+
202+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
203+
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
204+
sigmas = None
205+
mu = calculate_shift(
206+
image_seq_len,
207+
scheduler.config.get("base_image_seq_len", 256),
208+
scheduler.config.get("max_image_seq_len", 4096),
209+
scheduler.config.get("base_shift", 0.5),
210+
scheduler.config.get("max_shift", 1.15),
211+
)
212+
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
213+
if transformer.config.guidance_embeds:
214+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
215+
guidance = guidance.expand(latents.shape[0])
216+
else:
217+
guidance = None
218+
219+
return timesteps, num_inference_steps, sigmas, guidance
220+
221+
128222
class FluxInputStep(PipelineBlock):
129223
model_name = "flux"
130224

@@ -264,34 +358,103 @@ def intermediate_outputs(self) -> List[OutputParam]:
264358
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
265359
block_state = self.get_block_state(state)
266360
block_state.device = components._execution_device
267-
scheduler = components.scheduler
268361

269-
latents = block_state.latents
270-
image_seq_len = latents.shape[1]
362+
scheduler = components.scheduler
363+
transformer = components.transformer
271364

272-
num_inference_steps = block_state.num_inference_steps
273-
sigmas = block_state.sigmas
274-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
275-
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
276-
sigmas = None
277-
block_state.sigmas = sigmas
278-
mu = calculate_shift(
279-
image_seq_len,
280-
scheduler.config.get("base_image_seq_len", 256),
281-
scheduler.config.get("max_image_seq_len", 4096),
282-
scheduler.config.get("base_shift", 0.5),
283-
scheduler.config.get("max_shift", 1.15),
365+
timesteps, num_inference_steps, sigmas, guidance = _get_timesteps_and_optionals(
366+
transformer,
367+
scheduler,
368+
block_state.latents,
369+
block_state.num_inference_steps,
370+
block_state.guidance_scale,
371+
block_state.sigmas,
372+
block_state.device,
284373
)
285-
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
286-
scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu
374+
block_state.timesteps = timesteps
375+
block_state.num_inference_steps = num_inference_steps
376+
block_state.sigmas = sigmas
377+
block_state.guidance = guidance
378+
379+
self.set_block_state(state, block_state)
380+
return components, state
381+
382+
383+
class FluxImg2ImgSetTimestepsStep(PipelineBlock):
384+
model_name = "flux"
385+
386+
@property
387+
def expected_components(self) -> List[ComponentSpec]:
388+
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
389+
390+
@property
391+
def description(self) -> str:
392+
return "Step that sets the scheduler's timesteps for inference"
393+
394+
@property
395+
def inputs(self) -> List[InputParam]:
396+
return [
397+
InputParam("num_inference_steps", default=50),
398+
InputParam("timesteps"),
399+
InputParam("sigmas"),
400+
InputParam("guidance_scale", default=3.5),
401+
InputParam("latents", type_hint=torch.Tensor),
402+
InputParam("num_images_per_prompt", default=1),
403+
]
404+
405+
@property
406+
def intermediate_inputs(self) -> List[str]:
407+
return [
408+
InputParam(
409+
"latents",
410+
required=True,
411+
type_hint=torch.Tensor,
412+
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
413+
)
414+
]
415+
416+
@property
417+
def intermediate_outputs(self) -> List[OutputParam]:
418+
return [
419+
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
420+
OutputParam(
421+
"num_inference_steps",
422+
type_hint=int,
423+
description="The number of denoising steps to perform at inference time",
424+
),
425+
OutputParam(
426+
"latent_timestep",
427+
type_hint=torch.Tensor,
428+
description="The timestep that represents the initial noise level for image-to-image generation",
429+
),
430+
OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
431+
]
432+
433+
@torch.no_grad()
434+
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
435+
block_state = self.get_block_state(state)
436+
block_state.device = components._execution_device
437+
438+
scheduler = components.scheduler
439+
transformer = components.transformer
440+
441+
timesteps, num_inference_steps, sigmas, guidance = _get_timesteps_and_optionals(
442+
transformer,
443+
scheduler,
444+
block_state.latents,
445+
block_state.num_inference_steps,
446+
block_state.guidance_scale,
447+
block_state.sigmas,
448+
block_state.device,
287449
)
288-
if components.transformer.config.guidance_embeds:
289-
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
290-
guidance = guidance.expand(latents.shape[0])
291-
else:
292-
guidance = None
450+
block_state.timesteps = timesteps
451+
block_state.num_inference_steps = num_inference_steps
452+
block_state.sigmas = sigmas
293453
block_state.guidance = guidance
294454

455+
batch_size = block_state.latents.shape[0]
456+
block_state.latent_timestep = timesteps[:1].repeat(batch_size * block_state.num_images_per_prompt)
457+
295458
self.set_block_state(state, block_state)
296459
return components, state
297460

@@ -418,3 +581,96 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
418581
self.set_block_state(state, block_state)
419582

420583
return components, state
584+
585+
586+
class FluxLImg2ImgPrepareLatentsStep(PipelineBlock):
587+
model_name = "flux"
588+
589+
@property
590+
def expected_components(self) -> List[ComponentSpec]:
591+
return [ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
592+
593+
@property
594+
def description(self) -> str:
595+
return "Step that prepares the latents for the image-to-image generation process"
596+
597+
@property
598+
def inputs(self) -> List[Tuple[str, Any]]:
599+
return [
600+
InputParam("height", type_hint=int),
601+
InputParam("width", type_hint=int),
602+
InputParam("latents", type_hint=Optional[torch.Tensor]),
603+
InputParam("num_images_per_prompt", type_hint=int, default=1),
604+
InputParam("latents"),
605+
]
606+
607+
@property
608+
def intermediate_inputs(self) -> List[InputParam]:
609+
return [
610+
InputParam("generator"),
611+
InputParam(
612+
"image_latents",
613+
required=True,
614+
type_hint=torch.Tensor,
615+
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
616+
),
617+
InputParam(
618+
"latent_timestep",
619+
required=True,
620+
type_hint=torch.Tensor,
621+
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
622+
),
623+
InputParam(
624+
"batch_size",
625+
required=True,
626+
type_hint=int,
627+
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
628+
),
629+
InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
630+
]
631+
632+
@property
633+
def intermediate_outputs(self) -> List[OutputParam]:
634+
return [
635+
OutputParam(
636+
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
637+
),
638+
OutputParam(
639+
"latent_image_ids",
640+
type_hint=torch.Tensor,
641+
description="IDs computed from the image sequence needed for RoPE",
642+
),
643+
]
644+
645+
@torch.no_grad()
646+
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
647+
block_state = self.get_block_state(state)
648+
649+
block_state.height = block_state.height or components.default_height
650+
block_state.width = block_state.width or components.default_width
651+
block_state.device = components._execution_device
652+
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
653+
block_state.num_channels_latents = components.num_channels_latents
654+
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
655+
block_state.device = components._execution_device
656+
657+
# TODO: implement `check_inputs`
658+
659+
if block_state.latents is None:
660+
block_state.latents, block_state.latent_image_ids = prepare_latents_img2img(
661+
components.vae,
662+
components.scheduler,
663+
block_state.image_latents,
664+
block_state.latent_timestep,
665+
block_state.batch_size * block_state.num_images_per_prompt,
666+
block_state.num_channels_latents,
667+
block_state.height,
668+
block_state.width,
669+
block_state.dtype,
670+
block_state.device,
671+
block_state.generator,
672+
)
673+
674+
self.set_block_state(state, block_state)
675+
676+
return components, state

0 commit comments

Comments
 (0)