Skip to content

Commit 48ff875

Browse files
committed
fix
1 parent 631b6a8 commit 48ff875

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

lightllm/common/fused_moe/deepep_scatter_gather.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,11 @@ def ep_gather(
217217
input_index: torch.Tensor,
218218
output_tensor: torch.Tensor,
219219
):
220-
BLOCK_D = 128 # block size of quantization
221-
num_warps = 4
220+
BLOCK_D = 1024 # block size of quantization
221+
num_warps = 2
222222
num_tokens = output_tensor.shape[0]
223223
hidden_size = input_tensor.shape[1]
224+
assert hidden_size % BLOCK_D == 0
224225
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
225226
_fwd_kernel_ep_gather[grid](
226227
num_tokens,

0 commit comments

Comments
 (0)