Skip to content

Commit 13b042f

Browse files
committed
up
1 parent 9e4a75b commit 13b042f

File tree

5 files changed

+441
-11
lines changed

5 files changed

+441
-11
lines changed

src/diffusers/modular_pipelines/flux/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,19 @@
2424
_import_structure["encoders"] = ["FluxTextEncoderStep"]
2525
_import_structure["modular_blocks"] = [
2626
"ALL_BLOCKS",
27+
"ALL_BLOCKS_KONTEXT",
2728
"AUTO_BLOCKS",
29+
"AUTO_BLOCKS_KONTEXT",
2830
"TEXT2IMAGE_BLOCKS",
2931
"FluxAutoBeforeDenoiseStep",
3032
"FluxAutoBlocks",
31-
"FluxAutoBlocks",
3233
"FluxAutoDecodeStep",
3334
"FluxAutoDenoiseStep",
35+
"FluxKontextAutoBeforeDenoiseStep",
36+
"FluxKontextAutoBlocks",
37+
"FluxKontextAutoDenoiseStep",
3438
]
35-
_import_structure["modular_pipeline"] = ["FluxModularPipeline"]
39+
_import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]
3640

3741
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3842
try:
@@ -44,12 +48,16 @@
4448
from .encoders import FluxTextEncoderStep
4549
from .modular_blocks import (
4650
ALL_BLOCKS,
51+
ALL_BLOCKS_KONTEXT,
4752
AUTO_BLOCKS,
53+
AUTO_BLOCKS_KONTEXT,
4854
TEXT2IMAGE_BLOCKS,
4955
FluxAutoBeforeDenoiseStep,
5056
FluxAutoBlocks,
5157
FluxAutoDecodeStep,
5258
FluxAutoDenoiseStep,
59+
FluxKontextAutoBeforeDenoiseStep,
60+
FluxKontextAutoDenoiseStep,
5361
)
5462
from .modular_pipeline import FluxModularPipeline
5563
else:

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 216 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import numpy as np
1919
import torch
2020

21+
from ...configuration_utils import FrozenDict
22+
from ...image_processor import VaeImageProcessor
2123
from ...models import AutoencoderKL
2224
from ...schedulers import FlowMatchEulerDiscreteScheduler
2325
from ...utils import logging
@@ -182,15 +184,15 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
182184
return latent_image_ids.to(device=device, dtype=dtype)
183185

184186

185-
# Cannot use "# Copied from" because it introduces weird indentation errors.
186-
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
187+
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator, sample_mode: str = "sample"):
187188
if isinstance(generator, list):
188189
image_latents = [
189-
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
190+
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
191+
for i in range(image.shape[0])
190192
]
191193
image_latents = torch.cat(image_latents, dim=0)
192194
else:
193-
image_latents = retrieve_latents(vae.encode(image), generator=generator)
195+
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)
194196

195197
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
196198

@@ -687,3 +689,213 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
687689
self.set_block_state(state, block_state)
688690

