Skip to content

Commit 7810dd0

Browse files
[Update #89] Add torch ref attention (#107)
* attention matcher with torch._inductor pattern matcher,matching repeat kv, sdpa and group attention, update unit tests Signed-off-by: Frida Hou <[email protected]> * Update the torch ref op Signed-off-by: nvchenghaoz <[email protected]> * Revert "attention matcher with torch._inductor pattern matcher,matching repeat kv, sdpa and group attention, update unit tests" This reverts commit 5743fb3. --------- Signed-off-by: Frida Hou <[email protected]> Signed-off-by: nvchenghaoz <[email protected]> Co-authored-by: Frida Hou <[email protected]>
1 parent 4e10f76 commit 7810dd0

File tree

1 file changed

+126
-23
lines changed

1 file changed

+126
-23
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py

Lines changed: 126 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,28 @@
77
import torch.nn as nn
88
import torch.nn.functional as F
99

10-
# TODO (nvchenghaoz): Remove related kernels once we have a backend-specific implementation for attention.
10+
11+
def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[float]) -> torch.Tensor:
12+
"""Apply logit softcapping using the formula: logit_cap * tanh(logits / logit_cap)"""
13+
if logit_cap is not None and logit_cap > 0.0:
14+
return logit_cap * torch.tanh(attn_scores / logit_cap)
15+
return attn_scores
16+
17+
18+
def _convert_boolean_mask_to_float(attn_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
19+
"""Convert boolean attention mask to floating point mask.
20+
Args:
21+
attn_mask: Boolean tensor where True allows attention, False blocks it
22+
dtype: Target dtype for the output mask
23+
Returns:
24+
Floating point mask where True -> 1.0, False -> -inf
25+
"""
26+
if attn_mask.dtype == torch.bool:
27+
float_mask = torch.zeros_like(attn_mask, dtype=dtype)
28+
float_mask = float_mask.masked_fill(attn_mask, 1.0) # True -> 1.0
29+
float_mask = float_mask.masked_fill(~attn_mask, float("-inf")) # False -> -inf
30+
return float_mask
31+
return attn_mask
1132

1233

1334
@torch.library.custom_op("auto_deploy::torch_attention_repeat_kv", mutates_args=())
@@ -77,19 +98,96 @@ def grouped_sdpa(
7798
dropout_p: float = 0.0,
7899
is_causal: bool = False,
79100
scale: Optional[float] = None,
101+
sinks: Optional[torch.Tensor] = None,
102+
sliding_window: Optional[int] = None,
103+
logit_cap: Optional[float] = None,
80104
) -> torch.Tensor:
81-
"""SDPA attention that can handle GQA."""
105+
"""SDPA attention that can handle GQA. Expects bnsd format inputs."""
106+
b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim]
107+
_, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]
108+
109+
# Inputs are already in bnsd format, no need to transpose
110+
query_t = query # [b, n_heads, s_q, head_dim]
111+
key_t = key # [b, n_kv_heads, s_k, head_dim]
112+
value_t = value # [b, n_kv_heads, s_k, v_head_dim]
113+
114+
# Handle GQA by repeating KV if needed
115+
if n_heads != n_kv_heads:
116+
n_rep = n_heads // n_kv_heads
117+
key_t = repeat_kv(key_t, n_rep)
118+
value_t = repeat_kv(value_t, n_rep)
119+
120+
# Set scale
121+
if scale is None:
122+
scale = 1.0 / math.sqrt(head_dim)
123+
124+
# Compute attention scores: Q @ K^T
125+
attn_scores = torch.matmul(query_t, key_t.transpose(-2, -1)) * scale # [b, n_heads, s_q, s_k]
126+
127+
# Apply attention mask if provided
128+
if attn_mask is not None:
129+
# Convert boolean mask to float if needed
130+
attn_mask = _convert_boolean_mask_to_float(attn_mask, attn_scores.dtype)
131+
attn_scores = attn_scores + attn_mask
132+
133+
# Apply causal mask if specified and only during the context phase
134+
if is_causal and s_q == s_k: # Only apply causal mask during context processing
135+
causal_mask = torch.triu(
136+
torch.ones(s_q, s_k, device=query.device, dtype=torch.bool),
137+
diagonal=1, # Use diagonal=1 for standard causal masking
138+
)
139+
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
140+
141+
# Apply sliding window mask if specified
142+
if sliding_window is not None and sliding_window > 0:
143+
# Handle position calculation for both context and generation phases
144+
if s_q == s_k:
145+
# Context phase: standard position calculation
146+
query_positions = torch.arange(s_q, device=query.device)
147+
key_positions = torch.arange(s_k, device=query.device)
148+
else:
149+
# Generation phase: query is at position s_k (after the cache)
150+
query_positions = torch.arange(s_k, s_k + s_q, device=query.device) # [s_k] for s_q=1
151+
key_positions = torch.arange(s_k, device=query.device) # [0,1,2,...,s_k-1]
152+
153+
# Create position difference matrix: query_pos - key_pos
154+
pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) # [s_q, s_k]
155+
156+
# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
157+
sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window) # [s_q, s_k]
158+
attn_scores.masked_fill_(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
159+
160+
# Apply logit softcapping if enabled
161+
attn_scores = _apply_logit_softcapping(attn_scores, logit_cap)
162+
163+
# Apply sinks if provided
164+
if sinks is not None:
165+
# Concatenate sinks to attention scores following the reference implementation
166+
# sinks should have n_heads elements, each head gets its own sink value
167+
# Expand sinks to [b, n_heads, s_q, 1] - one sink column per head
168+
sinks_expanded = sinks.reshape(1, -1, 1, 1).expand(
169+
b, n_heads, s_q, 1
170+
) # [b, n_heads, s_q, 1]
171+
172+
# Concatenate along the key dimension (last dimension)
173+
logits_max = torch.max(attn_scores, dim=-1, keepdim=True).values
174+
sinks = torch.exp(sinks_expanded - logits_max)
175+
unnormalized_scores = torch.exp(attn_scores - logits_max)
176+
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
177+
scores = unnormalized_scores / normalizer
178+
# Use only the non-sink portion for computing output
179+
# We added exactly 1 column, so remove exactly 1 column
180+
attn_out = torch.matmul(scores, value_t) # [b, n_heads, s_q, v_head_dim]
181+
else:
182+
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype)
183+
attn_out = torch.matmul(attn_weights, value_t) # [b, n_heads, s_q, v_head_dim]
82184

