77import torch .nn as nn
88import torch .nn .functional as F
99
10-
11- def _apply_logit_softcapping (attn_scores : torch .Tensor , logit_cap : Optional [float ]) -> torch .Tensor :
12- """Apply logit softcapping using the formula: logit_cap * tanh(logits / logit_cap)"""
13- if logit_cap is not None and logit_cap > 0.0 :
14- return logit_cap * torch .tanh (attn_scores / logit_cap )
15- return attn_scores
10+ # TODO (nvchenghaoz): Remove related kernels once we have a backend-specific implementation for attention.
1611
1712
1813@torch .library .custom_op ("auto_deploy::torch_attention_repeat_kv" , mutates_args = ())
@@ -82,86 +77,19 @@ def grouped_sdpa(
8277 dropout_p : float = 0.0 ,
8378 is_causal : bool = False ,
8479 scale : Optional [float ] = None ,
85- sinks : Optional [torch .Tensor ] = None ,
86- sliding_window : Optional [int ] = None ,
87- logit_cap : Optional [float ] = None ,
8880) -> torch .Tensor :
89- """SDPA attention that can handle GQA. Expects bnsd format inputs."""
90- b , n_heads , s_q , head_dim = query .shape # bnsd format: [batch, num_heads, seq_len, head_dim]
91- _ , n_kv_heads , s_k , _ = key .shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]
92-
93- # Inputs are already in bnsd format, no need to transpose
94- query_t = query # [b, n_heads, s_q, head_dim]
95- key_t = key # [b, n_kv_heads, s_k, head_dim]
96- value_t = value # [b, n_kv_heads, s_k, v_head_dim]
97-
98- # Handle GQA by repeating KV if needed
99- if n_heads != n_kv_heads :
100- n_rep = n_heads // n_kv_heads
101- key_t = repeat_kv (key_t , n_rep ) # [b, n_heads, s_k, head_dim]
102- value_t = repeat_kv (value_t , n_rep ) # [b, n_heads, s_k, v_head_dim]
103-
104- # Set scale
105- if scale is None :
106- scale = 1.0 / math .sqrt (head_dim )
107-
108- # Compute attention scores: Q @ K^T
109- attn_scores = torch .matmul (query_t , key_t .transpose (- 2 , - 1 )) * scale # [b, n_heads, s_q, s_k]
110-
111- # Apply attention mask if provided
112- if attn_mask is not None :
113- attn_scores = attn_scores + attn_mask
114-
115- # Apply causal mask if specified
116- if is_causal :
117- causal_mask = torch .triu (
118- torch .ones (s_q , s_k , device = query .device , dtype = torch .bool ),
119- diagonal = s_k - s_q + 1 ,
120- )
121- attn_scores .masked_fill_ (causal_mask .unsqueeze (0 ).unsqueeze (0 ), float ("-inf" ))
122-
123- # Apply sliding window mask if specified
124- if sliding_window is not None and sliding_window > 0 :
125- # Create sliding window mask: each query position i can only attend to keys in [i-window_size+1, i]
126- query_positions = torch .arange (s_q , device = query .device ) # [s_q]
127- key_positions = torch .arange (s_k , device = query .device ) # [s_k]
128-
129- # Create position difference matrix: query_pos - key_pos
130- pos_diff = query_positions .unsqueeze (1 ) - key_positions .unsqueeze (0 ) # [s_q, s_k]
131-
132- # Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
133- sliding_window_mask = (pos_diff < 0 ) | (pos_diff >= sliding_window ) # [s_q, s_k]
134- attn_scores .masked_fill_ (sliding_window_mask .unsqueeze (0 ).unsqueeze (0 ), float ("-inf" ))
135-
136- # Apply logit softcapping if enabled
137- attn_scores = _apply_logit_softcapping (attn_scores , logit_cap )
138-
139- # Apply sinks if provided
140- if sinks is not None :
141- # Concatenate sinks to attention scores following the reference implementation
142- # sinks should have n_heads elements, each head gets its own sink value
143- # Expand sinks to [b, n_heads, s_q, 1] - one sink column per head
144- sinks_expanded = sinks .reshape (1 , - 1 , 1 , 1 ).expand (
145- b , n_heads , s_q , 1
146- ) # [b, n_heads, s_q, 1]
147-
148- # Concatenate along the key dimension (last dimension)
149- attn_weights = torch .cat ([attn_scores , sinks_expanded ], dim = - 1 )
150- attn_weights = torch .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
151-
152- # Use only the non-sink portion for computing output
153- # We added exactly 1 column, so remove exactly 1 column
154- attn_out = torch .matmul (attn_weights [..., :- 1 ], value_t ) # [b, n_heads, s_q, v_head_dim]
155- else :
156- attn_weights = torch .softmax (attn_scores , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
157- attn_out = torch .matmul (attn_weights , value_t ) # [b, n_heads, s_q, v_head_dim]
81+ """SDPA attention that can handle GQA."""
15882
159- # Apply dropout if specified
160- if dropout_p > 0.0 :
161- attn_out = F .dropout (attn_out , p = dropout_p , training = False )
162-
163- # Return in bnsd format (same as input format)
164- return attn_out
83+ return F .scaled_dot_product_attention (
84+ query .contiguous (),
85+ key .contiguous (),
86+ value .contiguous (),
87+ attn_mask = attn_mask ,
88+ dropout_p = dropout_p ,
89+ is_causal = is_causal ,
90+ scale = scale ,
91+ enable_gqa = True ,
92+ )
16593
16694
16795@grouped_sdpa .register_fake
@@ -173,19 +101,16 @@ def grouped_sdpa_fake(
173101 dropout_p = 0.0 ,
174102 is_causal = False ,
175103 scale = None ,
176- sinks = None ,
177- sliding_window = None ,
178- logit_cap = None ,
179104):
180105 """Fake implementation of grouped SDPA."""
181106 return query .new_empty (* query .shape [:- 1 ], value .shape [- 1 ]).contiguous ()
182107
183108
184109@torch .library .custom_op ("auto_deploy::torch_attention_bsnd_grouped_sdpa" , mutates_args = ())
185110def bsnd_grouped_sdpa (
186- query : torch .Tensor , # layout: [b, s_q, n , d]
187- key : torch .Tensor , # layout: [b, s_k, n , d]
188- value : torch .Tensor , # layout: [b, s_k, n , d]
111+ query : torch .Tensor , # layout: [b, n, s_q , d]
112+ key : torch .Tensor , # layout: [b, n, s_k , d]
113+ value : torch .Tensor , # layout: [b, n, s_k , d]
189114 attn_mask : Optional [torch .Tensor ] = None , # layout: [b, n, s_q, s_k]
190115 dropout_p : float = 0.0 ,
191116 is_causal : bool = False ,
@@ -199,16 +124,14 @@ def bsnd_grouped_sdpa(
199124 Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the
200125 original sdpa op!
201126 """
202- # Transpose inputs to bnsd format for grouped_sdpa
203- query = query .transpose (1 , 2 ) # [b, s_q, n, d] -> [b, n, s_q, d]
204- key = key .transpose (1 , 2 ) # [b, s_k, n, d] -> [b, n, s_k, d]
205- value = value .transpose (1 , 2 ) # [b, s_k, n, d] -> [b, n, s_k, d]
206-
207- # Call grouped_sdpa with bnsd inputs
208- out = grouped_sdpa (
209- query , key , value , attn_mask , dropout_p , is_causal , scale , sinks , sliding_window , logit_cap
210- )
211- # Transpose back to bsnd format
127+ # let's transpose to bnsd so we can use the grouped sdpa
128+ query = query .transpose (1 , 2 ).contiguous ()
129+ key = key .transpose (1 , 2 ).contiguous ()
130+ value = value .transpose (1 , 2 ).contiguous ()
131+
132+ out = grouped_sdpa (query , key , value , attn_mask , dropout_p , is_causal , scale )
133+
134+ # let's transpose back to bnsd
212135 return out .transpose (1 , 2 ).contiguous ()
213136
214137
0 commit comments