Skip to content

Commit 7f0da83

Browse files
committed
fix grouped topk bf16 sigmoid mode.
1 parent 9876ed1 commit 7f0da83

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

lightllm/common/fused_moe/grouped_topk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def grouped_topk_kernel(
113113
gating_output_ptr + token_index * gating_output_stride_m + offs_n,
114114
mask=offs_n < total_expert_num,
115115
other=-10000000.0,
116-
)
116+
).to(tl.float32)
117117
if IS_SIGMOID:
118118
scores = tl.sigmoid(hidden_states)
119119
else:

unit_tests/common/fused_moe/test_grouped_topk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def test_grouped_topk(expert_num, topk_group, group_num, topk_num, scoring_func,
8989

9090
assert torch.equal(torch.sort(old_topk_ids, dim=1)[0], torch.sort(new_topk_ids, dim=1)[0])
9191
assert torch.allclose(
92-
torch.sort(old_topk_weights, dim=1)[0], torch.sort(new_topk_weights, dim=1)[0], atol=1e-4, rtol=0
93-
)
92+
torch.sort(old_topk_weights, dim=1)[0], torch.sort(new_topk_weights, dim=1)[0], atol=1e-3, rtol=1e-1
93+
), f"max delta {torch.max(torch.sort(old_topk_weights, dim=1)[0] - torch.sort(new_topk_weights, dim=1)[0])}"
9494
return
9595

9696

0 commit comments

Comments
 (0)