Skip to content

Commit ea8f30b

Browse files
committed
try fix
1 parent 117a23c commit ea8f30b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ def chunk_scaled_dot_kkt_fwd_kernel(
7474
(1, 0),
7575
)
7676
b_k = tl.load(p_k, boundary_check=(0, 1))
77-
b_kb = b_k * b_beta[:, None]
78-
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
77+
b_A += tl.dot(b_k, tl.trans(b_k))
7978

8079
if USE_G:
8180
p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
8281
b_g = tl.load(p_g, boundary_check=(0,))
8382
b_g_diff = b_g[:, None] - b_g[None, :]
8483
b_A = b_A * exp(b_g_diff)
8584

85+
b_A *= b_beta[:, None]
8686
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
8787
b_A = tl.where(m_A, b_A, 0)
8888
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0))

0 commit comments

Comments
 (0)