@@ -197,10 +197,6 @@ def __init__(
197197            self .tokenizer .model_max_length  if  hasattr (self , "tokenizer" ) and  self .tokenizer  is  not None  else  128 
198198        )
199199
200-         self .default_height  =  512 
201-         self .default_width  =  704 
202-         self .default_frames  =  121 
203- 
204200    # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128 
205201    def  _get_t5_prompt_embeds (
206202        self ,
@@ -389,6 +385,10 @@ def check_inputs(
389385
390386    @staticmethod  
391387    def  _pack_latents (latents : torch .Tensor , patch_size : int  =  1 , patch_size_t : int  =  1 ) ->  torch .Tensor :
388+         # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. 
389+         # The patch dimensions are then permuted and collapsed into the channel dimension of shape: 
390+         # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). 
391+         # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features 
392392        batch_size , num_channels , num_frames , height , width  =  latents .shape 
393393        post_patch_num_frames  =  num_frames  //  patch_size_t 
394394        post_patch_height  =  height  //  patch_size 
@@ -410,7 +410,10 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
410410    def  _unpack_latents (
411411        latents : torch .Tensor , num_frames : int , height : int , width : int , patch_size : int  =  1 , patch_size_t : int  =  1 
412412    ) ->  torch .Tensor :
413-         batch_size , num_channels , video_sequence_length  =  latents .shape 
413+         # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) 
414+         # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of 
415+         # what happens in the `_pack_latents` method. 
416+         batch_size  =  latents .size (0 )
414417        latents  =  latents .reshape (batch_size , num_frames , height , width , - 1 , patch_size_t , patch_size , patch_size )
415418        latents  =  latents .permute (0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 ).flatten (6 , 7 ).flatten (4 , 5 ).flatten (2 , 3 )
416419        return  latents 
@@ -419,6 +422,7 @@ def _unpack_latents(
419422    def  _normalize_latents (
420423        latents : torch .Tensor , latents_mean : torch .Tensor , latents_std : torch .Tensor , scaling_factor : float  =  1.0 
421424    ) ->  torch .Tensor :
425+         # Normalize latents across the channel dimension [B, C, F, H, W] 
422426        latents_mean  =  latents_mean .view (1 , - 1 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
423427        latents_std  =  latents_std .view (1 , - 1 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
424428        latents  =  (latents  -  latents_mean ) *  scaling_factor  /  latents_std 
@@ -428,6 +432,7 @@ def _normalize_latents(
428432    def  _denormalize_latents (
429433        latents : torch .Tensor , latents_mean : torch .Tensor , latents_std : torch .Tensor , scaling_factor : float  =  1.0 
430434    ) ->  torch .Tensor :
435+         # Denormalize latents across the channel dimension [B, C, F, H, W] 
431436        latents_mean  =  latents_mean .view (1 , - 1 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
432437        latents_std  =  latents_std .view (1 , - 1 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
433438        latents  =  latents  *  latents_std  /  scaling_factor  +  latents_mean 
@@ -488,9 +493,9 @@ def __call__(
488493        self ,
489494        prompt : Union [str , List [str ]] =  None ,
490495        negative_prompt : Optional [Union [str , List [str ]]] =  None ,
491-         height : Optional [ int ]  =  None ,
492-         width : Optional [ int ]  =  None ,
493-         num_frames : int  =  81 ,
496+         height : int  =  512 ,
497+         width : int  =  704 ,
498+         num_frames : int  =  161 ,
494499        frame_rate : int  =  25 ,
495500        num_inference_steps : int  =  50 ,
496501        timesteps : List [int ] =  None ,
@@ -515,11 +520,11 @@ def __call__(
515520            prompt (`str` or `List[str]`, *optional*): 
516521                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 
517522                instead. 
518-             height (`int`, *optional*,  defaults to `self.default_height `): 
523+             height (`int`, defaults to `512 `): 
519524                The height in pixels of the generated image. This is set to 480 by default for the best results. 
520-             width (`int`, *optional*,  defaults to `self.default_width `): 
525+             width (`int`, defaults to `704 `): 
521526                The width in pixels of the generated image. This is set to 848 by default for the best results. 
522-             num_frames (`int`, defaults to `81  `): 
527+             num_frames (`int`, defaults to `161 `): 
523528                The number of video frames to generate 
524529            num_inference_steps (`int`, *optional*, defaults to 50): 
525530                The number of denoising steps. More denoising steps usually lead to a higher quality image at the 
@@ -581,10 +586,6 @@ def __call__(
581586        if  isinstance (callback_on_step_end , (PipelineCallback , MultiPipelineCallbacks )):
582587            callback_on_step_end_tensor_inputs  =  callback_on_step_end .tensor_inputs 
583588
584-         height  =  height  or  self .default_height 
585-         width  =  width  or  self .default_width 
586-         latent_frame_rate  =  frame_rate  /  self .vae_temporal_compression_ratio 
587- 
588589        # 1. Check inputs. Raise error if not correct 
589590        self .check_inputs (
590591            prompt = prompt ,
@@ -671,6 +672,7 @@ def __call__(
671672        self ._num_timesteps  =  len (timesteps )
672673
673674        # 6. Prepare micro-conditions 
675+         latent_frame_rate  =  frame_rate  /  self .vae_temporal_compression_ratio 
674676        rope_interpolation_scale  =  (
675677            1  /  latent_frame_rate ,
676678            self .vae_spatial_compression_ratio ,
0 commit comments