83-
return F.scaled_dot_product_attention(
84-
query.contiguous(),
85-
key.contiguous(),
86-
value.contiguous(),
87-
attn_mask=attn_mask,
88-
dropout_p=dropout_p,
89-
is_causal=is_causal,
90-
scale=scale,
91-
enable_gqa=True,
92-
)
185+
# Apply dropout if specified
186+
if dropout_p > 0.0:
187+
attn_out = F.dropout(attn_out, p=dropout_p, training=False)
188+
189+
# Return in bnsd format (same as input format)
190+
return attn_out
93191

94192

95193
@grouped_sdpa.register_fake
@@ -101,16 +199,19 @@ def grouped_sdpa_fake(
101199
dropout_p=0.0,
102200
is_causal=False,
103201
scale=None,
202+
sinks=None,
203+
sliding_window=None,
204+
logit_cap=None,
104205
):
105206
"""Fake implementation of grouped SDPA."""
106207
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
107208

108209

109210
@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=())
110211
def bsnd_grouped_sdpa(
111-
query: torch.Tensor, # layout: [b, n, s_q, d]
112-
key: torch.Tensor, # layout: [b, n, s_k, d]
113-
value: torch.Tensor, # layout: [b, n, s_k, d]
212+
query: torch.Tensor, # layout: [b, s_q, n, d]
213+
key: torch.Tensor, # layout: [b, s_k, n, d]
214+
value: torch.Tensor, # layout: [b, s_k, n, d]
114215
attn_mask: Optional[torch.Tensor] = None, # layout: [b, n, s_q, s_k]
115216
dropout_p: float = 0.0,
116217
is_causal: bool = False,
@@ -124,14 +225,16 @@ def bsnd_grouped_sdpa(
124225
Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the
125226
original sdpa op!
126227
"""
127-
# let's transpose to bnsd so we can use the grouped sdpa
128-
query = query.transpose(1, 2).contiguous()
129-
key = key.transpose(1, 2).contiguous()
130-
value = value.transpose(1, 2).contiguous()
131-
132-
out = grouped_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
133-
134-
# let's transpose back to bnsd
228+
# Transpose inputs to bnsd format for grouped_sdpa
229+
query = query.transpose(1, 2).contiguous() # [b, s_q, n, d] -> [b, n, s_q, d]
230+
key = key.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
231+
value = value.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
232+
233+
# Call grouped_sdpa with bnsd inputs
234+
out = grouped_sdpa(
235+
query, key, value, attn_mask, dropout_p, is_causal, scale, sinks, sliding_window, logit_cap
236+
)
237+
# Transpose back to bsnd format
135238
return out.transpose(1, 2).contiguous()
136239

137240

0 commit comments

Comments
 (0)