@@ -136,13 +136,22 @@ def forward(
136136 core_attn_out = self ._checkpointed_attention_forward (
137137 query , key , value , attention_mask , packed_seq_params = packed_seq_params )
138138 else :
139+ extra_kwargs = {}
140+ if self .config .experimental_attention_variant == 'dsa' :
141+ # For dsa we need to pass in the original hidden states and the compressed
142+ # query representation.
143+ extra_kwargs ['x' ] = hidden_states
144+ extra_kwargs ['qr' ] = q_compressed
145+ # for easy injection of rotary_pos_emb (patch)
146+ packed_seq_params = (packed_seq_params , rotary_pos_emb )
139147 core_attn_out = self .core_attention (
140148 query ,
141149 key ,
142150 value ,
143151 attention_mask ,
144152 packed_seq_params = packed_seq_params ,
145153 attn_mask_type = attn_mask_type ,
154+ ** extra_kwargs ,
146155 )
147156 if thd_qkv_format :
148157 if core_attn_out .ndim == 2 :
@@ -789,6 +798,152 @@ def _new_load_inline(*args, **kwargs):
789798 cpp_extension .load_inline = load_inline
790799
791800
801+ def _patch_dsa ():
802+ from megatron .core .models .common .embeddings .rope_utils import apply_rotary_pos_emb
803+ from megatron .core .models .gpt import experimental_attention_variant_module_specs
804+ from megatron .core .packed_seq_params import PackedSeqParams
805+ from megatron .core .tensor_parallel .mappings import gather_from_sequence_parallel_region
806+ from megatron .core .transformer .experimental_attention_variant .dsa import rotate_activation
807+ DSAIndexer = experimental_attention_variant_module_specs .DSAIndexer
808+
809+ class NewDSAIndexer (DSAIndexer ):
810+
811+ def forward_before_topk (
812+ self ,
813+ x : torch .Tensor ,
814+ qr : torch .Tensor ,
815+ packed_seq_params : Optional [PackedSeqParams ] = None ,
816+ ):
817+ """All computations before topk."""
818+ # =========================================
819+ # Gather inputs if sp is enabled
820+ # =========================================
821+ packed_seq_params , rotary_pos_emb = packed_seq_params # patch
822+ assert packed_seq_params is None , 'Packed sequence is not supported for DSAttention'
823+
824+ if self .config .sequence_parallel and self .pg_collection .tp .size () > 1 :
825+ x = gather_from_sequence_parallel_region (x , group = self .pg_collection .tp )
826+ qr = gather_from_sequence_parallel_region (qr , group = self .pg_collection .tp )
827+
828+ # =========================================
829+ # Get sequence length and batch size
830+ # =========================================
831+ seqlen , bsz , _ = x .size ()
832+
833+ # =========================================
834+ # q linear and apply rope to q
835+ # =========================================
836+ # [seqlen, batch, q_lora_rank] -> [seqlen, batch, index_n_heads * index_head_dim]
837+ q , _ = self .linear_wq_b (qr )
838+ # [seqlen, batch, index_n_heads * index_head_dim]
839+ # -> [seqlen, batch, index_n_heads, index_head_dim]
840+ q = q .reshape (seqlen , bsz , self .index_n_heads , self .index_head_dim )
841+ q = self ._apply_rope (q , rotary_pos_emb ) # mscale will be passed in by patch
842+
843+ # =========================================
844+ # k linear and apply rope to k
845+ # =========================================
846+ # [seqlen, batch, hidden_size] -> [seqlen, batch, index_head_dim]
847+ k , _ = self .linear_wk (x )
848+ k = self .k_norm (k )
849+ # [seqlen, batch, index_head_dim] -> [seqlen, batch, 1, index_head_dim]
850+ k = k .reshape (seqlen , bsz , 1 , self .index_head_dim )
851+ k = self ._apply_rope (k , rotary_pos_emb )
852+ # [seqlen, batch, 1, index_head_dim] -> [seqlen, batch, index_head_dim]
853+ k = k .reshape (seqlen , bsz , self .index_head_dim )
854+
855+ # =========================================
856+ # Rotate activation
857+ # =========================================
858+ q = rotate_activation (q )
859+ k = rotate_activation (k )
860+
861+ # =========================================
862+ # Prepare weights for index scores
863+ # =========================================
864+ # [seqlen, batch, hidden_size] -> [seqlen, batch, index_n_heads]
865+ weights , _ = self .linear_weights_proj (x )
866+ weights = weights * (self .index_n_heads ** - 0.5 ) * self .softmax_scale
867+
868+ return q , k , weights
869+
870+ def _apply_rope (self , x : torch .Tensor , rotary_pos_emb : torch .Tensor ):
871+ """Apply RoPE to the input tensor."""
872+ # x_nope [seqlen, batch, *, index_head_dim - qk_pos_emb_head_dim]
873+ # x_pe [seqlen, batch, *, qk_pos_emb_head_dim]
874+ x_pe , x_nope = torch .split (
875+ x , [self .index_head_dim - self .qk_pos_emb_head_dim , self .qk_pos_emb_head_dim ], dim = - 1 )
876+ x_pe = apply_rotary_pos_emb (
877+ x_pe ,
878+ rotary_pos_emb ,
879+ config = self .config ,
880+ cu_seqlens = None ,
881+ cp_group = self .pg_collection .cp ,
882+ )
883+ # [seqlen, batch, *, index_head_dim]
884+ x = torch .cat ([x_pe , x_nope ], dim = - 1 )
885+ return x
886+
887+ def forward_with_scores (
888+ self ,
889+ x : torch .Tensor ,
890+ qr : torch .Tensor ,
891+ mask : Optional [torch .Tensor ] = None ,
892+ packed_seq_params : Optional [PackedSeqParams ] = None ,
893+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
894+ """
895+ Forward pass for DSA Indexer that returns both index scores and top-k indices.
896+
897+ This is used when KL loss is enabled to compare indexer scores with true attention scores.
898+
899+ Args:
900+ x: hidden states [seqlen, batch, hidden_size].
901+ qr: Low-rank query tensor [seqlen, batch, q_lora_rank].
902+ mask: Attention mask [batch, seqlen, seqlen].
903+ packed_seq_params: Packed sequence parameters for variable length sequences.
904+
905+ Returns:
906+ index_scores: Index scores [batch, seqlen, seqlen].
907+ topk_indices: Top-k indices [batch, seqlen, index_topk].
908+ """
909+ try :
910+ from megatron .core .transformer .experimental_attention_variant .dsa import fused_qk_topk_naive
911+ except ImportError :
912+ raise ImportError ('fused_qk_topk_naive is not available. Please install megatron-core from source. '
913+ '`pip install git+https://github.com/NVIDIA/Megatron-LM.git`' )
914+ # [seqlen, batch, index_n_heads * index_head_dim]
915+ # [seqlen, batch, index_head_dim]
916+ # [seqlen, batch, index_n_heads]
917+ q , k , weights = self .forward_before_topk (x , qr , packed_seq_params )
918+
919+ # [batch, seqlen, seqlen], [batch, seqlen, index_topk]
920+ index_scores , topk_indices = fused_qk_topk_naive (q , k , weights , self .index_topk , mask )
921+
922+ return index_scores , topk_indices
923+
924+ def forward (self ,
925+ x : torch .Tensor ,
926+ qr : torch .Tensor ,
927+ mask : Optional [torch .Tensor ] = None ,
928+ packed_seq_params : Optional [PackedSeqParams ] = None ):
929+ """
930+ Forward pass for DSA Indexer.
931+
932+ Args:
933+ x: hidden states [seqlen, batch, hidden_size].
934+ qr: Low-rank query tensor [seqlen, batch, q_lora_rank].
935+ mask: Attention mask [batch, seqlen, seqlen].
936+ packed_seq_params: Packed sequence parameters for variable length sequences.
937+
938+ Returns:
939+ topk_indices: Top-k indices for sparse attention [batch, seqlen, index_topk].
940+ """
941+ _ , topk_indices = self .forward_with_scores (x , qr , mask , packed_seq_params )
942+ return topk_indices
943+
944+ experimental_attention_variant_module_specs .DSAIndexer = NewDSAIndexer
945+
946+
792947def init_megatron_env ():
793948 os .environ .pop ('VLLM_USE_MODELSCOPE' , None )
794949 logging_level = logging .root .level
@@ -804,6 +959,10 @@ def init_megatron_env():
804959 _patch_mrope ()
805960 _patch__write_item ()
806961 _patch_mtp ()
962+ try :
963+ _patch_dsa ()
964+ except ImportError :
965+ pass
807966 logging .root .setLevel (logging_level ) # revert logger level
808967 from swift .megatron import tuners # patch lora
809968 try :
0 commit comments