Skip to content

Commit 7216a2a

Browse files
committed
try fix
1 parent 0f433c0 commit 7216a2a

File tree

4 files changed

+20
-11
lines changed

4 files changed

+20
-11
lines changed

lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import triton.language as tl
1515

1616
from .index import prepare_chunk_indices, prepare_chunk_offsets
17-
from .op import exp
17+
from .op import exp, safe_exp
1818
from .utils import use_cuda_graph
1919
from lightllm.common.triton_utils.autotuner import autotune
2020

@@ -150,19 +150,18 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
150150

151151
last_idx = min((i_t + 1) * BT, T) - 1
152152
if USE_G:
153-
m_t = (i_t * BT + tl.arange(0, BT)) < T
154153
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
155154
p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
156155
b_g = tl.load(p_g, boundary_check=(0,))
157-
b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
156+
b_v = b_v * safe_exp(b_g_last - b_g)[:, None]
158157
b_g_last = exp(b_g_last)
159-
b_h1 *= b_g_last
158+
b_h1 = b_h1 * b_g_last
160159
if K > 64:
161-
b_h2 *= b_g_last
160+
b_h2 = b_h2 * b_g_last
162161
if K > 128:
163-
b_h3 *= b_g_last
162+
b_h3 = b_h3 * b_g_last
164163
if K > 192:
165-
b_h4 *= b_g_last
164+
b_h4 = b_h4 * b_g_last
166165

167166
if USE_GK:
168167
o_k1 = tl.arange(0, 64)

lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import triton.language as tl
1717

1818
from .index import prepare_chunk_indices
19-
from .op import exp
19+
from .op import exp, safe_exp
2020
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
2121
from lightllm.common.triton_utils.autotuner import autotune
2222

@@ -103,7 +103,7 @@ def chunk_fwd_kernel_o(
103103
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
104104
b_g = tl.load(p_g, boundary_check=(0,))
105105
b_o = b_o * exp(b_g)[:, None]
106-
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
106+
b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
107107

108108
o_t = i_t * BT + tl.arange(0, BT)
109109
m_t = o_t < T

lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import triton.language as tl
1515

1616
from .index import prepare_chunk_indices
17-
from .op import exp
17+
from .op import exp, safe_exp
1818
from lightllm.common.triton_utils.autotuner import autotune
1919

2020
triton.set_allocator
@@ -80,7 +80,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
8080
p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
8181
b_g = tl.load(p_g, boundary_check=(0,))
8282
b_g_diff = b_g[:, None] - b_g[None, :]
83-
b_A = b_A * exp(b_g_diff)
83+
b_A = b_A * safe_exp(b_g_diff)
8484

8585
b_A *= b_beta[:, None]
8686
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)

lightllm/models/qwen3next/triton_kernel/fla/ops/op.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@
1919
log2 = tl.log2
2020

2121

22+
@triton.jit
23+
def safe_exp(x):
24+
"""
25+
Numerically stable exponential function.
26+
Only applies exp to non-positive values, returns 0 for positive values.
27+
This prevents numerical overflow and improves stability.
28+
"""
29+
return exp(tl.where(x <= 0, x, float("-inf")))
30+
31+
2232
if not is_gather_supported:
2333

2434
@triton.jit

0 commit comments

Comments
 (0)