Skip to content

Commit eb67b2c

Browse files
committed
Move encoder_hid_proj to inside FluxTransformer2DModel
1 parent 5b0a88b commit eb67b2c

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,11 @@ def forward(
491491
ids = torch.cat((txt_ids, img_ids), dim=0)
492492
image_rotary_emb = self.pos_embed(ids)
493493

494+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
495+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
496+
ip_hidden_states = self.transformer.encoder_hid_proj(ip_adapter_image_embeds)
497+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
498+
494499
for index_block, block in enumerate(self.transformer_blocks):
495500
if torch.is_grad_enabled() and self.gradient_checkpointing:
496501

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,8 @@ def prepare_ip_adapter_image_embeds(
471471
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
472472

473473
image_embeds.append(single_image_embeds[None, :])
474-
image_embeds = self.transformer.encoder_hid_proj(image_embeds)
475474
else:
476475
for single_image_embeds in ip_adapter_image_embeds:
477-
image_embeds = self.transformer.encoder_hid_proj(single_image_embeds)
478476
image_embeds.append(single_image_embeds)
479477

480478
ip_adapter_image_embeds = []
@@ -913,7 +911,6 @@ def __call__(
913911
device,
914912
batch_size * num_images_per_prompt,
915913
)
916-
self._joint_attention_kwargs["image_projection"] = image_embeds
917914
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
918915
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
919916
negative_ip_adapter_image,
@@ -928,7 +925,8 @@ def __call__(
928925
if self.interrupt:
929926
continue
930927

931-
self._joint_attention_kwargs["image_projection"] = image_embeds
928+
if image_embeds is not None:
929+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
932930
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
933931
timestep = t.expand(latents.shape[0]).to(latents.dtype)
934932

@@ -945,7 +943,8 @@ def __call__(
945943
)[0]
946944

947945
if do_true_cfg:
948-
self._joint_attention_kwargs["image_projection"] = negative_image_embeds
946+
if negative_image_embeds is not None:
947+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
949948
neg_noise_pred = self.transformer(
950949
hidden_states=latents,
951950
timestep=timestep / 1000,

0 commit comments

Comments
 (0)