Skip to content

Commit ca763c3

Browse files
committed
fix merge
Signed-off-by: Bill Nell <[email protected]>
1 parent 2bafbe0 commit ca763c3

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
341341
torch.float32.itemsize)),
342342
)
343343

344+
topk_ids = topk_ids.to(dtype=torch.uint32)
345+
344346
dispatch_combine = PplxDispatchCombine(
345347
ata,
346348
max_num_tokens,
@@ -478,6 +480,8 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
478480
torch.float32.itemsize)),
479481
)
480482

483+
topk_ids = topk_ids.to(dtype=torch.uint32)
484+
481485
dispatch_combine = PplxDispatchCombine(
482486
ata,
483487
max_num_tokens,

0 commit comments

Comments
 (0)