Skip to content

Commit 810df7d

Browse files
committed
Revert "[None] Add the torch source implementation for new params. (#89)"
This reverts commit c245cf3.
1 parent 29bb062 commit 810df7d

File tree

1 file changed

+23
-100
lines changed

1 file changed

+23
-100
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py

Lines changed: 23 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@
77
import torch.nn as nn
88
import torch.nn.functional as F
99

10-
11-
def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[float]) -> torch.Tensor:
12-
"""Apply logit softcapping using the formula: logit_cap * tanh(logits / logit_cap)"""
13-
if logit_cap is not None and logit_cap > 0.0:
14-
return logit_cap * torch.tanh(attn_scores / logit_cap)
15-
return attn_scores
10+
# TODO (nvchenghaoz): Remove related kernels once we have a backend-specific implementation for attention.
1611

1712

1813
@torch.library.custom_op("auto_deploy::torch_attention_repeat_kv", mutates_args=())
@@ -82,86 +77,19 @@ def grouped_sdpa(
8277
dropout_p: float = 0.0,
8378
is_causal: bool = False,
8479
scale: Optional[float] = None,
85-
sinks: Optional[torch.Tensor] = None,
86-
sliding_window: Optional[int] = None,
87-
logit_cap: Optional[float] = None,
8880
) -> torch.Tensor:
89-
"""SDPA attention that can handle GQA. Expects bnsd format inputs."""
90-
b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim]
91-
_, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]
92-
93-
# Inputs are already in bnsd format, no need to transpose
94-
query_t = query # [b, n_heads, s_q, head_dim]
95-
key_t = key # [b, n_kv_heads, s_k, head_dim]
96-
value_t = value # [b, n_kv_heads, s_k, v_head_dim]
97-
98-
# Handle GQA by repeating KV if needed
99-
if n_heads != n_kv_heads:
100-
n_rep = n_heads // n_kv_heads
101-
key_t = repeat_kv(key_t, n_rep) # [b, n_heads, s_k, head_dim]
102-
value_t = repeat_kv(value_t, n_rep) # [b, n_heads, s_k, v_head_dim]
103-
104-
# Set scale
105-
if scale is None:
106-
scale = 1.0 / math.sqrt(head_dim)
107-
108-
# Compute attention scores: Q @ K^T
109-
attn_scores = torch.matmul(query_t, key_t.transpose(-2, -1)) * scale # [b, n_heads, s_q, s_k]
110-
111-
# Apply attention mask if provided
112-
if attn_mask is not None:
113-
attn_scores = attn_scores + attn_mask
114-
115-
# Apply causal mask if specified
116-
if is_causal:
117-
causal_mask = torch.triu(
118-
torch.ones(s_q, s_k, device=query.device, dtype=torch.bool),
119-
diagonal=s_k - s_q + 1,
120-
)
121-
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
122-
123-
# Apply sliding window mask if specified
124-
if sliding_window is not None and sliding_window > 0:
125-
# Create sliding window mask: each query position i can only attend to keys in [i-window_size+1, i]
126-
query_positions = torch.arange(s_q, device=query.device) # [s_q]
127-
key_positions = torch.arange(s_k, device=query.device) # [s_k]
128-
129-
# Create position difference matrix: query_pos - key_pos
130-
pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) # [s_q, s_k]
131-
132-
# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
133-
sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window) # [s_q, s_k]
134-
attn_scores.masked_fill_(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
135-
136-
# Apply logit softcapping if enabled
137-
attn_scores = _apply_logit_softcapping(attn_scores, logit_cap)
138-
139-
# Apply sinks if provided
140-
if sinks is not None:
141-
# Concatenate sinks to attention scores following the reference implementation
142-
# sinks should have n_heads elements, each head gets its own sink value
143-
# Expand sinks to [b, n_heads, s_q, 1] - one sink column per head
144-
sinks_expanded = sinks.reshape(1, -1, 1, 1).expand(
145-
b, n_heads, s_q, 1
146-
) # [b, n_heads, s_q, 1]
147-
148-
# Concatenate along the key dimension (last dimension)
149-
attn_weights = torch.cat([attn_scores, sinks_expanded], dim=-1)
150-
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
151-
152-
# Use only the non-sink portion for computing output
153-
# We added exactly 1 column, so remove exactly 1 column
154-
attn_out = torch.matmul(attn_weights[..., :-1], value_t) # [b, n_heads, s_q, v_head_dim]
155-
else:
156-
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype)
157-
attn_out = torch.matmul(attn_weights, value_t) # [b, n_heads, s_q, v_head_dim]
81+
"""SDPA attention that can handle GQA."""
15882

159-
# Apply dropout if specified
160-
if dropout_p > 0.0:
161-
attn_out = F.dropout(attn_out, p=dropout_p, training=False)
162-
163-
# Return in bnsd format (same as input format)
164-
return attn_out
83+
return F.scaled_dot_product_attention(
84+
query.contiguous(),
85+
key.contiguous(),
86+
value.contiguous(),
87+
attn_mask=attn_mask,
88+
dropout_p=dropout_p,
89+
is_causal=is_causal,
90+
scale=scale,
91+
enable_gqa=True,
92+
)
16593

16694

16795
@grouped_sdpa.register_fake
@@ -173,19 +101,16 @@ def grouped_sdpa_fake(
173101
dropout_p=0.0,
174102
is_causal=False,
175103
scale=None,
176-
sinks=None,
177-
sliding_window=None,
178-
logit_cap=None,
179104
):
180105
"""Fake implementation of grouped SDPA."""
181106
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
182107

183108

184109
@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=())
185110
def bsnd_grouped_sdpa(
186-
query: torch.Tensor, # layout: [b, s_q, n, d]
187-
key: torch.Tensor, # layout: [b, s_k, n, d]
188-
value: torch.Tensor, # layout: [b, s_k, n, d]
111+
query: torch.Tensor, # layout: [b, n, s_q, d]
112+
key: torch.Tensor, # layout: [b, n, s_k, d]
113+
value: torch.Tensor, # layout: [b, n, s_k, d]
189114
attn_mask: Optional[torch.Tensor] = None, # layout: [b, n, s_q, s_k]
190115
dropout_p: float = 0.0,
191116
is_causal: bool = False,
@@ -199,16 +124,14 @@ def bsnd_grouped_sdpa(
199124
Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the
200125
original sdpa op!
201126
"""
202-
# Transpose inputs to bnsd format for grouped_sdpa
203-
query = query.transpose(1, 2) # [b, s_q, n, d] -> [b, n, s_q, d]
204-
key = key.transpose(1, 2) # [b, s_k, n, d] -> [b, n, s_k, d]
205-
value = value.transpose(1, 2) # [b, s_k, n, d] -> [b, n, s_k, d]
206-
207-
# Call grouped_sdpa with bnsd inputs
208-
out = grouped_sdpa(
209-
query, key, value, attn_mask, dropout_p, is_causal, scale, sinks, sliding_window, logit_cap
210-
)
211-
# Transpose back to bsnd format
127+
# let's transpose to bnsd so we can use the grouped sdpa
128+
query = query.transpose(1, 2).contiguous()
129+
key = key.transpose(1, 2).contiguous()
130+
value = value.transpose(1, 2).contiguous()
131+
132+
out = grouped_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
133+
134+
# let's transpose back to bnsd
212135
return out.transpose(1, 2).contiguous()
213136

214137

0 commit comments

Comments
 (0)