Skip to content

Commit af91bf1

Browse files
author
none
committed
fix
1 parent 1f2658f commit af91bf1

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

lightllm/models/deepseek2/triton_kernel/rotary_emb.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,28 +74,32 @@ def rotary_emb_fwd(q, k, cos, sin):
7474
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
7575
assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}"
7676
assert triton.next_power_of_2(head_dim) == head_dim
77-
BLOCK_SEQ = 16
77+
78+
if total_len <= 512:
79+
BLOCK_SEQ = 1
80+
else:
81+
BLOCK_SEQ = 16
7882

7983
num_warps = 1
8084
num_stages = 5
8185

8286
grid = (triton.cdiv(total_len, BLOCK_SEQ),)
8387
_rotary_kernel[grid](
84-
q,
85-
k,
86-
cos,
87-
sin,
88-
q.stride(0),
89-
q.stride(1),
90-
q.stride(2),
91-
k.stride(0),
92-
k.stride(1),
93-
k.stride(2),
94-
cos.stride(0),
95-
cos.stride(1),
96-
sin.stride(0),
97-
sin.stride(1),
98-
total_len,
88+
Q=q,
89+
K=k,
90+
Cos=cos,
91+
Sin=sin,
92+
stride_qbs=q.stride(0),
93+
stride_qh=q.stride(1),
94+
stride_qd=q.stride(2),
95+
stride_kbs=k.stride(0),
96+
stride_kh=k.stride(1),
97+
stride_kd=k.stride(2),
98+
stride_cosbs=cos.stride(0),
99+
stride_cosd=cos.stride(1),
100+
stride_sinbs=sin.stride(0),
101+
stride_sind=sin.stride(1),
102+
max_total_len=total_len,
99103
HEAD_Q=head_num_q,
100104
HEAD_K=head_num_k,
101105
BLOCK_SEQ=BLOCK_SEQ,

0 commit comments

Comments
 (0)