Skip to content

Commit 098e8f7

Browse files
authored
Add IS_OMNI5 to support dynamic rms_norm num_warps.
Add environment variable check for warp count in RMSNorm.
1 parent e13e2a1 commit 098e8f7

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

lightllm/models/llama/triton_kernel/rmsnorm.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import triton
44
import triton.language as tl
5+
import os
56

7+
use_5o_num_warps = os.getenv("IS_OMNI5", "False").upper() in ["ON", "TRUE", "1"]
68

79
@triton.jit
810
def _rms_norm_fwd_fused(
@@ -56,8 +58,11 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
5658
if N > BLOCK_SIZE:
5759
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
5860
# heuristics for number of warps
59-
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
60-
num_warps = triton.next_power_of_2(num_warps)
61+
if use_5o_num_warps:
62+
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
63+
num_warps = triton.next_power_of_2(num_warps)
64+
else:
65+
num_warps = 8
6166
if BLOCK_SIZE > 16384:
6267
BLOCK_SIZE = 16384
6368
# enqueue kernel

0 commit comments

Comments
 (0)