Skip to content

Commit 764e646

Browse files
Varun Sundar Rabindranathbnellnm
authored andcommitted
fix forward_chunked
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Signed-off-by: Bill Nell <[email protected]>
1 parent d42186f commit 764e646

File tree

1 file changed

+15
-47
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+15
-47
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 15 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,40 +1024,16 @@ def forward(self, hidden_states: torch.Tensor,
10241024
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
10251025
self.layer_name)
10261026

1027+
10271028
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
10281029
full_router_logits: torch.Tensor):
10291030

1030-
ctx = get_forward_context()
1031-
1032-
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp
1033-
#cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu
1034-
num_tokens_across_dp = ctx.dp_metadata.num_tokens_across_dp
1035-
1036-
#In this function we define two ranges:
1037-
# 1. chunk_range - The current iteration of the loops's range over the
1038-
# DP world tokens
1039-
# 2. my_tokens_in_chunk - The tokens within chunk_range that this DP
1040-
# rank owns.
1041-
1042-
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size
1043-
1044-
num_tokens_remaining_across_dp = num_tokens_across_dp
1045-
chunk_start = 0
1046-
chunk_end = min(moe_dp_chunk_size_per_rank,
1047-
full_hidden_states.shape[0])
10481031
full_final_hidden_states = torch.empty_like(full_hidden_states)
10491032

1050-
assert full_hidden_states.shape[0] == full_router_logits.shape[0]
1051-
1052-
for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank):
1033+
def process_chunk(chunk_start, chunk_end, skip_result_store = False):
10531034
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
10541035
router_logits = full_router_logits[chunk_start:chunk_end, :]
10551036

1056-
cu_tokens_across_dp_this_iter = torch.cumsum(
1057-
num_tokens_remaining_across_dp.clamp(
1058-
max=moe_dp_chunk_size_per_rank),
1059-
dim=0)
1060-
10611037
# TODO: still may be needed for non-pplx, put into dispatcher class.
10621038
if False:
10631039
hidden_states = self.naive_multicast(
@@ -1103,30 +1079,22 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
11031079
final_hidden_states = tensor_model_parallel_all_reduce(
11041080
final_hidden_states)
11051081

1106-
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
1107-
final_hidden_states)
1108-
1109-
# Update bounds
1110-
num_tokens_remaining_across_dp = torch.clamp(
1111-
num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank,
1112-
min=0)
1082+
if not skip_result_store:
1083+
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
1084+
final_hidden_states)
11131085

1114-
# HACK FIX
1115-
if num_tokens_remaining_across_dp.sum() == 0:
1116-
break
1086+
max_tokens_across_dp = get_forward_context().dp_metadata.max_tokens_across_dp
1087+
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size
11171088

1118-
def update_chunk_bound(x: int):
1119-
return min(x + moe_dp_chunk_size_per_rank,
1120-
full_hidden_states.shape[0])
1089+
num_tokens = full_hidden_states.size(0)
1090+
for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank):
1091+
chunk_start = chunk_start_
1092+
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dp)
1093+
# clamp start and end
1094+
chunk_start = min(chunk_start, num_tokens - 1)
1095+
chunk_end = min(chunk_end, num_tokens)
11211096

1122-
#chunk_start = update_chunk_bound(chunk_start)
1123-
#chunk_end = update_chunk_bound(chunk_end)
1124-
if chunk_end == full_hidden_states.shape[0]:
1125-
# simply redo computation
1126-
pass
1127-
else:
1128-
chunk_start = update_chunk_bound(chunk_start)
1129-
chunk_end = update_chunk_bound(chunk_end)
1097+
process_chunk(chunk_start, chunk_end, skip_result_store = chunk_start_ >= num_tokens)
11301098

11311099
return full_final_hidden_states
11321100

0 commit comments

Comments
 (0)