|
28 | 28 | from vllm.config import VllmConfig, set_current_vllm_config
|
29 | 29 | from vllm.model_executor.layers.activation import SiluAndMul
|
30 | 30 | from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
31 |
| - BatchedExperts, BatchedTritonExperts) |
| 31 | + BatchedExperts) |
32 | 32 | from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
33 | 33 | from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
34 | 34 | FusedMoEModularKernel)
|
@@ -390,9 +390,11 @@ def _pplx_dispatch_combine(
|
390 | 390 | a_rep = torch.repeat_interleave(a, topk, dim=0).to(device)
|
391 | 391 |
|
392 | 392 | torch_output = (a_rep.view(-1, topk, k) * 1.5 *
|
393 |
| - topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(a.dtype) |
| 393 | + topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to( |
| 394 | + a.dtype) |
394 | 395 |
|
395 |
| - pplx_output = pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts) |
| 396 | + pplx_output = pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, |
| 397 | + num_experts) |
396 | 398 |
|
397 | 399 | torch_output = chunk_by_rank(torch_output, pgi.rank,
|
398 | 400 | pgi.world_size).to(pplx_output.device)
|
@@ -426,7 +428,8 @@ def test_pplx_dispatch_combine(
|
426 | 428 | a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
427 | 429 | score = torch.randn((m, e), device=device, dtype=dtype)
|
428 | 430 |
|
429 |
| - parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, score, topk, e) |
| 431 | + parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, score, |
| 432 | + topk, e) |
430 | 433 |
|
431 | 434 |
|
432 | 435 | def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
|
|
0 commit comments