1616from  typing  import  Any , Callable , Dict , List , Optional , Tuple , Union 
1717
1818import  ftfy 
19- import  numpy  as  np 
2019import  PIL 
2120import  regex  as  re 
2221import  torch 
@@ -165,7 +164,7 @@ def _get_t5_prompt_embeds(
165164        self ,
166165        prompt : Union [str , List [str ]] =  None ,
167166        num_videos_per_prompt : int  =  1 ,
168-         max_sequence_length : int  =  226 ,
167+         max_sequence_length : int  =  512 ,
169168        device : Optional [torch .device ] =  None ,
170169        dtype : Optional [torch .dtype ] =  None ,
171170    ):
@@ -292,15 +291,18 @@ def encode_prompt(
292291    def  check_inputs (
293292        self ,
294293        prompt ,
294+         negative_prompt ,
295295        image ,
296-         max_area ,
296+         height ,
297+         width ,
297298        prompt_embeds = None ,
299+         negative_prompt_embeds = None ,
298300        callback_on_step_end_tensor_inputs = None ,
299301    ):
300302        if  not  isinstance (image , torch .Tensor ) and  not  isinstance (image , PIL .Image .Image ):
301303            raise  ValueError ("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is"  f" { type (image )}  )
302-         if  max_area   <  0 :
303-             raise  ValueError (f"`max_area` has  to be positive  but are { max_area }  )
304+         if  height   %   16   !=   0   or   width   %   16   !=  0 :
305+             raise  ValueError (f"`height` and `width` have  to be divisible by 16  but are { height }  and  { width }  )
304306
305307        if  callback_on_step_end_tensor_inputs  is  not None  and  not  all (
306308            k  in  self ._callback_tensor_inputs  for  k  in  callback_on_step_end_tensor_inputs 
@@ -314,43 +316,43 @@ def check_inputs(
314316                f"Cannot forward both `prompt`: { prompt } { prompt_embeds }  
315317                " only forward one of the two." 
316318            )
319+         elif  negative_prompt  is  not None  and  negative_prompt_embeds  is  not None :
320+             raise  ValueError (
321+                 f"Cannot forward both `negative_prompt`: { negative_prompt } { negative_prompt_embeds }  
322+                 " only forward one of the two." 
323+             )
317324        elif  prompt  is  None  and  prompt_embeds  is  None :
318325            raise  ValueError (
319326                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 
320327            )
321328        elif  prompt  is  not None  and  (not  isinstance (prompt , str ) and  not  isinstance (prompt , list )):
322329            raise  ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )}  )
330+         elif  negative_prompt  is  not None  and  (
331+             not  isinstance (negative_prompt , str ) and  not  isinstance (negative_prompt , list )
332+         ):
333+             raise  ValueError (f"`negative_prompt` has to be of type `str` or `list` but is { type (negative_prompt )}  )
323334
324335    def  prepare_latents (
325336        self ,
326337        image : PipelineImageInput ,
327338        batch_size : int ,
328-         num_channels_latents : 32 ,
329-         height : int  =  720 ,
330-         width : int  =  1280 ,
331-         max_area : int  =  720  *  1280 ,
339+         num_channels_latents : int  =  16 ,
340+         height : int  =  480 ,
341+         width : int  =  832 ,
332342        num_frames : int  =  81 ,
333-         num_latent_frames : int  =  21 ,
334343        dtype : Optional [torch .dtype ] =  None ,
335344        device : Optional [torch .device ] =  None ,
336345        generator : Optional [Union [torch .Generator , List [torch .Generator ]]] =  None ,
337346        latents : Optional [torch .Tensor ] =  None ,
338347    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
339-         aspect_ratio  =  height  /  width 
340-         mod_value  =  self .vae_scale_factor_spatial  *  self .transformer .config .patch_size [1 ]
341-         height  =  round (np .sqrt (max_area  *  aspect_ratio )) //  mod_value  *  mod_value 
342-         width  =  round (np .sqrt (max_area  /  aspect_ratio )) //  mod_value  *  mod_value 
343- 
344348        if  latents  is  not None :
345349            return  latents .to (device = device , dtype = dtype )
346350
347-         shape  =  (
348-             batch_size ,
349-             num_channels_latents ,
350-             num_latent_frames ,
351-             int (height ) //  self .vae_scale_factor_spatial ,
352-             int (width ) //  self .vae_scale_factor_spatial ,
353-         )
351+         num_latent_frames  =  (num_frames  -  1 ) //  self .vae_scale_factor_temporal  +  1 
352+         latent_height  =  height  //  self .vae_scale_factor_spatial 
353+         latent_width  =  width  //  self .vae_scale_factor_spatial 
354+ 
355+         shape  =  (batch_size , num_channels_latents , num_latent_frames , latent_height , latent_width )
354356        if  isinstance (generator , list ) and  len (generator ) !=  batch_size :
355357            raise  ValueError (
356358                f"You have passed a list of generators of length { len (generator )}  
@@ -359,35 +361,25 @@ def prepare_latents(
359361
360362        latents  =  randn_tensor (shape , generator = generator , device = device , dtype = dtype )
361363
362-         image  =  self . video_processor . preprocess ( image ,  height = height ,  width = width )[:, :,  None ] 
364+         image  =  image . unsqueeze ( 2 ) 
363365        video_condition  =  torch .cat (
364366            [image , torch .zeros (image .shape [0 ], image .shape [1 ], num_frames  -  1 , height , width )], dim = 2 
365367        )
366368        video_condition  =  video_condition .to (device = device , dtype = dtype )
369+ 
367370        if  isinstance (generator , list ):
368371            latent_condition  =  [retrieve_latents (self .vae .encode (video_condition ), g ) for  g  in  generator ]
369372            latents  =  latent_condition  =  torch .cat (latent_condition )
370373        else :
371374            latent_condition  =  retrieve_latents (self .vae .encode (video_condition ), generator )
372375            latent_condition  =  latent_condition .repeat (batch_size , 1 , 1 , 1 , 1 )
373-         mask_lat_size  =  torch .ones (
374-             batch_size ,
375-             1 ,
376-             num_frames ,
377-             int (height ) //  self .vae_scale_factor_spatial ,
378-             int (width ) //  self .vae_scale_factor_spatial ,
379-         )
376+ 
377+         mask_lat_size  =  torch .ones (batch_size , 1 , num_frames , latent_height , latent_width )
380378        mask_lat_size [:, :, list (range (1 , num_frames ))] =  0 
381379        first_frame_mask  =  mask_lat_size [:, :, 0 :1 ]
382380        first_frame_mask  =  torch .repeat_interleave (first_frame_mask , dim = 2 , repeats = self .vae_scale_factor_temporal )
383381        mask_lat_size  =  torch .concat ([first_frame_mask , mask_lat_size [:, :, 1 :, :]], dim = 2 )
384-         mask_lat_size  =  mask_lat_size .view (
385-             batch_size ,
386-             - 1 ,
387-             self .vae_scale_factor_temporal ,
388-             int (height ) //  self .vae_scale_factor_spatial ,
389-             int (width ) //  self .vae_scale_factor_spatial ,
390-         )
382+         mask_lat_size  =  mask_lat_size .view (batch_size , - 1 , self .vae_scale_factor_temporal , latent_height , latent_width )
391383        mask_lat_size  =  mask_lat_size .transpose (1 , 2 )
392384        mask_lat_size  =  mask_lat_size .to (latent_condition .device )
393385
@@ -424,7 +416,8 @@ def __call__(
424416        image : PipelineImageInput ,
425417        prompt : Union [str , List [str ]] =  None ,
426418        negative_prompt : Union [str , List [str ]] =  None ,
427-         max_area : int  =  720  *  1280 ,
419+         height : int  =  480 ,
420+         width : int  =  832 ,
428421        num_frames : int  =  81 ,
429422        num_inference_steps : int  =  50 ,
430423        guidance_scale : float  =  5.0 ,
@@ -451,9 +444,15 @@ def __call__(
451444            prompt (`str` or `List[str]`, *optional*): 
452445                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 
453446                instead. 
454-             max_area (`int`, defaults to `1280 * 720`): 
455-                 The maximum area in pixels of the generated image. 
456-             num_frames (`int`, defaults to `129`): 
447+             negative_prompt (`str` or `List[str]`, *optional*): 
448+                 The prompt or prompts not to guide the image generation. If not defined, one has to pass 
449+                 `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 
450+                 less than `1`). 
451+             height (`int`, defaults to `480`): 
452+                 The height of the generated video. 
453+             width (`int`, defaults to `832`): 
454+                 The width of the generated video. 
455+             num_frames (`int`, defaults to `81`): 
457456                The number of frames in the generated video. 
458457            num_inference_steps (`int`, defaults to `50`): 
459458                The number of denoising steps. More denoising steps usually lead to a higher quality image at the 
@@ -514,9 +513,12 @@ def __call__(
514513        # 1. Check inputs. Raise error if not correct 
515514        self .check_inputs (
516515            prompt ,
516+             negative_prompt ,
517517            image ,
518-             max_area ,
518+             height ,
519+             width ,
519520            prompt_embeds ,
521+             negative_prompt_embeds ,
520522            callback_on_step_end_tensor_inputs ,
521523        )
522524
@@ -548,36 +550,29 @@ def __call__(
548550        )
549551
550552        # Encode image embedding 
551-         image_embeds  =  self .encode_image (image )
552-         image_embeds  =  image_embeds .repeat (batch_size , 1 , 1 )
553- 
554553        transformer_dtype  =  self .transformer .dtype 
555554        prompt_embeds  =  prompt_embeds .to (transformer_dtype )
556555        if  negative_prompt_embeds  is  not None :
557556            negative_prompt_embeds  =  negative_prompt_embeds .to (transformer_dtype )
557+ 
558+         image_embeds  =  self .encode_image (image )
559+         image_embeds  =  image_embeds .repeat (batch_size , 1 , 1 )
558560        image_embeds  =  image_embeds .to (transformer_dtype )
559561
560562        # 4. Prepare timesteps 
561563        self .scheduler .set_timesteps (num_inference_steps , device = device )
562564        timesteps  =  self .scheduler .timesteps 
563565
564-         if  isinstance (image , torch .Tensor ):
565-             height , width  =  image .shape [- 2 :]
566-         else :
567-             width , height  =  image .size 
568- 
569566        # 5. Prepare latent variables 
570-         num_channels_latents  =  self .vae .config .z_dim 
571-         num_latent_frames  =  ( num_frames   -   1 )  //   self . vae_scale_factor_temporal   +   1 
567+         num_channels_latents  =  self .transformer .config .in_channels 
568+         image  =  self . video_processor . preprocess ( image ,  height = height ,  width = width ). to ( device ,  dtype = torch . float32 ) 
572569        latents , condition  =  self .prepare_latents (
573570            image ,
574571            batch_size  *  num_videos_per_prompt ,
575572            num_channels_latents ,
576573            height ,
577574            width ,
578-             max_area ,
579575            num_frames ,
580-             num_latent_frames ,
581576            torch .float32 ,
582577            device ,
583578            generator ,
0 commit comments