Skip to content

Commit 7228d98

Browse files
committed
Add the softcap import and related transform
Signed-off-by: Chenghao Zhang <[email protected]>
1 parent f1dcc35 commit 7228d98

File tree

6 files changed

+122
-39
lines changed

6 files changed

+122
-39
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,58 @@ def scaled_dot_product_attention(
4040
dropout_p: float = 0.0,
4141
is_causal: bool = False,
4242
scale: Optional[float] = None,
43+
logit_cap: Optional[float] = None,
4344
) -> torch.Tensor:
4445
"""A carbon copy of torch.nn.functional.scaled_dot_product_attention as custom op.
4546
4647
Using this custom op instead of using the functional directly ensures consistent representation
4748
of the vanilla sdpa in a graph.
4849
"""
4950

50-
return F.scaled_dot_product_attention(
51-
query.contiguous(),
52-
key.contiguous(),
53-
value.contiguous(),
54-
attn_mask=attn_mask,
55-
dropout_p=dropout_p,
56-
is_causal=is_causal,
57-
scale=scale,
58-
)
51+
# Handle soft capping by applying it manually since F.scaled_dot_product_attention
52+
# may not support soft_cap parameter
53+
if logit_cap is not None:
54+
# Apply manual soft capping to the attention scores
55+
# First compute raw attention scores
56+
d_k = query.size(-1)
57+
if scale is None:
58+
scale = 1.0 / (d_k**0.5)
59+
60+
# Compute attention scores
61+
scores = torch.matmul(query, key.transpose(-2, -1)) * scale
62+
63+
# Apply soft capping: tanh(scores / logit_cap) * logit_cap
64+
scores = torch.tanh(scores / logit_cap) * logit_cap
65+
66+
if attn_mask is not None:
67+
scores += attn_mask
68+
69+
# Apply softmax
70+
attn_weights = F.softmax(scores, dim=-1)
71+
72+
# Apply dropout if specified
73+
if dropout_p > 0.0:
74+
attn_weights = F.dropout(attn_weights, p=dropout_p, training=torch.is_grad_enabled())
75+
76+
# Apply attention to values
77+
output = torch.matmul(attn_weights, value)
78+
return output.contiguous()
79+
else:
80+
# Use standard SDPA when no soft capping
81+
return F.scaled_dot_product_attention(
82+
query.contiguous(),
83+
key.contiguous(),
84+
value.contiguous(),
85+
attn_mask=attn_mask,
86+
dropout_p=dropout_p,
87+
is_causal=is_causal,
88+
scale=scale,
89+
)
5990

6091

6192
@scaled_dot_product_attention.register_fake
6293
def scaled_dot_product_attention_fake(
63-
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
94+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, logit_cap=None
6495
):
6596
"""Fake implementation of scaled_dot_product_attention."""
6697
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
@@ -75,18 +106,20 @@ def grouped_sdpa(
75106
dropout_p: float = 0.0,
76107
is_causal: bool = False,
77108
scale: Optional[float] = None,
109+
logit_cap: Optional[float] = None,
78110
) -> torch.Tensor:
79111
"""SDPA attention that can handle GQA."""
80112

