Skip to content

Commit 2bafbe0

Browse files
committed
wip
Signed-off-by: Bill Nell <[email protected]>
1 parent a003bd8 commit 2bafbe0

File tree

5 files changed

+51
-49
lines changed

5 files changed

+51
-49
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ build-backend = "setuptools.build_meta"
1515
[project]
1616
name = "vllm"
1717
authors = [{name = "vLLM Team"}]
18-
#license = "Apache-2.0"
19-
#license-files = ["LICENSE"]
18+
license = "Apache-2.0"
19+
license-files = ["LICENSE"]
2020
readme = "README.md"
2121
description = "A high-throughput and memory-efficient inference and serving engine for LLMs"
2222
classifiers = [

tests/kernels/moe/test_cutlass_moe.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,10 @@ def test_cutlass_moe_8_bit_no_graph(
241241
per_out_ch)
242242

243243
score = torch.randn((m, e), device="cuda", dtype=torch.half)
244-
topk_weights, topk_ids = fused_topk(mt.a,
245-
score,
246-
topk,
247-
renormalize=False)
244+
topk_weights, topk_ids, _ = fused_topk(mt.a,
245+
score,
246+
topk,
247+
renormalize=False)
248248

249249
# Note that we are using the dequantized versions of the tensors.
250250
# Using a, w1 and w2 directly results in minor output differences.
@@ -285,10 +285,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
285285
per_out_ch)
286286

287287
score = torch.randn((m, e), device="cuda", dtype=dtype)
288-
topk_weights, topk_ids = fused_topk(mt.a,
289-
score,
290-
topk,
291-
renormalize=False)
288+
topk_weights, topk_ids, _ = fused_topk(mt.a,
289+
score,
290+
topk,
291+
renormalize=False)
292292

293293
# Note that we are using the dequantized versions of the tensors.
294294
# Using a, w1 and w2 directly results in minor output differences.
@@ -338,10 +338,10 @@ def test_cutlass_moe_8_bit_EP(
338338
per_out_channel)
339339

340340
score = torch.randn((m, e), device="cuda", dtype=torch.half)
341-
topk_weights, topk_ids = fused_topk(mt.a,
342-
score,
343-
topk,
344-
renormalize=False)
341+
topk_weights, topk_ids, _ = fused_topk(mt.a,
342+
score,
343+
topk,
344+
renormalize=False)
345345

346346
# Note that we are using the dequantized versions of the tensors.
347347
# Using a, w1 and w2 directly results in minor output differences.

tests/kernels/moe/test_pplx_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def test_fused_moe_batched_experts(
296296
score = torch.randn((m, e), device="cuda", dtype=dtype)
297297

298298
with set_current_vllm_config(vllm_config):
299-
topk_weight, topk_ids = fused_topk(a, score, topk, False)
299+
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
300300
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
301301
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
302302
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
@@ -404,7 +404,7 @@ def _pplx_dispatch_combine(
404404
nvshmem_init(uid, pgi.rank, pgi.world_size)
405405
device = pgi.device
406406

407-
topk_weight, topk_ids = fused_topk(a, score, topk, False)
407+
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
408408
k = a.shape[1]
409409

410410
a_rep = torch.repeat_interleave(a, topk, dim=0).to(device)
@@ -577,7 +577,7 @@ def _pplx_moe(
577577
e, _, n = w2.shape
578578

579579
with set_current_vllm_config(vllm_config):
580-
topk_weight, topk_ids = fused_topk(a, score, topk, False)
580+
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
581581
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
582582
pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
583583
batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -887,10 +887,10 @@ def fused_topk(
887887
dtype=torch.int32 if indices_type is None else indices_type,
888888
device=hidden_states.device
889889
)
890-
token_expert_indicies = torch.empty(M,
891-
topk,
892-
dtype=torch.int32,
893-
device=hidden_states.device)
890+
token_expert_indices = torch.empty(M,
891+
topk,
892+
dtype=torch.int32,
893+
device=hidden_states.device)
894894

895895
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
896896

@@ -1211,28 +1211,29 @@ def fused_experts(hidden_states: torch.Tensor,
12111211

12121212

12131213
def fused_experts_impl(
1214-
hidden_states: torch.Tensor,
1215-
w1: torch.Tensor,
1216-
w2: torch.Tensor,
1217-
topk_weights: torch.Tensor,
1218-
topk_ids: torch.Tensor,
1219-
inplace: bool = False,
1220-
activation: str = "silu",
1221-
apply_router_weight_on_input: bool = False,
1222-
use_fp8_w8a8: bool = False,
1223-
use_int8_w8a8: bool = False,
1224-
use_int8_w8a16: bool = False,
1225-
use_int4_w4a16: bool = False,
1226-
per_channel_quant: bool = False,
1227-
global_num_experts: int = -1,
1228-
expert_map: Optional[torch.Tensor] = None,
1229-
w1_scale: Optional[torch.Tensor] = None,
1230-
w2_scale: Optional[torch.Tensor] = None,
1231-
w1_zp: Optional[torch.Tensor] = None,
1232-
w2_zp: Optional[torch.Tensor] = None,
1233-
a1_scale: Optional[torch.Tensor] = None,
1234-
a2_scale: Optional[torch.Tensor] = None,
1235-
block_shape: Optional[List[int]] = None) -> torch.Tensor:
1214+
hidden_states: torch.Tensor,
1215+
w1: torch.Tensor,
1216+
w2: torch.Tensor,
1217+
topk_weights: torch.Tensor,
1218+
topk_ids: torch.Tensor,
1219+
inplace: bool = False,
1220+
activation: str = "silu",
1221+
apply_router_weight_on_input: bool = False,
1222+
use_fp8_w8a8: bool = False,
1223+
use_int8_w8a8: bool = False,
1224+
use_int8_w8a16: bool = False,
1225+
use_int4_w4a16: bool = False,
1226+
per_channel_quant: bool = False,
1227+
global_num_experts: int = -1,
1228+
expert_map: Optional[torch.Tensor] = None,
1229+
w1_scale: Optional[torch.Tensor] = None,
1230+
w2_scale: Optional[torch.Tensor] = None,
1231+
w1_zp: Optional[torch.Tensor] = None,
1232+
w2_zp: Optional[torch.Tensor] = None,
1233+
a1_scale: Optional[torch.Tensor] = None,
1234+
a2_scale: Optional[torch.Tensor] = None,
1235+
block_shape: Optional[List[int]] = None,
1236+
) -> torch.Tensor:
12361237
# Check constraints.
12371238
if use_int4_w4a16:
12381239
assert hidden_states.shape[1] // 2 == w1.shape[

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,12 +1002,13 @@ def select_experts(hidden_states: torch.Tensor,
10021002
scoring_func=scoring_func,
10031003
e_score_correction_bias=e_score_correction_bias)
10041004
elif custom_routing_function is None:
1005-
topk_weights, topk_ids, token_expert_indices = fused_topk(hidden_states=hidden_states,
1006-
gating_output=router_logits,
1007-
topk=top_k,
1008-
renormalize=renormalize,
1009-
indices_type=indices_type,
1010-
)
1005+
topk_weights, topk_ids, token_expert_indices = fused_topk(
1006+
hidden_states=hidden_states,
1007+
gating_output=router_logits,
1008+
topk=top_k,
1009+
renormalize=renormalize,
1010+
indices_type=indices_type,
1011+
)
10111012
else:
10121013
assert indices_type is None or indices_type == torch.int32
10131014
topk_weights, topk_ids = custom_routing_function(

0 commit comments

Comments
 (0)