Skip to content

Commit 7d13a41

Browse files
committed
weighted sum
1 parent ef9ec65 commit 7d13a41

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,8 @@ def __call__(
369369
prompt_2: Optional[Union[str, List[str]]] = None,
370370
prompt_embeds: Optional[torch.FloatTensor] = None,
371371
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
372-
scales: Optional[Union[float, List[float]]] = None,
372+
prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.,
373+
pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.,
373374
return_dict: bool = True,
374375
):
375376
r"""
@@ -418,6 +419,10 @@ def __call__(
418419
batch_size = image.shape[0]
419420
if prompt is not None and isinstance(prompt, str):
420421
prompt = batch_size * [prompt]
422+
if isinstance(prompt_embeds_scale, float):
423+
prompt_embeds_scale = batch_size * [prompt_embeds_scale]
424+
if isinstance(pooled_prompt_embeds_scale, float):
425+
pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale]
421426

422427
device = self._execution_device
423428

@@ -449,9 +454,21 @@ def __call__(
449454
# pooled_prompt_embeds is 768, clip text encoder hidden size
450455
pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype)
451456

457+
print("1 prompt_embeds.shape", prompt_embeds.shape)
458+
prompt_embeds_scale = torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None]
459+
pooled_prompt_embeds_scale = torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None]
460+
461+
452462
# Concatenate image and text embeddings
453463
prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1)
454-
464+
print("2 prompt_embeds.shape", prompt_embeds.shape)
465+
prompt_embeds *= prompt_embeds_scale
466+
pooled_prompt_embeds *= pooled_prompt_embeds_scale
467+
print("3 prompt_embeds.shape", prompt_embeds.shape)
468+
469+
prompt_embeds = torch.sum(prompt_embeds, dim=0)
470+
pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0)
471+
print("4 prompt_embeds.shape", prompt_embeds.shape)
455472
# Offload all models
456473
self.maybe_free_model_hooks()
457474

0 commit comments

Comments
 (0)