@@ -207,6 +207,9 @@ def __init__(
207207 self .vae_scale_factor_temporal = (
208208 self .vae .config .temporal_compression_ratio if hasattr (self , "vae" ) and self .vae is not None else 4
209209 )
210+ self .vae_scaling_factor_image = (
211+ self .vae .config .scaling_factor if hasattr (self , "vae" ) and self .vae is not None else 0.7
212+ )
210213
211214 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
212215
@@ -348,6 +351,12 @@ def prepare_latents(
348351 generator : Optional [torch .Generator ] = None ,
349352 latents : Optional [torch .Tensor ] = None ,
350353 ):
354+ if isinstance (generator , list ) and len (generator ) != batch_size :
355+ raise ValueError (
356+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
357+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
358+ )
359+
351360 num_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
352361 shape = (
353362 batch_size ,
@@ -357,12 +366,6 @@ def prepare_latents(
357366 width // self .vae_scale_factor_spatial ,
358367 )
359368
360- if isinstance (generator , list ) and len (generator ) != batch_size :
361- raise ValueError (
362- f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
363- f" size of { batch_size } . Make sure the batch size matches the length of the generators."
364- )
365-
366369 image = image .unsqueeze (2 ) # [B, C, F, H, W]
367370
368371 if isinstance (generator , list ):
@@ -373,7 +376,7 @@ def prepare_latents(
373376 image_latents = [retrieve_latents (self .vae .encode (img .unsqueeze (0 )), generator ) for img in image ]
374377
375378 image_latents = torch .cat (image_latents , dim = 0 ).to (dtype ).permute (0 , 2 , 1 , 3 , 4 ) # [B, F, C, H, W]
376- image_latents = self .vae . config . scaling_factor * image_latents
379+ image_latents = self .vae_scaling_factor_image * image_latents
377380
378381 padding_shape = (
379382 batch_size ,
@@ -397,7 +400,7 @@ def prepare_latents(
397400 # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
398401 def decode_latents (self , latents : torch .Tensor ) -> torch .Tensor :
399402 latents = latents .permute (0 , 2 , 1 , 3 , 4 ) # [batch_size, num_channels, num_frames, height, width]
400- latents = 1 / self .vae . config . scaling_factor * latents
403+ latents = 1 / self .vae_scaling_factor_image * latents
401404
402405 frames = self .vae .decode (latents ).sample
403406 return frames
@@ -438,7 +441,6 @@ def check_inputs(
438441 width ,
439442 negative_prompt ,
440443 callback_on_step_end_tensor_inputs ,
441- video = None ,
442444 latents = None ,
443445 prompt_embeds = None ,
444446 negative_prompt_embeds = None ,
@@ -494,9 +496,6 @@ def check_inputs(
494496 f" { negative_prompt_embeds .shape } ."
495497 )
496498
497- if video is not None and latents is not None :
498- raise ValueError ("Only one of `video` or `latents` should be provided" )
499-
500499 # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
501500 def fuse_qkv_projections (self ) -> None :
502501 r"""Enables fused QKV projections."""
@@ -584,18 +583,18 @@ def __call__(
584583
585584 Args:
586585 image (`PipelineImageInput`):
587- The input video to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
586+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
588587 prompt (`str` or `List[str]`, *optional*):
589588 The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
590589 instead.
591590 negative_prompt (`str` or `List[str]`, *optional*):
592591 The prompt or prompts not to guide the image generation. If not defined, one has to pass
593592 `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
594593 less than `1`).
595- height (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor ):
596- The height in pixels of the generated image. This is set to 1024 by default for the best results.
597- width (`int`, *optional*, defaults to self.unet .config.sample_size * self.vae_scale_factor ):
598- The width in pixels of the generated image. This is set to 1024 by default for the best results.
594+ height (`int`, *optional*, defaults to self.transformer .config.sample_height * self.vae_scale_factor_spatial ):
595+ The height in pixels of the generated image. This is set to 480 by default for the best results.
596+ width (`int`, *optional*, defaults to self.transformer .config.sample_height * self.vae_scale_factor_spatial ):
597+ The width in pixels of the generated image. This is set to 720 by default for the best results.
599598 num_frames (`int`, defaults to `48`):
600599 Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
601600 contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
@@ -665,20 +664,19 @@ def __call__(
665664 if isinstance (callback_on_step_end , (PipelineCallback , MultiPipelineCallbacks )):
666665 callback_on_step_end_tensor_inputs = callback_on_step_end .tensor_inputs
667666
668- height = height or self .transformer .config .sample_size * self .vae_scale_factor_spatial
669- width = width or self .transformer .config .sample_size * self .vae_scale_factor_spatial
670667 num_videos_per_prompt = 1
671668
672669 # 1. Check inputs. Raise error if not correct
673670 self .check_inputs (
674- image ,
675- prompt ,
676- height ,
677- width ,
678- negative_prompt ,
679- callback_on_step_end_tensor_inputs ,
680- prompt_embeds ,
681- negative_prompt_embeds ,
671+ image = image ,
672+ prompt = prompt ,
673+ height = height ,
674+ width = width ,
675+ negative_prompt = negative_prompt ,
676+ callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
677+ latents = latents ,
678+ prompt_embeds = prompt_embeds ,
679+ negative_prompt_embeds = negative_prompt_embeds ,
682680 )
683681 self ._guidance_scale = guidance_scale
684682 self ._interrupt = False
0 commit comments