Skip to content

Commit 1045988

Browse files
committed
encoder.
1 parent 046bf9e commit 1045988

File tree

1 file changed

+111
-1
lines changed

1 file changed

+111
-1
lines changed

src/diffusers/modular_pipelines/flux/encoders.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
import torch
2020
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
2121

22+
from ...configuration_utils import FrozenDict
23+
from ...image_processor import VaeImageProcessor
2224
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
25+
from ...models import AutoencoderKL
2326
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
2427
from ..modular_pipeline import PipelineBlock, PipelineState
2528
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
@@ -50,6 +53,113 @@ def prompt_clean(text):
5053
return text
5154

5255

56+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
57+
def retrieve_latents(
58+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
59+
):
60+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
61+
return encoder_output.latent_dist.sample(generator)
62+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
63+
return encoder_output.latent_dist.mode()
64+
elif hasattr(encoder_output, "latents"):
65+
return encoder_output.latents
66+
else:
67+
raise AttributeError("Could not access latents of provided encoder_output")
68+
69+
70+
class FluxVaeEncoderStep(PipelineBlock):
71+
model_name = "flux"
72+
73+
@property
74+
def description(self) -> str:
75+
return "Vae Encoder step that encode the input image into a latent representation"
76+
77+
@property
78+
def expected_components(self) -> List[ComponentSpec]:
79+
return [
80+
ComponentSpec("vae", AutoencoderKL),
81+
ComponentSpec(
82+
"image_processor",
83+
VaeImageProcessor,
84+
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
85+
default_creation_method="from_config",
86+
),
87+
]
88+
89+
@property
90+
def inputs(self) -> List[InputParam]:
91+
return [
92+
InputParam("image", required=True),
93+
InputParam("height"),
94+
InputParam("width"),
95+
]
96+
97+
@property
98+
def intermediate_inputs(self) -> List[InputParam]:
99+
return [
100+
InputParam("generator"),
101+
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
102+
InputParam(
103+
"preprocess_kwargs",
104+
type_hint=Optional[dict],
105+
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
106+
),
107+
]
108+
109+
@property
110+
def intermediate_outputs(self) -> List[OutputParam]:
111+
return [
112+
OutputParam(
113+
"image_latents",
114+
type_hint=torch.Tensor,
115+
description="The latents representing the reference image for image-to-image/inpainting generation",
116+
)
117+
]
118+
119+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae
120+
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
121+
if isinstance(generator, list):
122+
image_latents = [
123+
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
124+
]
125+
image_latents = torch.cat(image_latents, dim=0)
126+
else:
127+
image_latents = retrieve_latents(vae.encode(image), generator=generator)
128+
129+
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
130+
131+
return image_latents
132+
133+
@torch.no_grad()
134+
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
135+
block_state = self.get_block_state(state)
136+
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
137+
block_state.device = components._execution_device
138+
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
139+
140+
block_state.image = components.image_processor.preprocess(
141+
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
142+
)
143+
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
144+
145+
block_state.batch_size = block_state.image.shape[0]
146+
147+
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
148+
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
149+
raise ValueError(
150+
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
151+
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
152+
)
153+
154+
block_state.image_latents = self._encode_vae_image(
155+
components, image=block_state.image, generator=block_state.generator
156+
)
157+
158+
self.set_block_state(state, block_state)
159+
160+
return components, state
161+
162+
53163
class FluxTextEncoderStep(PipelineBlock):
54164
model_name = "flux"
55165

@@ -297,7 +407,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
297407
prompt_embeds=None,
298408
pooled_prompt_embeds=None,
299409
device=block_state.device,
300-
num_images_per_prompt=1, # hardcoded for now.
410+
num_images_per_prompt=1, # TODO: hardcoded for now.
301411
lora_scale=block_state.text_encoder_lora_scale,
302412
)
303413

0 commit comments

Comments
 (0)