From 25ecc818a73de97fb3f4a09aaf5cf58b14d21d64 Mon Sep 17 00:00:00 2001 From: 99991 <99991@users.noreply.github.com> Date: Thu, 29 Feb 2024 11:53:58 +0100 Subject: [PATCH] Fix repeat shape --- .../pipelines/stable_cascade/pipeline_stable_cascade_prior.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py index 7bf1a0f34067..d97ebe904564 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py @@ -247,7 +247,7 @@ def encode_image(self, images, device, dtype, batch_size, num_images_per_prompt) image_embeds.append(image_embed) image_embeds = torch.cat(image_embeds, dim=1) - image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt) + image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) negative_image_embeds = torch.zeros_like(image_embeds) return image_embeds, negative_image_embeds @@ -492,7 +492,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, ) elif image_embeds is not None: - image_embeds_pooled = image_embeds.repeat(batch_size * num_images_per_prompt) + image_embeds_pooled = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) uncond_image_embeds_pooled = torch.zeros_like(image_embeds_pooled) else: image_embeds_pooled = torch.zeros(