Skip to content

Commit 22b988a

Browse files
Varun Sundar Rabindranathbnellnm
authored andcommitted
fix forward_chunked
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Signed-off-by: Bill Nell <[email protected]>
1 parent c7ddca4 commit 22b988a

File tree

1 file changed

+15
-47
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+15
-47
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 15 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,40 +1023,16 @@ def forward(self, hidden_states: torch.Tensor,
10231023
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
10241024
self.layer_name)
10251025

1026+
10261027
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
10271028
full_router_logits: torch.Tensor):
10281029

1029-
ctx = get_forward_context()
1030-
1031-
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp
1032-
#cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu
1033-
num_tokens_across_dp = ctx.dp_metadata.num_tokens_across_dp
1034-
1035-
#In this function we define two ranges:
1036-
# 1. chunk_range - The current iteration of the loops's range over the
1037-
# DP world tokens
1038-
# 2. my_tokens_in_chunk - The tokens within chunk_range that this DP
1039-
# rank owns.
1040-
1041-
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size
1042-
1043-
num_tokens_remaining_across_dp = num_tokens_across_dp
1044-
chunk_start = 0
1045-
chunk_end = min(moe_dp_chunk_size_per_rank,
1046-
full_hidden_states.shape[0])
10471030
full_final_hidden_states = torch.empty_like(full_hidden_states)
10481031

1049-
assert full_hidden_states.shape[0] == full_router_logits.shape[0]
1050-
1051-
for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank):
1032+
def process_chunk(chunk_start, chunk_end, skip_result_store = False):
10521033
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
10531034
router_logits = full_router_logits[chunk_start:chunk_end, :]
10541035

1055-
cu_tokens_across_dp_this_iter = torch.cumsum(
1056-
num_tokens_remaining_across_dp.clamp(
1057-
max=moe_dp_chunk_size_per_rank),
1058-
dim=0)
1059-
10601036
# TODO: still may be needed for non-pplx, put into dispatcher class.
10611037
if False:
10621038
hidden_states = self.naive_multicast(
@@ -1102,30 +1078,22 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
11021078
final_hidden_states = tensor_model_parallel_all_reduce(
11031079
final_hidden_states)
11041080

1105-
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
1106-
final_hidden_states)
1107-
1108-
# Update bounds
1109-
num_tokens_remaining_across_dp = torch.clamp(
1110-
num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank,
1111-
min=0)
1081+
if not skip_result_store:
1082+
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
1083+
final_hidden_states)
11121084

1113-
# HACK FIX
1114-
if num_tokens_remaining_across_dp.sum() == 0:
1115-
break
1085+
max_tokens_across_dp = get_forward_context().dp_metadata.max_tokens_across_dp
1086+
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size
11161087

1117-
def update_chunk_bound(x: int):
1118-
return min(x + moe_dp_chunk_size_per_rank,
1119-
full_hidden_states.shape[0])
1088+
num_tokens = full_hidden_states.size(0)
1089+
for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank):
1090+
chunk_start = chunk_start_
1091+
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dp)
1092+
# clamp start and end
1093+
chunk_start = min(chunk_start, num_tokens - 1)
1094+
chunk_end = min(chunk_end, num_tokens)
11201095

1121-
#chunk_start = update_chunk_bound(chunk_start)
1122-
#chunk_end = update_chunk_bound(chunk_end)
1123-
if chunk_end == full_hidden_states.shape[0]:
1124-
# simply redo computation
1125-
pass
1126-
else:
1127-
chunk_start = update_chunk_bound(chunk_start)
1128-
chunk_end = update_chunk_bound(chunk_end)
1096+
process_chunk(chunk_start, chunk_end, skip_result_store = chunk_start_ >= num_tokens)
11291097

11301098
return full_final_hidden_states
11311099

0 commit comments

Comments
 (0)