diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py index 7bf1a0f34067..d97ebe904564 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py @@ -247,7 +247,7 @@ def encode_image(self, images, device, dtype, batch_size, num_images_per_prompt) image_embeds.append(image_embed) image_embeds = torch.cat(image_embeds, dim=1) - image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt) + image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) negative_image_embeds = torch.zeros_like(image_embeds) return image_embeds, negative_image_embeds @@ -492,7 +492,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, ) elif image_embeds is not None: - image_embeds_pooled = image_embeds.repeat(batch_size * num_images_per_prompt) + image_embeds_pooled = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) uncond_image_embeds_pooled = torch.zeros_like(image_embeds_pooled) else: image_embeds_pooled = torch.zeros(