@@ -380,6 +380,7 @@ def prepare_latents(
380380 device : Optional [torch .device ] = None ,
381381 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
382382 latents : Optional [torch .Tensor ] = None ,
383+ last_image : Optional [torch .Tensor ] = None ,
383384 ) -> Tuple [torch .Tensor , torch .Tensor ]:
384385 num_latent_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
385386 latent_height = height // self .vae_scale_factor_spatial
@@ -398,9 +399,16 @@ def prepare_latents(
398399 latents = latents .to (device = device , dtype = dtype )
399400
400401 image = image .unsqueeze (2 )
401- video_condition = torch .cat (
402- [image , image .new_zeros (image .shape [0 ], image .shape [1 ], num_frames - 1 , height , width )], dim = 2
403- )
402+ if last_image is None :
403+ video_condition = torch .cat (
404+ [image , image .new_zeros (image .shape [0 ], image .shape [1 ], num_frames - 1 , height , width )], dim = 2
405+ )
406+ else :
407+ last_image = last_image .unsqueeze (2 )
408+ video_condition = torch .cat (
409+ [image , image .new_zeros (image .shape [0 ], image .shape [1 ], num_frames - 2 , height , width ), last_image ],
410+ dim = 2 ,
411+ )
404412 video_condition = video_condition .to (device = device , dtype = dtype )
405413
406414 latents_mean = (
@@ -424,7 +432,11 @@ def prepare_latents(
424432 latent_condition = (latent_condition - latents_mean ) * latents_std
425433
426434 mask_lat_size = torch .ones (batch_size , 1 , num_frames , latent_height , latent_width )
427- mask_lat_size [:, :, list (range (1 , num_frames ))] = 0
435+
436+ if last_image is None :
437+ mask_lat_size [:, :, list (range (1 , num_frames ))] = 0
438+ else :
439+ mask_lat_size [:, :, list (range (1 , num_frames - 1 ))] = 0
428440 first_frame_mask = mask_lat_size [:, :, 0 :1 ]
429441 first_frame_mask = torch .repeat_interleave (first_frame_mask , dim = 2 , repeats = self .vae_scale_factor_temporal )
430442 mask_lat_size = torch .concat ([first_frame_mask , mask_lat_size [:, :, 1 :, :]], dim = 2 )
@@ -476,6 +488,7 @@ def __call__(
476488 prompt_embeds : Optional [torch .Tensor ] = None ,
477489 negative_prompt_embeds : Optional [torch .Tensor ] = None ,
478490 image_embeds : Optional [torch .Tensor ] = None ,
491+ last_image : Optional [torch .Tensor ] = None ,
479492 output_type : Optional [str ] = "np" ,
480493 return_dict : bool = True ,
481494 attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -620,7 +633,10 @@ def __call__(
620633 negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
621634
622635 if image_embeds is None :
623- image_embeds = self .encode_image (image , device )
636+ if last_image is None :
637+ image_embeds = self .encode_image (image , device )
638+ else :
639+ image_embeds = self .encode_image ([image , last_image ], device )
624640 image_embeds = image_embeds .repeat (batch_size , 1 , 1 )
625641 image_embeds = image_embeds .to (transformer_dtype )
626642
@@ -631,6 +647,10 @@ def __call__(
631647 # 5. Prepare latent variables
632648 num_channels_latents = self .vae .config .z_dim
633649 image = self .video_processor .preprocess (image , height = height , width = width ).to (device , dtype = torch .float32 )
650+ if last_image is not None :
651+ last_image = self .video_processor .preprocess (last_image , height = height , width = width ).to (
652+ device , dtype = torch .float32
653+ )
634654 latents , condition = self .prepare_latents (
635655 image ,
636656 batch_size * num_videos_per_prompt ,
@@ -642,6 +662,7 @@ def __call__(
642662 device ,
643663 generator ,
644664 latents ,
665+ last_image ,
645666 )
646667
647668 # 6. Denoising loop
0 commit comments