Skip to content

Commit 9590b96

Browse files
committed
cleanups + lint, layer.py wip
Signed-off-by: Bill Nell <[email protected]>
1 parent 6a3daba commit 9590b96

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
504504
chunk_by_rank(w2, rank, world_size),
505505
chunk_topk_weight,
506506
chunk_topk_ids,
507-
global_num_experts=num_experts
508-
)
507+
global_num_experts=num_experts)
509508

510509
torch.cuda.synchronize()
511510

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -622,8 +622,7 @@ def __init__(
622622
assert quant_method is not None
623623
self.quant_method = quant_method
624624

625-
dispatch_combine = self._construct_dispatch_combine(
626-
moe, quant_config)
625+
dispatch_combine = self._construct_dispatch_combine(moe, quant_config)
627626

628627
success = self.quant_method.set_dispatch_combine(dispatch_combine)
629628

@@ -1030,13 +1029,12 @@ def forward(self, hidden_states: torch.Tensor,
10301029
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
10311030
self.layer_name)
10321031

1033-
10341032
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
10351033
full_router_logits: torch.Tensor):
10361034

10371035
full_final_hidden_states = torch.empty_like(full_hidden_states)
10381036

1039-
def process_chunk(chunk_start, chunk_end, skip_result_store = False):
1037+
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
10401038
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
10411039
router_logits = full_router_logits[chunk_start:chunk_end, :]
10421040

@@ -1089,18 +1087,23 @@ def process_chunk(chunk_start, chunk_end, skip_result_store = False):
10891087
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
10901088
final_hidden_states)
10911089

1092-
max_tokens_across_dp = get_forward_context().dp_metadata.max_tokens_across_dp
1090+
max_tokens_across_dp = get_forward_context(
1091+
).dp_metadata.max_tokens_across_dp
10931092
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size
10941093

10951094
num_tokens = full_hidden_states.size(0)
1096-
for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank):
1097-
chunk_start = chunk_start_
1098-
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dp)
1095+
for chunk_start_ in range(0, max_tokens_across_dp,
1096+
moe_dp_chunk_size_per_rank):
1097+
chunk_start = chunk_start_
1098+
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
1099+
max_tokens_across_dp)
10991100
# clamp start and end
11001101
chunk_start = min(chunk_start, num_tokens - 1)
11011102
chunk_end = min(chunk_end, num_tokens)
11021103

1103-
process_chunk(chunk_start, chunk_end, skip_result_store = chunk_start_ >= num_tokens)
1104+
process_chunk(chunk_start,
1105+
chunk_end,
1106+
skip_result_store=chunk_start_ >= num_tokens)
11041107

11051108
return full_final_hidden_states
11061109

0 commit comments

Comments
 (0)