77import torch .nn as nn
88import torch .nn .functional as F
99
10- # TODO (nvchenghaoz): Remove related kernels once we have a backend-specific implementation for attention.
10+
11+ def _apply_logit_softcapping (attn_scores : torch .Tensor , logit_cap : Optional [float ]) -> torch .Tensor :
12+ """Apply logit softcapping using the formula: logit_cap * tanh(logits / logit_cap)"""
13+ if logit_cap is not None and logit_cap > 0.0 :
14+ return logit_cap * torch .tanh (attn_scores / logit_cap )
15+ return attn_scores
16+
17+
18+ def _convert_boolean_mask_to_float (attn_mask : torch .Tensor , dtype : torch .dtype ) -> torch .Tensor :
19+ """Convert boolean attention mask to floating point mask.
20+ Args:
21+ attn_mask: Boolean tensor where True allows attention, False blocks it
22+ dtype: Target dtype for the output mask
23+ Returns:
24+ Floating point mask where True -> 1.0, False -> -inf
25+ """
26+ if attn_mask .dtype == torch .bool :
27+ float_mask = torch .zeros_like (attn_mask , dtype = dtype )
28+ float_mask = float_mask .masked_fill (attn_mask , 1.0 ) # True -> 1.0
29+ float_mask = float_mask .masked_fill (~ attn_mask , float ("-inf" )) # False -> -inf
30+ return float_mask
31+ return attn_mask
1132
1233
1334@torch .library .custom_op ("auto_deploy::torch_attention_repeat_kv" , mutates_args = ())
@@ -77,19 +98,96 @@ def grouped_sdpa(
7798 dropout_p : float = 0.0 ,
7899 is_causal : bool = False ,
79100 scale : Optional [float ] = None ,
101+ sinks : Optional [torch .Tensor ] = None ,
102+ sliding_window : Optional [int ] = None ,
103+ logit_cap : Optional [float ] = None ,
80104) -> torch .Tensor :
81- """SDPA attention that can handle GQA."""
105+ """SDPA attention that can handle GQA. Expects bnsd format inputs."""
106+ b , n_heads , s_q , head_dim = query .shape # bnsd format: [batch, num_heads, seq_len, head_dim]
107+ _ , n_kv_heads , s_k , _ = key .shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]
108+
109+ # Inputs are already in bnsd format, no need to transpose
110+ query_t = query # [b, n_heads, s_q, head_dim]
111+ key_t = key # [b, n_kv_heads, s_k, head_dim]
112+ value_t = value # [b, n_kv_heads, s_k, v_head_dim]
113+
114+ # Handle GQA by repeating KV if needed
115+ if n_heads != n_kv_heads :
116+ n_rep = n_heads // n_kv_heads
117+ key_t = repeat_kv (key_t , n_rep )
118+ value_t = repeat_kv (value_t , n_rep )
119+
120+ # Set scale
121+ if scale is None :
122+ scale = 1.0 / math .sqrt (head_dim )
123+
124+ # Compute attention scores: Q @ K^T
125+ attn_scores = torch .matmul (query_t , key_t .transpose (- 2 , - 1 )) * scale # [b, n_heads, s_q, s_k]
126+
127+ # Apply attention mask if provided
128+ if attn_mask is not None :
129+ # Convert boolean mask to float if needed
130+ attn_mask = _convert_boolean_mask_to_float (attn_mask , attn_scores .dtype )
131+ attn_scores = attn_scores + attn_mask
132+
133+ # Apply causal mask if specified and only during the context phase
134+ if is_causal and s_q == s_k : # Only apply causal mask during context processing
135+ causal_mask = torch .triu (
136+ torch .ones (s_q , s_k , device = query .device , dtype = torch .bool ),
137+ diagonal = 1 , # Use diagonal=1 for standard causal masking
138+ )
139+ attn_scores .masked_fill_ (causal_mask .unsqueeze (0 ).unsqueeze (0 ), float ("-inf" ))
140+
141+ # Apply sliding window mask if specified
142+ if sliding_window is not None and sliding_window > 0 :
143+ # Handle position calculation for both context and generation phases
144+ if s_q == s_k :
145+ # Context phase: standard position calculation
146+ query_positions = torch .arange (s_q , device = query .device )
147+ key_positions = torch .arange (s_k , device = query .device )
148+ else :
149+ # Generation phase: query is at position s_k (after the cache)
150+ query_positions = torch .arange (s_k , s_k + s_q , device = query .device ) # [s_k] for s_q=1
151+ key_positions = torch .arange (s_k , device = query .device ) # [0,1,2,...,s_k-1]
152+
153+ # Create position difference matrix: query_pos - key_pos
154+ pos_diff = query_positions .unsqueeze (1 ) - key_positions .unsqueeze (0 ) # [s_q, s_k]
155+
156+ # Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
157+ sliding_window_mask = (pos_diff < 0 ) | (pos_diff >= sliding_window ) # [s_q, s_k]
158+ attn_scores .masked_fill_ (sliding_window_mask .unsqueeze (0 ).unsqueeze (0 ), float ("-inf" ))
159+
160+ # Apply logit softcapping if enabled
161+ attn_scores = _apply_logit_softcapping (attn_scores , logit_cap )
162+
163+ # Apply sinks if provided
164+ if sinks is not None :
165+ # Concatenate sinks to attention scores following the reference implementation
166+ # sinks should have n_heads elements, each head gets its own sink value
167+ # Expand sinks to [b, n_heads, s_q, 1] - one sink column per head
168+ sinks_expanded = sinks .reshape (1 , - 1 , 1 , 1 ).expand (
169+ b , n_heads , s_q , 1
170+ ) # [b, n_heads, s_q, 1]
171+
172+ # Concatenate along the key dimension (last dimension)
173+ logits_max = torch .max (attn_scores , dim = - 1 , keepdim = True ).values
174+ sinks = torch .exp (sinks_expanded - logits_max )
175+ unnormalized_scores = torch .exp (attn_scores - logits_max )
176+ normalizer = unnormalized_scores .sum (dim = - 1 , keepdim = True ) + sinks
177+ scores = unnormalized_scores / normalizer
178+ # Use only the non-sink portion for computing output
179+ # We added exactly 1 column, so remove exactly 1 column
180+ attn_out = torch .matmul (scores , value_t ) # [b, n_heads, s_q, v_head_dim]
181+ else :
182+ attn_weights = torch .softmax (attn_scores , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
183+ attn_out = torch .matmul (attn_weights , value_t ) # [b, n_heads, s_q, v_head_dim]
82184
83- return F .scaled_dot_product_attention (
84- query .contiguous (),
85- key .contiguous (),
86- value .contiguous (),
87- attn_mask = attn_mask ,
88- dropout_p = dropout_p ,
89- is_causal = is_causal ,
90- scale = scale ,
91- enable_gqa = True ,
92- )
185+ # Apply dropout if specified
186+ if dropout_p > 0.0 :
187+ attn_out = F .dropout (attn_out , p = dropout_p , training = False )
188+
189+ # Return in bnsd format (same as input format)
190+ return attn_out
93191
94192
95193@grouped_sdpa .register_fake
@@ -101,16 +199,19 @@ def grouped_sdpa_fake(
101199 dropout_p = 0.0 ,
102200 is_causal = False ,
103201 scale = None ,
202+ sinks = None ,
203+ sliding_window = None ,
204+ logit_cap = None ,
104205):
105206 """Fake implementation of grouped SDPA."""
106207 return query .new_empty (* query .shape [:- 1 ], value .shape [- 1 ]).contiguous ()
107208
108209
109210@torch .library .custom_op ("auto_deploy::torch_attention_bsnd_grouped_sdpa" , mutates_args = ())
110211def bsnd_grouped_sdpa (
111- query : torch .Tensor , # layout: [b, n, s_q , d]
112- key : torch .Tensor , # layout: [b, n, s_k , d]
113- value : torch .Tensor , # layout: [b, n, s_k , d]
212+ query : torch .Tensor , # layout: [b, s_q, n , d]
213+ key : torch .Tensor , # layout: [b, s_k, n , d]
214+ value : torch .Tensor , # layout: [b, s_k, n , d]
114215 attn_mask : Optional [torch .Tensor ] = None , # layout: [b, n, s_q, s_k]
115216 dropout_p : float = 0.0 ,
116217 is_causal : bool = False ,
@@ -124,14 +225,16 @@ def bsnd_grouped_sdpa(
124225 Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the
125226 original sdpa op!
126227 """
127- # let's transpose to bnsd so we can use the grouped sdpa
128- query = query .transpose (1 , 2 ).contiguous ()
129- key = key .transpose (1 , 2 ).contiguous ()
130- value = value .transpose (1 , 2 ).contiguous ()
131-
132- out = grouped_sdpa (query , key , value , attn_mask , dropout_p , is_causal , scale )
133-
134- # let's transpose back to bnsd
228+ # Transpose inputs to bnsd format for grouped_sdpa
229+ query = query .transpose (1 , 2 ).contiguous () # [b, s_q, n, d] -> [b, n, s_q, d]
230+ key = key .transpose (1 , 2 ).contiguous () # [b, s_k, n, d] -> [b, n, s_k, d]
231+ value = value .transpose (1 , 2 ).contiguous () # [b, s_k, n, d] -> [b, n, s_k, d]
232+
233+ # Call grouped_sdpa with bnsd inputs
234+ out = grouped_sdpa (
235+ query , key , value , attn_mask , dropout_p , is_causal , scale , sinks , sliding_window , logit_cap
236+ )
237+ # Transpose back to bsnd format
135238 return out .transpose (1 , 2 ).contiguous ()
136239
137240
0 commit comments