@@ -435,10 +435,11 @@ def check_inputs(
435435 f" got: `prompt_attention_mask` { prompt_attention_mask .shape } != `negative_prompt_attention_mask`"
436436 f" { negative_prompt_attention_mask .shape } ."
437437 )
438-
438+
439+
439440 @staticmethod
440441 # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
441- def _pack_latents (latents : torch .Tensor , patch_size : int = 1 , patch_size_t : int = 1 ) -> torch .Tensor :
442+ def _pack_latents (latents : torch .Tensor , patch_size : int = 1 , patch_size_t : int = 1 , device : torch . device = None ) -> torch .Tensor :
442443 # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
443444 # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
444445 # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
@@ -447,6 +448,16 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
447448 post_patch_num_frames = num_frames // patch_size_t
448449 post_patch_height = height // patch_size
449450 post_patch_width = width // patch_size
451+
452+ latent_sample_coords = torch .meshgrid (
453+ torch .arange (0 , num_frames , patch_size_t , device = device ),
454+ torch .arange (0 , height , patch_size , device = device ),
455+ torch .arange (0 , width , patch_size , device = device ),
456+ )
457+ latent_sample_coords = torch .stack (latent_sample_coords , dim = 0 )
458+ latent_coords = latent_sample_coords .unsqueeze (0 ).repeat (batch_size , 1 , 1 , 1 , 1 )
459+ latent_coords = latent_coords .reshape (batch_size , - 1 , num_frames * height * width )
460+
450461 latents = latents .reshape (
451462 batch_size ,
452463 - 1 ,
@@ -458,7 +469,7 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
458469 patch_size ,
459470 )
460471 latents = latents .permute (0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 ).flatten (4 , 7 ).flatten (1 , 3 )
461- return latents
472+ return latents , latent_coords
462473
463474 @staticmethod
464475 # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
@@ -588,6 +599,7 @@ def prepare_latents(
588599
589600 shape = (batch_size , num_channels_latents , num_latent_frames , latent_height , latent_width )
590601 latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
602+ latents = torch .load ("/raid/yiyi/LTX-Video/init_latents.pt" ).to (device , dtype = dtype )
591603
592604 extra_conditioning_latents = []
593605 extra_conditioning_rope_interpolation_scales = []
@@ -605,14 +617,6 @@ def prepare_latents(
605617 num_frames_input = data .size (2 )
606618 num_frames_output = self .trim_conditioning_sequence (condition .frame_index , num_frames_input , num_frames )
607619 data = data [:, :, :num_frames_output ]
608-
609- print (data .shape )
610- print (data [0 ,0 ,:3 ,:5 ,:5 ])
611- data_loaded = torch .load ("/raid/yiyi/LTX-Video/media_item.pt" )
612- print (data_loaded .shape )
613- print (data_loaded [0 ,0 ,:3 ,:5 ,:5 ])
614- print (torch .sum ((data_loaded - data ).abs ()))
615- print (f" dtype:{ dtype } , device:{ device } " )
616620 data = data .to (device , dtype = torch .bfloat16 )
617621 else :
618622 raise ValueError ("Either `image` or `video` must be provided in the `LTXVideoCondition`." )
@@ -623,23 +627,11 @@ def prepare_latents(
623627 f"but got { data .size (2 )} frames."
624628 )
625629
626- print (f" before encode: { data .shape } , { data .dtype } , { data .device } " )
627-
628630 condition_latents = retrieve_latents (self .vae .encode (data ), generator = generator )
629- print (f" after encode: { condition_latents .shape } , { condition_latents .dtype } , { condition_latents .device } " )
630- print (condition_latents [0 ,0 ,:3 ,:5 ,:5 ])
631- condition_latents_before_normalize = torch .load ("/raid/yiyi/LTX-Video/latents_before_normalize.pt" )
632- print (torch .sum ((condition_latents_before_normalize - condition_latents ).abs ()))
633- assert False
634- condition_latents = self ._normalize_latents (condition_latents , self .vae .latents_mean , self .vae .latents_std )
635-
636- print (f" after normalize: { condition_latents .shape } " )
637- print (condition_latents [0 ,0 ,:3 ,:5 ,:5 ])
638- condition_latents_loaded = torch .load ("/raid/yiyi/LTX-Video/latents_normalized.pt" )
639- print (condition_latents_loaded .shape )
640- print (condition_latents_loaded [0 ,0 ,:3 ,:5 ,:5 ])
641- print (torch .sum ((condition_latents_loaded .to (condition_latents .device ) - condition_latents ).abs ()))
642- assert False
631+ condition_latents_loaded = torch .load ("/raid/yiyi/LTX-Video/latents_before_normalize.pt" ).to (condition_latents .device )
632+ print (f" condition_latents(loaded): { condition_latents_loaded .shape } , { condition_latents_loaded [0 ,0 ,:3 ,:3 ,:3 ]} " )
633+ print (f" condition_latents: { condition_latents .shape } , { condition_latents [0 ,0 ,:3 ,:3 ,:3 ]} " )
634+ condition_latents = self ._normalize_latents (condition_latents_loaded , self .vae .latents_mean , self .vae .latents_std )
643635
644636 num_data_frames = data .size (2 )
645637 num_cond_frames = condition_latents .size (2 )
@@ -667,7 +659,7 @@ def prepare_latents(
667659 noise = randn_tensor (condition_latents .shape , generator = generator , device = device , dtype = dtype )
668660 condition_latents = torch .lerp (noise , condition_latents , condition .strength )
669661 c_nlf = condition_latents .shape [2 ]
670- condition_latents = self ._pack_latents (
662+ condition_latents , latent_coords = self ._pack_latents (
671663 condition_latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size
672664 )
673665 conditioning_mask = torch .full (
@@ -692,30 +684,37 @@ def prepare_latents(
692684 extra_conditioning_rope_interpolation_scales .append (rope_interpolation_scale )
693685 extra_conditioning_mask .append (conditioning_mask )
694686
695- latents = self ._pack_latents (
696- latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size
687+ latents , latent_coords = self ._pack_latents (
688+ latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size , device
697689 )
698- rope_interpolation_scale = [
699- self .vae_temporal_compression_ratio / frame_rate ,
700- self .vae_spatial_compression_ratio ,
701- self .vae_spatial_compression_ratio ,
702- ]
703- rope_interpolation_scale = (
704- torch .tensor (rope_interpolation_scale , device = device , dtype = dtype )
705- .view (- 1 , 1 , 1 , 1 , 1 )
706- .repeat (1 , 1 , num_latent_frames , latent_height , latent_width )
690+
691+ pixel_coords = (
692+ latent_coords
693+ * torch .tensor ([self .vae_temporal_compression_ratio , self .vae_spatial_compression_ratio , self .vae_spatial_compression_ratio ], device = latent_coords .device )[None , :, None ]
707694 )
708- conditioning_mask = self ._pack_latents (
709- conditioning_mask , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size
695+
696+ pixel_coords [:, 0 ] = (pixel_coords [:, 0 ] + 1 - self .vae_temporal_compression_ratio ).clamp (min = 0 )
697+
698+ rope_interpolation_scale = pixel_coords
699+
700+ conditioning_mask = condition_latent_frames_mask .gather (
701+ 1 , latent_coords [:, 0 ]
710702 )
711703
704+ ## YiYi Todo: not looked into yet
712705 if len (extra_conditioning_latents ) > 0 :
713706 latents = torch .cat ([* extra_conditioning_latents , latents ], dim = 1 )
714707 rope_interpolation_scale = torch .cat (
715708 [* extra_conditioning_rope_interpolation_scales , rope_interpolation_scale ], dim = 2
716709 )
717710 conditioning_mask = torch .cat ([* extra_conditioning_mask , conditioning_mask ], dim = 1 )
718711
712+
713+ print (f" latents (after pack): { latents .shape } , { latents [0 ,:3 ,:3 ]} " )
714+ print (f" conditioning_mask: { conditioning_mask .shape } , { conditioning_mask [0 ,:10 ]} " )
715+ print (f" rope_interpolation_scale: { rope_interpolation_scale .shape } , { rope_interpolation_scale [0 ,:3 ,:3 ]} " )
716+ print (f" extra_conditioning_num_latents: { extra_conditioning_num_latents } " )
717+ assert False
719718 return latents , conditioning_mask , rope_interpolation_scale , extra_conditioning_num_latents
720719
721720 @property
@@ -914,7 +913,7 @@ def __call__(
914913 frame_rate ,
915914 generator ,
916915 device ,
917- torch . float32 ,
916+ prompt_embeds . dtype ,
918917 )
919918 init_latents = latents .clone ()
920919
0 commit comments