@@ -96,8 +96,7 @@ def compute_mixed_rotation(
9696 num_heads: int
9797
9898 Returns:
99- freqs_cos: [N, num_heads, num_freqs] - cosine components
100- freqs_sin: [N, num_heads, num_freqs] - sine components
99+ freqs_cos: [N, num_heads, num_freqs] - cosine components freqs_sin: [N, num_heads, num_freqs] - sine components
101100 """
102101 with torch .autocast ("cuda" , enabled = False ):
103102 assert freqs .ndim == 3
@@ -470,8 +469,7 @@ def forward(
470469 num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
471470
472471 Returns:
473- x: (B, N, dim) tensor of visual tokens after block
474- y: (B, L, dim) tensor of text tokens after block
472+ x: (B, N, dim) tensor of visual tokens after block y: (B, L, dim) tensor of text tokens after block
475473 """
476474 breakpoint ()
477475 N = x .size (1 )
@@ -651,7 +649,7 @@ def run_attention(
651649 breakpoint ()
652650 N = M
653651 local_heads = self .num_heads
654- local_dim = local_heads * self .head_dim
652+ # local_dim = local_heads * self.head_dim
655653 # with torch.autocast("cuda", enabled=False):
656654 # out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
657655 # qkv,
@@ -696,8 +694,8 @@ def forward(
696694 num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
697695
698696 Returns:
699- x: (B, N, dim_x) tensor of visual tokens after multi-modal attention
700- y: (B, L, dim_y) tensor of text token features after multi-modal attention
697+ x: (B, N, dim_x) tensor of visual tokens after multi-modal attention y: (B, L, dim_y) tensor of text token
698+ features after multi-modal attention
701699 """
702700 B , L , _ = y .shape
703701 _ , M , _ = x .shape
@@ -725,6 +723,7 @@ def forward(
725723 )
726724 return x , y
727725
726+
728727def apply_rotary_emb_qk_real (
729728 xqk : torch .Tensor ,
730729 freqs_cos : torch .Tensor ,
@@ -756,10 +755,10 @@ def apply_rotary_emb_qk_real(
756755 # assert out.dtype == torch.bfloat16
757756 return out
758757
758+
759759class PadSplitXY (torch .autograd .Function ):
760760 """
761- Merge heads, pad and extract visual and text tokens,
762- and split along the sequence length.
761+ Merge heads, pad and extract visual and text tokens, and split along the sequence length.
763762 """
764763
765764 @staticmethod
@@ -778,8 +777,7 @@ def forward(
778777 indices: Valid token indices out of unpacked tensor. Shape: (total,)
779778
780779 Returns:
781- x: Visual tokens. Shape: (B, N, num_heads * head_dim).
782- y: Text tokens. Shape: (B, L, num_heads * head_dim).
780+ x: Visual tokens. Shape: (B, N, num_heads * head_dim). y: Text tokens. Shape: (B, L, num_heads * head_dim).
783781 """
784782 ctx .save_for_backward (indices )
785783 ctx .B , ctx .N , ctx .L = B , N , L
@@ -788,9 +786,7 @@ def forward(
788786 # Pad sequences to (B, N + L, dim).
789787 assert indices .ndim == 1
790788 output = torch .zeros (B * (N + L ), D , device = xy .device , dtype = dtype )
791- indices = indices .unsqueeze (1 ).expand (
792- - 1 , D
793- ) # (total,) -> (total, num_heads * head_dim)
789+ indices = indices .unsqueeze (1 ).expand (- 1 , D ) # (total,) -> (total, num_heads * head_dim)
794790 output .scatter_ (0 , indices , xy )
795791 xy = output .view (B , N + L , D )
796792
@@ -801,6 +797,7 @@ def forward(
801797def pad_and_split_xy (xy , indices , B , N , L , dtype ) -> Tuple [torch .Tensor , torch .Tensor ]:
802798 return PadSplitXY .apply (xy , indices , B , N , L , dtype )
803799
800+
804801class UnifyStreams (torch .autograd .Function ):
805802 """Unify visual and text streams."""
806803
@@ -1034,7 +1031,9 @@ def forward(
10341031 Args:
10351032 x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
10361033 sigma: (B,) tensor of noise standard deviations
1037- y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
1034+ y_feat:
1035+ List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77,
1036+ y_feat_dim=2048)
10381037 y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
10391038 packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
10401039 """
0 commit comments