@@ -207,6 +207,9 @@ def __init__(
207207        self .vae_scale_factor_temporal  =  (
208208            self .vae .config .temporal_compression_ratio  if  hasattr (self , "vae" ) and  self .vae  is  not None  else  4 
209209        )
210+         self .vae_scaling_factor_image  =  (
211+             self .vae .config .scaling_factor  if  hasattr (self , "vae" ) and  self .vae  is  not None  else  0.7 
212+         )
210213
211214        self .video_processor  =  VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
212215
@@ -348,6 +351,12 @@ def prepare_latents(
348351        generator : Optional [torch .Generator ] =  None ,
349352        latents : Optional [torch .Tensor ] =  None ,
350353    ):
354+         if  isinstance (generator , list ) and  len (generator ) !=  batch_size :
355+             raise  ValueError (
356+                 f"You have passed a list of generators of length { len (generator )}  
357+                 f" size of { batch_size }  
358+             )
359+ 
351360        num_frames  =  (num_frames  -  1 ) //  self .vae_scale_factor_temporal  +  1 
352361        shape  =  (
353362            batch_size ,
@@ -357,12 +366,6 @@ def prepare_latents(
357366            width  //  self .vae_scale_factor_spatial ,
358367        )
359368
360-         if  isinstance (generator , list ) and  len (generator ) !=  batch_size :
361-             raise  ValueError (
362-                 f"You have passed a list of generators of length { len (generator )}  
363-                 f" size of { batch_size }  
364-             )
365- 
366369        image  =  image .unsqueeze (2 )  # [B, C, F, H, W] 
367370
368371        if  isinstance (generator , list ):
@@ -373,7 +376,7 @@ def prepare_latents(
373376            image_latents  =  [retrieve_latents (self .vae .encode (img .unsqueeze (0 )), generator ) for  img  in  image ]
374377
375378        image_latents  =  torch .cat (image_latents , dim = 0 ).to (dtype ).permute (0 , 2 , 1 , 3 , 4 )  # [B, F, C, H, W] 
376-         image_latents  =  self .vae . config . scaling_factor  *  image_latents 
379+         image_latents  =  self .vae_scaling_factor_image  *  image_latents 
377380
378381        padding_shape  =  (
379382            batch_size ,
@@ -397,7 +400,7 @@ def prepare_latents(
397400    # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents 
398401    def  decode_latents (self , latents : torch .Tensor ) ->  torch .Tensor :
399402        latents  =  latents .permute (0 , 2 , 1 , 3 , 4 )  # [batch_size, num_channels, num_frames, height, width] 
400-         latents  =  1  /  self .vae . config . scaling_factor  *  latents 
403+         latents  =  1  /  self .vae_scaling_factor_image  *  latents 
401404
402405        frames  =  self .vae .decode (latents ).sample 
403406        return  frames 
@@ -438,7 +441,6 @@ def check_inputs(
438441        width ,
439442        negative_prompt ,
440443        callback_on_step_end_tensor_inputs ,
441-         video = None ,
442444        latents = None ,
443445        prompt_embeds = None ,
444446        negative_prompt_embeds = None ,
@@ -494,9 +496,6 @@ def check_inputs(
494496                    f" { negative_prompt_embeds .shape }  
495497                )
496498
497-         if  video  is  not None  and  latents  is  not None :
498-             raise  ValueError ("Only one of `video` or `latents` should be provided" )
499- 
500499    # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections 
501500    def  fuse_qkv_projections (self ) ->  None :
502501        r"""Enables fused QKV projections.""" 
@@ -584,18 +583,18 @@ def __call__(
584583
585584        Args: 
586585            image (`PipelineImageInput`): 
587-                 The input video  to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. 
586+                 The input image  to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. 
588587            prompt (`str` or `List[str]`, *optional*): 
589588                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 
590589                instead. 
591590            negative_prompt (`str` or `List[str]`, *optional*): 
592591                The prompt or prompts not to guide the image generation. If not defined, one has to pass 
593592                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 
594593                less than `1`). 
595-             height (`int`, *optional*, defaults to self.unet .config.sample_size  * self.vae_scale_factor ): 
596-                 The height in pixels of the generated image. This is set to 1024  by default for the best results. 
597-             width (`int`, *optional*, defaults to self.unet .config.sample_size  * self.vae_scale_factor ): 
598-                 The width in pixels of the generated image. This is set to 1024  by default for the best results. 
594+             height (`int`, *optional*, defaults to self.transformer .config.sample_height  * self.vae_scale_factor_spatial ): 
595+                 The height in pixels of the generated image. This is set to 480  by default for the best results. 
596+             width (`int`, *optional*, defaults to self.transformer .config.sample_height  * self.vae_scale_factor_spatial ): 
597+                 The width in pixels of the generated image. This is set to 720  by default for the best results. 
599598            num_frames (`int`, defaults to `48`): 
600599                Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will 
601600                contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where 
@@ -665,20 +664,19 @@ def __call__(
665664        if  isinstance (callback_on_step_end , (PipelineCallback , MultiPipelineCallbacks )):
666665            callback_on_step_end_tensor_inputs  =  callback_on_step_end .tensor_inputs 
667666
668-         height  =  height  or  self .transformer .config .sample_size  *  self .vae_scale_factor_spatial 
669-         width  =  width  or  self .transformer .config .sample_size  *  self .vae_scale_factor_spatial 
670667        num_videos_per_prompt  =  1 
671668
672669        # 1. Check inputs. Raise error if not correct 
673670        self .check_inputs (
674-             image ,
675-             prompt ,
676-             height ,
677-             width ,
678-             negative_prompt ,
679-             callback_on_step_end_tensor_inputs ,
680-             prompt_embeds ,
681-             negative_prompt_embeds ,
671+             image = image ,
672+             prompt = prompt ,
673+             height = height ,
674+             width = width ,
675+             negative_prompt = negative_prompt ,
676+             callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
677+             latents = latents ,
678+             prompt_embeds = prompt_embeds ,
679+             negative_prompt_embeds = negative_prompt_embeds ,
682680        )
683681        self ._guidance_scale  =  guidance_scale 
684682        self ._interrupt  =  False 
0 commit comments