Skip to content

Commit c0fc027

Browse files
committed
cleanups + lint, layer.py wip
Signed-off-by: Bill Nell <[email protected]>
1 parent 938c516 commit c0fc027

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
504504
chunk_by_rank(w2, rank, world_size),
505505
chunk_topk_weight,
506506
chunk_topk_ids,
507-
global_num_experts=num_experts
508-
)
507+
global_num_experts=num_experts)
509508

510509
torch.cuda.synchronize()
511510

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -622,8 +622,7 @@ def __init__(
622622
assert quant_method is not None
623623
self.quant_method = quant_method
624624

625-
dispatch_combine = self._construct_dispatch_combine(
626-
moe, quant_config)
625+
dispatch_combine = self._construct_dispatch_combine(moe, quant_config)
627626

628627
success = self.quant_method.set_dispatch_combine(dispatch_combine)
629628

@@ -1029,13 +1028,12 @@ def forward(self, hidden_states: torch.Tensor,
10291028
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
10301029
self.layer_name)
10311030

1032-
10331031
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
10341032
full_router_logits: torch.Tensor):
10351033

10361034
full_final_hidden_states = torch.empty_like(full_hidden_states)
10371035

1038-
def process_chunk(chunk_start, chunk_end, skip_result_store = False):
1036+
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
10391037
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
10401038
router_logits = full_router_logits[chunk_start:chunk_end, :]
10411039

@@ -1088,18 +1086,23 @@ def process_chunk(chunk_start, chunk_end, skip_result_store = False):
10881086
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
10891087
final_hidden_states)
10901088

1091-
max_tokens_across_dp = get_forward_context().dp_metadata.max_tokens_across_dp
1089+
max_tokens_across_dp = get_forward_context(
1090+
).dp_metadata.max_tokens_across_dp
10921091
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size
10931092

10941093
num_tokens = full_hidden_states.size(0)
1095-
for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank):
1096-
chunk_start = chunk_start_
1097-
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dp)
1094+
for chunk_start_ in range(0, max_tokens_across_dp,
1095+
moe_dp_chunk_size_per_rank):
1096+
chunk_start = chunk_start_
1097+
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
1098+
max_tokens_across_dp)
10981099
# clamp start and end
10991100
chunk_start = min(chunk_start, num_tokens - 1)
11001101
chunk_end = min(chunk_end, num_tokens)
11011102

1102-
process_chunk(chunk_start, chunk_end, skip_result_store = chunk_start_ >= num_tokens)
1103+
process_chunk(chunk_start,
1104+
chunk_end,
1105+
skip_result_store=chunk_start_ >= num_tokens)
11031106

11041107
return full_final_hidden_states
11051108

0 commit comments

Comments
 (0)