- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
[WIP] test prepare_latents for ltx0.95 #10976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
e202e46
              267583a
              16c1467
              a098d94
              d8bd10e
              6f7e837
              cbc035d
              1fdebea
              0cc1905
              7c2151f
              353728a
              d85d21c
              445cf58
              fb46d21
              64df9af
              00e9670
              ed2f7e3
              b98d69c
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -437,8 +437,8 @@ def check_inputs( | |
| ) | ||
|  | ||
| @staticmethod | ||
| # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents | ||
| def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: | ||
| # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents | ||
| def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor: | ||
| # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. | ||
| # The patch dimensions are then permuted and collapsed into the channel dimension of shape: | ||
| # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). | ||
|  | @@ -447,6 +447,17 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int | |
| post_patch_num_frames = num_frames // patch_size_t | ||
| post_patch_height = height // patch_size | ||
| post_patch_width = width // patch_size | ||
|  | ||
| latent_sample_coords = torch.meshgrid( | ||
| torch.arange(0, num_frames, patch_size_t, device=device), | ||
| torch.arange(0, height, patch_size, device=device), | ||
| torch.arange(0, width, patch_size, device=device), | ||
| indexing="ij", | ||
| ) | ||
| latent_sample_coords = torch.stack(latent_sample_coords, dim=0) | ||
| latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) | ||
| latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width) | ||
|  | ||
| latents = latents.reshape( | ||
| batch_size, | ||
| -1, | ||
|  | @@ -458,7 +469,7 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int | |
| patch_size, | ||
| ) | ||
| latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) | ||
| return latents | ||
| return latents, latent_coords | ||
|  | ||
| @staticmethod | ||
| # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents | ||
|  | @@ -503,10 +514,10 @@ def _prepare_non_first_frame_conditioning( | |
| frame_index: int, | ||
| strength: float, | ||
| num_prefix_latent_frames: int = 2, | ||
| prefix_latents_mode: str = "soft", | ||
| prefix_latents_mode: str = "concat", | ||
| prefix_soft_conditioning_strength: float = 0.15, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| num_latent_frames = latents.size(2) | ||
| num_latent_frames = condition_latents.size(2) | ||
|  | ||
| if num_latent_frames < num_prefix_latent_frames: | ||
| raise ValueError( | ||
|  | @@ -544,6 +555,25 @@ def _prepare_non_first_frame_conditioning( | |
|  | ||
| return latents, condition_latents, condition_latent_frames_mask | ||
|  | ||
| def trim_conditioning_sequence( | ||
| self, start_frame: int, sequence_num_frames: int, target_num_frames: int | ||
| ): | ||
| """ | ||
| Trim a conditioning sequence to the allowed number of frames. | ||
| Args: | ||
| start_frame (int): The target frame number of the first frame in the sequence. | ||
| sequence_num_frames (int): The number of frames in the sequence. | ||
| target_num_frames (int): The target number of frames in the generated video. | ||
| Returns: | ||
| int: updated sequence length | ||
| """ | ||
| scale_factor = self.vae_temporal_compression_ratio | ||
| num_frames = min(sequence_num_frames, target_num_frames - start_frame) | ||
| # Trim down to a multiple of temporal_scale_factor frames plus 1 | ||
| num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 | ||
| return num_frames | ||
|  | ||
|  | ||
| def prepare_latents( | ||
| self, | ||
| conditions: Union[LTXVideoCondition, List[LTXVideoCondition]], | ||
|  | @@ -573,13 +603,17 @@ def prepare_latents( | |
| extra_conditioning_num_latents = ( | ||
| 0 # Number of extra conditioning latents added (should be removed before decoding) | ||
| ) | ||
| condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=dtype) | ||
| condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) | ||
|  | ||
| for condition in conditions: | ||
| if condition.image is not None: | ||
| data = self.video_processor.preprocess(condition.image, height, width).unsqueeze(2) | ||
| elif condition.video is not None: | ||
| data = self.video_processor.preprocess_video(condition.vide, height, width) | ||
| data = self.video_processor.preprocess_video(condition.video, height, width) | ||
| num_frames_input = data.size(2) | ||
| num_frames_output = self.trim_conditioning_sequence(condition.frame_index, num_frames_input, num_frames) | ||
| data = data[:, :, :num_frames_output] | ||
| data = data.to(device, dtype=dtype) | ||
| else: | ||
| raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.") | ||
|  | ||
|  | @@ -599,6 +633,7 @@ def prepare_latents( | |
| latents[:, :, :num_cond_frames], condition_latents, condition.strength | ||
| ) | ||
| condition_latent_frames_mask[:, :num_cond_frames] = condition.strength | ||
|  | ||
| else: | ||
| if num_data_frames > 1: | ||
| ( | ||
|  | @@ -617,47 +652,39 @@ def prepare_latents( | |
| noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) | ||
| condition_latents = torch.lerp(noise, condition_latents, condition.strength) | ||
| c_nlf = condition_latents.shape[2] | ||
| condition_latents = self._pack_latents( | ||
| condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||
| condition_latents, rope_interpolation_scale = self._pack_latents( | ||
| condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device | ||
| ) | ||
|  | ||
| rope_interpolation_scale = ( | ||
| rope_interpolation_scale * | ||
| torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=rope_interpolation_scale.device)[None, :, None] | ||
| ) | ||
| rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0) | ||
| rope_interpolation_scale[:, 0] += condition.frame_index | ||
|  | ||
|          | ||
| conditioning_mask = torch.full( | ||
| condition_latents.shape[:2], condition.strength, device=device, dtype=dtype | ||
| ) | ||
|  | ||
| rope_interpolation_scale = [ | ||
| # TODO!!! This is incorrect: the frame index needs to added AFTER multiplying the interpolation | ||
| # scale with the grid. | ||
| (self.vae_temporal_compression_ratio + condition.frame_index) / frame_rate, | ||
| self.vae_spatial_compression_ratio, | ||
| self.vae_spatial_compression_ratio, | ||
| ] | ||
| 
      Comment on lines
    
      -627
     to 
      -633
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yiyixuxu Pardon my stupidity, but I can't seem to find if we're handling this  In the original code, this is what I was meaning to handle: https://github.com/Lightricks/LTX-Video/blob/496dc5058f4408dcb777282f3fb6377fb2da08e6/ltx_video/pipelines/pipeline_ltx_video.py#L1285 | ||
| rope_interpolation_scale = ( | ||
| torch.tensor(rope_interpolation_scale, device=device, dtype=dtype) | ||
| .view(-1, 1, 1, 1, 1) | ||
| .repeat(1, 1, c_nlf, latent_height, latent_width) | ||
| ) | ||
| extra_conditioning_num_latents += condition_latents.size(1) | ||
|  | ||
| extra_conditioning_latents.append(condition_latents) | ||
| extra_conditioning_rope_interpolation_scales.append(rope_interpolation_scale) | ||
| extra_conditioning_mask.append(conditioning_mask) | ||
|  | ||
| latents = self._pack_latents( | ||
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||
| latents, rope_interpolation_scale = self._pack_latents( | ||
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device | ||
| ) | ||
| rope_interpolation_scale = [ | ||
| self.vae_temporal_compression_ratio / frame_rate, | ||
| self.vae_spatial_compression_ratio, | ||
| self.vae_spatial_compression_ratio, | ||
| ] | ||
| rope_interpolation_scale = ( | ||
| torch.tensor(rope_interpolation_scale, device=device, dtype=dtype) | ||
| .view(-1, 1, 1, 1, 1) | ||
| .repeat(1, 1, num_latent_frames, latent_height, latent_width) | ||
| conditioning_mask = condition_latent_frames_mask.gather( | ||
| 1, rope_interpolation_scale[:, 0] | ||
| ) | ||
| conditioning_mask = self._pack_latents( | ||
| conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||
|  | ||
| rope_interpolation_scale = ( | ||
| rope_interpolation_scale | ||
| * torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=rope_interpolation_scale.device)[None, :, None] | ||
| ) | ||
| rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0) | ||
|  | ||
| if len(extra_conditioning_latents) > 0: | ||
| latents = torch.cat([*extra_conditioning_latents, latents], dim=1) | ||
|  | @@ -864,7 +891,7 @@ def __call__( | |
| frame_rate, | ||
| generator, | ||
| device, | ||
| torch.float32, | ||
| prompt_embeds.dtype, | ||
|          | ||
| ) | ||
| init_latents = latents.clone() | ||
|  | ||
|  | @@ -955,8 +982,8 @@ def __call__( | |
| pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0] | ||
|  | ||
| latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) | ||
| latents = self._pack_latents( | ||
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||
| latents, _ = self._pack_latents( | ||
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device | ||
| ) | ||
|  | ||
| if callback_on_step_end is not None: | ||
|  | ||

Uh oh!
There was an error while loading. Please reload this page.