689691
return components, state
692+
693+
694+
class FluxKontextPrepareLatentsStep(ModularPipelineBlocks):
695+
model_name = "flux_kontext"
696+
697+
@property
698+
def expected_components(self) -> List[ComponentSpec]:
699+
return [
700+
ComponentSpec("vae", AutoencoderKL),
701+
ComponentSpec(
702+
"image_processor",
703+
VaeImageProcessor,
704+
config=FrozenDict({"vae_scale_factor": 16}),
705+
default_creation_method="from_config",
706+
),
707+
]
708+
709+
@property
710+
def description(self) -> str:
711+
return "Prepare latents step that prepares the latents for the image-to-image generation process with Flux Kontext"
712+
713+
@property
714+
def inputs(self) -> List[InputParam]:
715+
return [
716+
InputParam("height", type_hint=int),
717+
InputParam("width", type_hint=int),
718+
InputParam("max_area", type_hint=int, default=1024**2),
719+
InputParam("latents", type_hint=Optional[torch.Tensor]),
720+
InputParam("num_images_per_prompt", type_hint=int, default=1),
721+
InputParam("generator"),
722+
InputParam(
723+
"batch_size",
724+
required=True,
725+
type_hint=int,
726+
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
727+
),
728+
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
729+
]
730+
731+
@property
732+
def intermediate_outputs(self) -> List[OutputParam]:
733+
return [
734+
OutputParam(
735+
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
736+
),
737+
OutputParam(
738+
"image_latents", type_hint=torch.Tensor, description="Latents computed from the input image(s)."
739+
),
740+
OutputParam(
741+
"latent_ids",
742+
type_hint=torch.Tensor,
743+
description="IDs computed from the latent sequence needed for RoPE",
744+
),
745+
OutputParam(
746+
"image_ids",
747+
type_hint=torch.Tensor,
748+
description="IDs computed from the image sequence needed for RoPE",
749+
),
750+
]
751+
752+
@staticmethod
753+
def check_inputs(components, block_state):
754+
if (block_state.height is not None and block_state.height % (components.vae_scale_factor * 2) != 0) or (
755+
block_state.width is not None and block_state.width % (components.vae_scale_factor * 2) != 0
756+
):
757+
logger.warning(
758+
f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}."
759+
)
760+
761+
@staticmethod
762+
def preprocess_image(
763+
image, image_processor: VaeImageProcessor, vae_scale_factor: int, latent_channels: int, _auto_resize=True
764+
):
765+
from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS
766+
767+
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == latent_channels):
768+
multiple_of = vae_scale_factor * 2
769+
img = image[0] if isinstance(image, list) else image
770+
image_height, image_width = image_processor.get_default_height_width(img)
771+
aspect_ratio = image_width / image_height
772+
if _auto_resize:
773+
# Kontext is trained on specific resolutions, using one of them is recommended
774+
_, image_width, image_height = min(
775+
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
776+
)
777+
image_width = image_width // multiple_of * multiple_of
778+
image_height = image_height // multiple_of * multiple_of
779+
image = image_processor.resize(image, image_height, image_width)
780+
image = image_processor.preprocess(image, image_height, image_width)
781+
return image
782+
783+
@staticmethod
784+
def prepare_latents(
785+
comp,
786+
image,
787+
batch_size,
788+
num_channels_latents,
789+
height,
790+
width,
791+
dtype,
792+
device,
793+
generator,
794+
latents=None,
795+
):
796+
# Couldn't use the `prepare_latents` method directly from Flux because I decided to copy over
797+
# the packing methods here. So, for example, `comp._pack_latents()` won't work if we were
798+
# to go with the "# Copied from ..." approach. Or maybe there's a way?
799+
800+
# VAE applies 8x compression on images but we must also account for packing which requires
801+
# latent height and width to be divisible by 2.
802+
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
803+
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
804+
shape = (batch_size, num_channels_latents, height, width)
805+
806+
image_latents = image_ids = None
807+
if image is not None:
808+
image = image.to(device=device, dtype=dtype)
809+
if image.shape[1] != num_channels_latents:
810+
image_latents = _encode_vae_image(image=image, generator=generator, sample_mode="argmax")
811+
else:
812+
image_latents = image
813+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
814+
# expand init_latents for batch_size
815+
additional_image_per_prompt = batch_size // image_latents.shape[0]
816+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
817+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
818+
raise ValueError(
819+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
820+
)
821+
else:
822+
image_latents = torch.cat([image_latents], dim=0)
823+
824+
image_latent_height, image_latent_width = image_latents.shape[2:]
825+
image_latents = _pack_latents(
826+
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
827+
)
828+
image_ids = _prepare_latent_image_ids(
829+
batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
830+
)
831+
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
832+
image_ids[..., 0] = 1
833+
834+
latent_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
835+
836+
if latents is None:
837+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
838+
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
839+
else:
840+
latents = latents.to(device=device, dtype=dtype)
841+
842+
return latents, image_latents, latent_ids, image_ids
843+
844+
@torch.no_grad()
845+
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
846+
block_state = self.get_block_state(state)
847+
848+
block_state.height = block_state.height or components.default_height
849+
block_state.width = block_state.width or components.default_width
850+
block_state.device = components._execution_device
851+
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
852+
block_state.num_channels_latents = components.num_channels_latents
853+
854+
self.check_inputs(components, block_state)
855+
856+
# Adjust height and width if needed.
857+
max_area = block_state.max_area
858+
original_height, original_width = block_state.height, block_state.width
859+
aspect_ratio = original_width / original_height
860+
width = round((max_area * aspect_ratio) ** 0.5)
861+
height = round((max_area / aspect_ratio) ** 0.5)
862+
863+
multiple_of = components.vae_scale_factor * 2
864+
width = width // multiple_of * multiple_of
865+
height = height // multiple_of * multiple_of
866+
867+
if height != original_height or width != original_width:
868+
logger.warning(
869+
f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
870+
)
871+
block_state.height = height
872+
block_state.width = width
873+
874+
# Process input image(s).
875+
# `_auto_resize` is currently forced to True. Since it's private anyway, I thought of not adding it.
876+
image = block_state.image
877+
block_state.image = self.preprocess_image(
878+
image=image,
879+
image_processor=components.image_processor,
880+
vae_scale_factor=components.vae_scale_factor,
881+
latent_channels=components.num_channels_latents,
882+
)
883+
884+
batch_size = block_state.batch_size * block_state.num_images_per_prompt
885+
block_state.latents, block_state.image_latents, block_state.latent_ids, block_state.image_ids = (
886+
self.prepare_latents(
887+
components,
888+
batch_size,
889+
block_state.num_channels_latents,
890+
block_state.height,
891+
block_state.width,
892+
block_state.dtype,
893+
block_state.device,
894+
block_state.generator,
895+
block_state.latents,
896+
)
897+
)
898+
899+
self.set_block_state(state, block_state)
900+
901+
return components, state

0 commit comments

Comments
 (0)