Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lightllm/models/llama/triton_kernel/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The PR title suggests this change is a fix. It would be very helpful for future maintenance to add a brief inline comment explaining why the maximum number of warps is capped at 4. This provides context for why this specific value was chosen, for example, if it addresses a performance or stability issue on certain hardware.

Suggested change
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
num_warps = min(max(BLOCK_SIZE // 256, 1), 4) # Capped at 4 for performance/stability reasons

num_warps = triton.next_power_of_2(num_warps)
if BLOCK_SIZE > 16384:
BLOCK_SIZE = 16384
Expand Down