Skip to content

Commit 7fe6038

Browse files
committed
fix ep scatter overflow
1 parent 366f4d4 commit 7fe6038

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

lightllm/common/fused_moe/deepep_scatter_gather.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def _fwd_kernel_ep_scatter_2(
7171

7272
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
7373
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
74-
7574
for token_id in range(start_token_id, total_token_num, grid_num):
7675
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
7776
to_copy_s = tl.load(recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s)
@@ -80,6 +79,7 @@ def _fwd_kernel_ep_scatter_2(
8079
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
8180
if expert_id >= 0:
8281
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
82+
dest_token_index = dest_token_index.to(tl.int64)
8383
tl.store(output_index + token_id * output_index_stride0 + topk_index, dest_token_index)
8484
output_tensor_ptr = output_tensor + dest_token_index * output_tensor_stride0
8585
output_tensor_scale_ptr = output_tensor_scale + dest_token_index * output_tensor_scale_stride0

0 commit comments

Comments
 (0)