81-
return F.scaled_dot_product_attention(
113+
# Use our custom scaled_dot_product_attention that supports soft capping
114+
return scaled_dot_product_attention(
82115
query.contiguous(),
83116
key.contiguous(),
84117
value.contiguous(),
85118
attn_mask=attn_mask,
86119
dropout_p=dropout_p,
87120
is_causal=is_causal,
88121
scale=scale,
89-
enable_gqa=True,
122+
logit_cap=logit_cap,
90123
)
91124

92125

@@ -99,6 +132,7 @@ def grouped_sdpa_fake(
99132
dropout_p=0.0,
100133
is_causal=False,
101134
scale=None,
135+
logit_cap=None,
102136
):
103137
"""Fake implementation of grouped SDPA."""
104138
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
@@ -113,6 +147,7 @@ def bsnd_grouped_sdpa(
113147
dropout_p: float = 0.0,
114148
is_causal: bool = False,
115149
scale: Optional[float] = None,
150+
logit_cap: Optional[float] = None,
116151
) -> torch.Tensor:
117152
"""Attention that assumes the input layout is bsnd.
118153
@@ -124,15 +159,15 @@ def bsnd_grouped_sdpa(
124159
key = key.transpose(1, 2).contiguous()
125160
value = value.transpose(1, 2).contiguous()
126161

127-
out = grouped_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
162+
out = grouped_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale, logit_cap)
128163

129164
# let's transpose back to bnsd
130165
return out.transpose(1, 2).contiguous()
131166

132167

133168
@bsnd_grouped_sdpa.register_fake
134169
def bsnd_grouped_sdpa_fake(
135-
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
170+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, logit_cap=None
136171
):
137172
"""Fake implementation of bnsd grouped SDPA."""
138173
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()

tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _generate_mha(
4040
cache_locs: torch.Tensor,
4141
input_pos: torch.Tensor,
4242
scale: float,
43+
logit_cap: Optional[float],
4344
out: torch.Tensor,
4445
):
4546
b, (n_heads, q_d_head) = q.shape[0], q.shape[-2:]
@@ -55,7 +56,6 @@ def _generate_mha(
5556
stage1_output_logsumexp = torch.empty(
5657
b, n_heads, num_blocks, device=device, dtype=torch.float32
5758
) - float("inf")
58-
5959
update_kv_cache[(b, n_kv_heads, 1)](
6060
k,
6161
v,
@@ -74,13 +74,7 @@ def _generate_mha(
7474
)
7575

7676
HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
77-
gqa_attention_kv_stage1[
78-
(
79-
b,
80-
n_kv_heads,
81-
num_blocks,
82-
)
83-
](
77+
gqa_attention_kv_stage1[(b, n_heads, num_blocks)](
8478
q,
8579
k_cache,
8680
v_cache,
@@ -97,6 +91,7 @@ def _generate_mha(
9791
v_d_head,
9892
SEQ_BLOCK_SIZE,
9993
HEAD_BLOCK_SIZE,
94+
LOGIT_CAP=logit_cap,
10095
)
10196
attention_kv_stage2[(b, n_heads, 1)](
10297
stage1_output_values,
@@ -122,6 +117,7 @@ def _flattened_context_mha(
122117
seq_start: torch.Tensor,
123118
scale: float,
124119
out: torch.Tensor,
120+
logit_cap: Optional[float],
125121
) -> None:
126122
# NOTE: s_total == sum(seq_len)
127123
s_total, n_heads, q_d_head = q.shape
@@ -166,6 +162,7 @@ def _flattened_context_mha(
166162
SEQ_BLOCK,
167163
max_cache_seq_len,
168164
num_stages=2,
165+
LOGIT_CAP=logit_cap,
169166
)
170167

171168

@@ -187,6 +184,7 @@ def flattened_mha_with_cache(
187184
# <none>
188185
# CONSTANTS
189186
scale: Optional[float],
187+
logit_cap: Optional[float],
190188
) -> torch.Tensor:
191189
"""Flattened MHA with cache that takes q, k, v in BSND layout.
192190
@@ -223,7 +221,7 @@ def flattened_mha_with_cache(
223221
y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous()
224222
if s == 1:
225223
# generate-only phase
226-
_generate_mha(q, k, v, k_cache, v_cache, cache_loc, input_pos, scale, y)
224+
_generate_mha(q, k, v, k_cache, v_cache, cache_loc, input_pos, scale, logit_cap, y)
227225
else:
228226
# mixed context + generate phase
229227
_flattened_context_mha(
@@ -237,7 +235,8 @@ def flattened_mha_with_cache(
237235
seq_len,
238236
seq_start,
239237
scale,
240-
y,
238+
out=y,
239+
logit_cap=logit_cap,
241240
)
242241

243242
return y.view(*output_shape)
@@ -255,6 +254,7 @@ def flattened_mha_fake(
255254
k_cache: torch.Tensor,
256255
v_cache: torch.Tensor,
257256
scale: Optional[float],
257+
logit_cap: Optional[float],
258258
):
259259
return q.new_empty(*q.shape[:-1], v.shape[-1]).contiguous()
260260

@@ -382,13 +382,18 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
382382
scale = source_attn_node.args[6]
383383
else:
384384
scale = source_attn_node.kwargs.get("scale", None)
385-
386385
# do a sanity check on the scale if it is not None, we only support the default scale
387386
# of 1/sqrt(head_dim) and so we should do an approximate check for that one
388387
if not isinstance(scale, float):
389388
ad_logger.warning("Provided scale is not a float, Using default scale instead.")
390389
scale = None
391390

391+
if len(source_attn_node.args) > 7:
392+
logit_cap = source_attn_node.args[7]
393+
else:
394+
logit_cap = source_attn_node.kwargs.get("logit_cap", None)
395+
392396
return [
393397
scale, # softmax scale
398+
logit_cap, # soft capping scale
394399
]

tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,10 @@ def gqa_attention_kv_stage1(
128128
1. Fetch the K-cache from 0 to input_pos
129129
2. Fetch the V-cache from 0 to input_pos
130130
3. A = Q*K^T [1,D_HEAD] * [1,seq_len,D_HEAD] -> [1, seq_len]
131-
4. S = softmax(A)
132-
5. O = S*V [1, seq_len] * [1, seq_len, D_HEAD] -> [1, D_HEAD]
131+
4. A = A * scale
132+
5. A = A * logit_cap if logit_cap is not None
133+
6. S = softmax(A)
134+
7. O = S*V [1, seq_len] * [1, seq_len, D_HEAD] -> [1, D_HEAD]
133135
"""
134136
# Assume KV-cache layout: [Batch, Seq, Head, Dim]
135137
# A program is responsible for 1 batch, 1 head and a block of sequences.
@@ -577,6 +579,7 @@ def context_attention_kv_flattened(
577579
V_D_HEAD: tl.constexpr, # Dimension of each value head.
578580
SEQ_BLOCK: tl.constexpr,
579581
MAX_SEQ_LENGTH: tl.constexpr,
582+
LOGIT_CAP: tl.constexpr = None,
580583
):
581584
"""Kernel for context phase.
582585
@@ -645,6 +648,8 @@ def context_attention_kv_flattened(
645648
(seq_offsets[:, None] + kv_position) >= kv_seq_offsets[None, :], qk, float("-inf")
646649
)
647650
qk *= SCALE
651+
if LOGIT_CAP is not None:
652+
qk = LOGIT_CAP * tanh(qk / LOGIT_CAP)
648653
# rowmax
649654
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
650655
p = tl.exp(qk - m_ij[:, None])

tensorrt_llm/_torch/auto_deploy/models/decilm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88
def _from_pretrained_patched(pretrained_model_name_or_path, **kwargs):
99
print(str(pretrained_model_name_or_path))
10-
if re.search(r"Llama-3_(?:1|3)-Nemotron-(?:Ultra|Super)", str(pretrained_model_name_or_path)):
10+
11+
# Use the eager attention implementation for Gemma-2 models to import the soft logit capping ops.
12+
if re.search(
13+
r"Llama-3_(?:1|3)-Nemotron-(?:Ultra|Super)", str(pretrained_model_name_or_path)
14+
) or re.search(r"gemma-2", str(pretrained_model_name_or_path), re.IGNORECASE):
1115
kwargs["attn_implementation"] = "eager"
1216
return _orig_from_pretrained(pretrained_model_name_or_path, **kwargs)
1317

tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ def _match_eager_attention_pattern(final_matmul_node: Node) -> Optional[Dict[str
293293
Match the eager attention pattern starting from the final matmul node.
294294
295295
The pattern is:
296-
transpose -> matmul -> mul/div -> (optional) add -> (optional) to -> softmax -> (optional) to -> dropout -> matmul
296+
transpose -> matmul -> mul/div -> (optional) div -> tanh -> mul (soft capping)
297+
-> (optional) add -> (optional) to -> softmax -> (optional) to -> dropout -> matmul
297298
298299
Returns a dictionary with information about the match or None if no match.
299300
"""
@@ -352,21 +353,51 @@ def _match_eager_attention_pattern(final_matmul_node: Node) -> Optional[Dict[str
352353
prev_node = prev_node.args[0]
353354

354355
# Check for attention mask pattern (add node)
356+
attn_mask = None
355357
if is_op(prev_node, torch.ops.aten.add):
356358
add_node = prev_node
357359
attn_mask = add_node.args[1] # Second arg is the mask
358360

359-
# The add should have a mul or div node as its first argument
361+
# The add should have input as its first argument
360362
if len(add_node.args) < 1:
361363
return None
362364

363-
scaling_node = add_node.args[0]
364-
if not (is_op(scaling_node, torch.ops.aten.mul) or is_op(scaling_node, torch.ops.aten.div)):
365-
return None
366-
elif is_op(prev_node, torch.ops.aten.mul) or is_op(prev_node, torch.ops.aten.div):
367-
# No mask case - the softmax input is directly the mul or div node
365+
prev_node = add_node.args[0]
366+
367+
# Check for optional soft capping pattern: div -> tanh -> mul
368+
logit_cap = None
369+
if is_op(prev_node, torch.ops.aten.mul):
370+
# Check if this mul is part of soft capping (mul after tanh)
371+
if len(prev_node.args) >= 2:
372+
mul_input = prev_node.args[0]
373+
soft_cap_mul_factor = prev_node.args[1]
374+
375+
# Check if the input to mul is tanh
376+
if is_op(mul_input, torch.ops.aten.tanh):
377+
if len(mul_input.args) >= 1:
378+
tanh_input = mul_input.args[0]
379+
380+
# Check if the input to tanh is div (completing the soft cap pattern)
381+
if is_op(tanh_input, torch.ops.aten.div):
382+
if len(tanh_input.args) >= 2:
383+
div_input = tanh_input.args[0]
384+
soft_cap_div_factor = tanh_input.args[1]
385+
386+
# Verify that the div and mul factors are the same (soft cap scale)
387+
if isinstance(soft_cap_div_factor, (float, int)) and isinstance(
388+
soft_cap_mul_factor, (float, int)
389+
):
390+
if abs(soft_cap_div_factor - soft_cap_mul_factor) < 1e-6:
391+
logit_cap = soft_cap_div_factor
392+
prev_node = div_input
393+
elif soft_cap_div_factor == soft_cap_mul_factor:
394+
# Same node/tensor used for both operations
395+
logit_cap = soft_cap_div_factor
396+
prev_node = div_input
397+
398+
# Now prev_node should be the scaling operation (mul or div)
399+
if is_op(prev_node, torch.ops.aten.mul) or is_op(prev_node, torch.ops.aten.div):
368400
scaling_node = prev_node
369-
attn_mask = None
370401
else:
371402
return None
372403

@@ -422,6 +453,10 @@ def _match_eager_attention_pattern(final_matmul_node: Node) -> Optional[Dict[str
422453
if attn_mask is not None:
423454
match_info["attn_mask"] = attn_mask
424455

456+
# Add soft cap scale if it exists
457+
if logit_cap is not None:
458+
match_info["logit_cap"] = logit_cap
459+
425460
return match_info
426461

427462

tensorrt_llm/_torch/auto_deploy/transformations/transform.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,10 @@ def __call__(self, cm: CachedSequenceInterface) -> GraphModule:
182182
from .library import visualize_namespace
183183

184184
visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
185+
except ImportError:
185186
ad_logger.warning(
186-
"Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize"
187-
" the graph."
187+
"Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize the graph."
188188
)
189-
except ImportError:
190189
pass
191190

192191
############################################################################################

0 commit comments

Comments
 (0)