@@ -432,7 +432,7 @@ def _pplx_dispatch_combine(
432
432
@pytest .mark .parametrize ("k" , [128 , 512 , 1024 ])
433
433
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
434
434
@pytest .mark .parametrize ("topk" , TOP_KS )
435
- @pytest .mark .parametrize ("dtype" , [torch .float16 , torch . bfloat16 ])
435
+ @pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
436
436
@pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]])
437
437
@requires_pplx
438
438
def test_pplx_dispatch_combine (
@@ -584,13 +584,13 @@ def _pplx_moe(
584
584
topk_weight , topk_ids , _ = fused_topk (a , score , topk , False )
585
585
torch_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids )
586
586
pplx_output = pplx_moe (pgi , dp_size , a , w1 , w2 , topk_weight , topk_ids )
587
- batched_output = _batched_moe (pgi , dp_size , a , w1 , w2 , topk_weight , topk_ids )
587
+ # batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
588
588
589
589
torch_output = chunk_by_rank (torch_output , pgi .rank ,
590
590
pgi .world_size ).to (pplx_output .device )
591
591
592
592
torch .testing .assert_close (pplx_output , torch_output , atol = 2e-2 , rtol = 0 )
593
- torch .testing .assert_close (batched_output , torch_output , atol = 2e-2 , rtol = 0 )
593
+ # torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
594
594
595
595
nvshmem_finalize ()
596
596
0 commit comments