diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index f2c047fb22c9..00bec1b620d5 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -362,10 +362,16 @@ def check_inputs( ) if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: + if prompt_embeds.shape[0] != negative_prompt_embeds.shape[0]: raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + "`prompt_embeds` and `negative_prompt_embeds` must have the same batch size when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds.shape[-1] != negative_prompt_embeds.shape[-1]: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same dimension when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." )