Skip to content

Commit dba34d4

Browse files
894743926w30012745
andauthored
[v0.18.0][Triton][Qwen3.5] delete expr for kernels args (vllm-project#7646)
### What this PR does / why we need it? Some parameters of Triton operators are unnecessarily modified with the "constexpr" modifier. When these parameters change, recompilation is triggered, which significantly affects the model performance. Therefore, these parameters need to be rectified. backport: vllm-project#7482 Signed-off-by: w30012745 <wangxiaoshuai2@h-partners.com> Co-authored-by: w30012745 <wangxiaoshuai2@h-partners.com>
1 parent dd55736 commit dba34d4

File tree

4 files changed

+13
-13
lines changed

4 files changed

+13
-13
lines changed

vllm_ascend/ops/triton/fla/chunk_o.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
2323
}
2424
)
25-
@triton.jit(do_not_specialize=["T"])
25+
@triton.jit(do_not_specialize=["chunk_offsets", "scale", "T", "H", "Hg", "K", "V"])
2626
def chunk_fwd_kernel_o(
2727
q,
2828
k,
@@ -34,10 +34,10 @@ def chunk_fwd_kernel_o(
3434
chunk_offsets,
3535
scale,
3636
T,
37-
H: tl.constexpr,
38-
Hg: tl.constexpr,
39-
K: tl.constexpr,
40-
V: tl.constexpr,
37+
H,
38+
Hg,
39+
K,
40+
V,
4141
BT: tl.constexpr,
4242
BK: tl.constexpr,
4343
BV: tl.constexpr,

vllm_ascend/ops/triton/fla/l2norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
1515

1616

17-
@triton.jit
18-
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr):
17+
@triton.jit(do_not_specialize=["eps", "M", "NUM_CHUNKS"])
18+
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS):
1919
base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK)
2020
rindex = tl.arange(0, N)[None, :]
2121

vllm_ascend/ops/triton/fla/wy_fast.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
20-
@triton.jit(do_not_specialize=["T"])
20+
@triton.jit(do_not_specialize=["T", "H", "Hg", "K", "V"])
2121
def recompute_w_u_fwd_kernel(
2222
k,
2323
v,
@@ -29,10 +29,10 @@ def recompute_w_u_fwd_kernel(
2929
cu_seqlens,
3030
chunk_indices,
3131
T,
32-
H: tl.constexpr,
33-
Hg: tl.constexpr,
34-
K: tl.constexpr,
35-
V: tl.constexpr,
32+
H,
33+
Hg,
34+
K,
35+
V,
3636
BT: tl.constexpr,
3737
BK: tl.constexpr,
3838
BV: tl.constexpr,

vllm_ascend/ops/triton/layernorm_gated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
1414
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
15-
@triton.jit
15+
@triton.jit(do_not_specialize=["stride_x_row", "stride_y_row", "stride_z_row", "M", "N", "eps"])
1616
def _layer_norm_fwd_1pass_kernel_npu(
1717
X, # pointer to the input
1818
Y, # pointer to the output

0 commit comments

Comments
 (0)