Skip to content

Commit f83e674

Browse files
committed
Fix the minor error during the kernel call
Signed-off-by: nvchenghaoz <[email protected]>
1 parent cf9023a commit f83e674

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def grouped_sdpa(
119119
dropout_p=dropout_p,
120120
is_causal=is_causal,
121121
scale=scale,
122-
logit_cap=logit_cap,
122+
enable_gqa=True,
123123
)
124124

125125

tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def _generate_mha(
5656
stage1_output_logsumexp = torch.empty(
5757
b, n_heads, num_blocks, device=device, dtype=torch.float32
5858
) - float("inf")
59+
5960
update_kv_cache[(b, n_kv_heads, 1)](
6061
k,
6162
v,
@@ -74,7 +75,13 @@ def _generate_mha(
7475
)
7576

7677
HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
77-
gqa_attention_kv_stage1[(b, n_heads, num_blocks)](
78+
gqa_attention_kv_stage1[
79+
(
80+
b,
81+
n_kv_heads,
82+
num_blocks,
83+
)
84+
](
7885
q,
7986
k_cache,
8087
v_cache,
@@ -382,6 +389,7 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
382389
scale = source_attn_node.args[6]
383390
else:
384391
scale = source_attn_node.kwargs.get("scale", None)
392+
385393
# do a sanity check on the scale if it is not None, we only support the default scale
386394
# of 1/sqrt(head_dim) and so we should do an approximate check for that one
387395
if not isinstance(scale, float):

0 commit comments

Comments
 (0)