Skip to content

Commit c8f5625

Browse files
committed
Merge branch 'main' into fast_start
2 parents 7100389 + 16c8c79 commit c8f5625

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

lightllm/models/llama/triton_kernel/rmsnorm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import triton
44
import triton.language as tl
5+
import os
6+
7+
rmsnorm_num_warps = int(os.getenv("RMSNORM_WARPS", "8"))
58

69

710
@triton.jit
@@ -56,8 +59,6 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
5659
if N > BLOCK_SIZE:
5760
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
5861
# heuristics for number of warps
59-
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
60-
num_warps = triton.next_power_of_2(num_warps)
6162
if BLOCK_SIZE > 16384:
6263
BLOCK_SIZE = 16384
6364
# enqueue kernel
@@ -72,7 +73,7 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
7273
N,
7374
eps,
7475
BLOCK_SIZE=BLOCK_SIZE,
75-
num_warps=num_warps,
76+
num_warps=rmsnorm_num_warps,
7677
)
7778
return y
7879

0 commit comments

Comments
 (0)