- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
[WIP] test prepare_latents for ltx0.95 #10976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
    
  
     Merged
                    Changes from 1 commit
      Commits
    
    
            Show all changes
          
          
            18 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      e202e46
              
                up
              
              
                yiyixuxu 267583a
              
                up
              
              
                yiyixuxu 16c1467
              
                up
              
              
                yiyixuxu a098d94
              
                Update src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
              
              
                yiyixuxu d8bd10e
              
                up
              
              
                yiyixuxu 6f7e837
              
                Merge branch 'ltx-95-latents-yiyi' of github.com:huggingface/diffuser…
              
              
                yiyixuxu cbc035d
              
                make it work
              
              
                yiyixuxu 1fdebea
              
                up
              
              
                yiyixuxu 0cc1905
              
                update conversion script
              
              
                yiyixuxu 7c2151f
              
                up
              
              
                yiyixuxu 353728a
              
                up
              
              
                yiyixuxu d85d21c
              
                up
              
              
                yiyixuxu 445cf58
              
                up
              
              
                yiyixuxu fb46d21
              
                up more
              
              
                yiyixuxu 64df9af
              
                up
              
              
                yiyixuxu 00e9670
              
                Apply suggestions from code review
              
              
                yiyixuxu ed2f7e3
              
                add docs tests + more refactor
              
              
                yiyixuxu b98d69c
              
                up
              
              
                yiyixuxu File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -437,8 +437,8 @@ def check_inputs( | |
| ) | ||
|  | ||
| @staticmethod | ||
| # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents | ||
| def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: | ||
| # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents | ||
| def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor: | ||
| # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. | ||
| # The patch dimensions are then permuted and collapsed into the channel dimension of shape: | ||
| # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). | ||
|  | @@ -447,6 +447,16 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int | |
| post_patch_num_frames = num_frames // patch_size_t | ||
| post_patch_height = height // patch_size | ||
| post_patch_width = width // patch_size | ||
|  | ||
| latent_sample_coords = torch.meshgrid( | ||
| torch.arange(0, num_frames, patch_size_t, device=device), | ||
| torch.arange(0, height, patch_size, device=device), | ||
| torch.arange(0, width, patch_size, device=device), | ||
| ) | ||
|         
                  yiyixuxu marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| latent_sample_coords = torch.stack(latent_sample_coords, dim=0) | ||
| latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) | ||
| latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width) | ||
|  | ||
| latents = latents.reshape( | ||
| batch_size, | ||
| -1, | ||
|  | @@ -458,7 +468,7 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int | |
| patch_size, | ||
| ) | ||
| latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) | ||
| return latents | ||
| return latents, latent_coords | ||
|  | ||
| @staticmethod | ||
| # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents | ||
|  | @@ -544,6 +554,25 @@ def _prepare_non_first_frame_conditioning( | |
|  | ||
| return latents, condition_latents, condition_latent_frames_mask | ||
|  | ||
| def trim_conditioning_sequence( | ||
| self, start_frame: int, sequence_num_frames: int, target_num_frames: int | ||
| ): | ||
| """ | ||
| Trim a conditioning sequence to the allowed number of frames. | ||
| Args: | ||
| start_frame (int): The target frame number of the first frame in the sequence. | ||
| sequence_num_frames (int): The number of frames in the sequence. | ||
| target_num_frames (int): The target number of frames in the generated video. | ||
| Returns: | ||
| int: updated sequence length | ||
| """ | ||
| scale_factor = self.vae_temporal_compression_ratio | ||
| num_frames = min(sequence_num_frames, target_num_frames - start_frame) | ||
| # Trim down to a multiple of temporal_scale_factor frames plus 1 | ||
| num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 | ||
| return num_frames | ||
|  | ||
|  | ||
| def prepare_latents( | ||
| self, | ||
| conditions: Union[LTXVideoCondition, List[LTXVideoCondition]], | ||
|  | @@ -579,7 +608,11 @@ def prepare_latents( | |
| if condition.image is not None: | ||
| data = self.video_processor.preprocess(condition.image, height, width).unsqueeze(2) | ||
| elif condition.video is not None: | ||
| data = self.video_processor.preprocess_video(condition.vide, height, width) | ||
| data = self.video_processor.preprocess_video(condition.video, height, width) | ||
| num_frames_input = data.size(2) | ||
| num_frames_output = self.trim_conditioning_sequence(condition.frame_index, num_frames_input, num_frames) | ||
| data = data[:, :, :num_frames_output] | ||
| data = data.to(device, dtype=dtype) | ||
| else: | ||
| raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.") | ||
|  | ||
|  | @@ -599,6 +632,7 @@ def prepare_latents( | |
| latents[:, :, :num_cond_frames], condition_latents, condition.strength | ||
| ) | ||
| condition_latent_frames_mask[:, :num_cond_frames] = condition.strength | ||
| # YiYi TODO: code path not tested | ||
| else: | ||
| if num_data_frames > 1: | ||
| ( | ||
|  | @@ -617,8 +651,8 @@ def prepare_latents( | |
| noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) | ||
| condition_latents = torch.lerp(noise, condition_latents, condition.strength) | ||
| c_nlf = condition_latents.shape[2] | ||
| condition_latents = self._pack_latents( | ||
| condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||
| condition_latents, condition_latent_coords = self._pack_latents( | ||
| condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device | ||
| ) | ||
| conditioning_mask = torch.full( | ||
| condition_latents.shape[:2], condition.strength, device=device, dtype=dtype | ||
|  | @@ -642,23 +676,22 @@ def prepare_latents( | |
| extra_conditioning_rope_interpolation_scales.append(rope_interpolation_scale) | ||
| extra_conditioning_mask.append(conditioning_mask) | ||
|  | ||
| latents = self._pack_latents( | ||
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||
| latents, latent_coords = self._pack_latents( | ||
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device | ||
| ) | ||
| rope_interpolation_scale = [ | ||
| self.vae_temporal_compression_ratio / frame_rate, | ||
| self.vae_spatial_compression_ratio, | ||
| self.vae_spatial_compression_ratio, | ||
| ] | ||
| rope_interpolation_scale = ( | ||
| torch.tensor(rope_interpolation_scale, device=device, dtype=dtype) | ||
| .view(-1, 1, 1, 1, 1) | ||
| .repeat(1, 1, num_latent_frames, latent_height, latent_width) | ||
| pixel_coords = ( | ||
| latent_coords | ||
| * torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=latent_coords.device)[None, :, None] | ||
| ) | ||
| conditioning_mask = self._pack_latents( | ||
| conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||
| pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0) | ||
|  | ||
| rope_interpolation_scale = pixel_coords | ||
|  | ||
| conditioning_mask = condition_latent_frames_mask.gather( | ||
| 1, latent_coords[:, 0] | ||
| ) | ||
|  | ||
| # YiYi TODO: code path not tested | ||
| if len(extra_conditioning_latents) > 0: | ||
| latents = torch.cat([*extra_conditioning_latents, latents], dim=1) | ||
| rope_interpolation_scale = torch.cat( | ||
|  | @@ -864,7 +897,7 @@ def __call__( | |
| frame_rate, | ||
| generator, | ||
| device, | ||
| torch.float32, | ||
| prompt_embeds.dtype, | ||
|          | ||
| ) | ||
| init_latents = latents.clone() | ||
|  | ||
|  | @@ -955,8 +988,8 @@ def __call__( | |
| pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0] | ||
|  | ||
| latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) | ||
| latents = self._pack_latents( | ||
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||
| latents, _ = self._pack_latents( | ||
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device | ||
| ) | ||
|  | ||
| if callback_on_step_end is not None: | ||
|  | ||
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.