Skip to content

Commit b8dfdf7

Browse files
committed
fix
1 parent 7d13a41 commit b8dfdf7

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -454,21 +454,16 @@ def __call__(
454454
# pooled_prompt_embeds is 768, clip text encoder hidden size
455455
pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype)
456456

457-
print("1 prompt_embeds.shape", prompt_embeds.shape)
458-
prompt_embeds_scale = torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None]
459-
pooled_prompt_embeds_scale = torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None]
457+
# scale & oncatenate image and text embeddings
458+
prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1)
460459

460+
prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None]
461+
pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None]
462+
463+
# weighted sum
464+
prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True)
465+
pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True)
461466

462-
# Concatenate image and text embeddings
463-
prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1)
464-
print("2 prompt_embeds.shape", prompt_embeds.shape)
465-
prompt_embeds *= prompt_embeds_scale
466-
pooled_prompt_embeds *= pooled_prompt_embeds_scale
467-
print("3 prompt_embeds.shape", prompt_embeds.shape)
468-
469-
prompt_embeds = torch.sum(prompt_embeds, dim=0)
470-
pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0)
471-
print("4 prompt_embeds.shape", prompt_embeds.shape)
472467
# Offload all models
473468
self.maybe_free_model_hooks()
474469

0 commit comments

Comments
 (0)