Skip to content

Commit 4971b43

Browse files
committed
hack fix for chunking loop
Signed-off-by: Bill Nell <[email protected]>
1 parent c69354d commit 4971b43

File tree

5 files changed

+48
-27
lines changed

5 files changed

+48
-27
lines changed

tests/kernels/moe/test_moe.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_fused_moe(
108108
rtol=0)
109109

110110

111-
def batch_by_experts(
111+
def torch_dispatch(
112112
a: torch.Tensor,
113113
topk_ids: torch.Tensor,
114114
num_experts: int
@@ -138,14 +138,14 @@ def batch_by_experts(
138138
return b_a, tokens_per_expert
139139

140140

141-
def unbatch_output(b_out, topk_weight, topk_ids, K):
141+
def torch_combine(b_out, topk_weight, topk_ids):
142142
num_tokens, topk = topk_ids.shape
143143

144144
num_experts = b_out.shape[0]
145145
topk = topk_ids.shape[1]
146+
K = b_out.shape[-1]
146147
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
147148
expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device)
148-
experts = torch.arange(0, num_experts, dtype=torch.int, device=b_out.device)
149149
for token in range(num_tokens):
150150
expert_ids = topk_ids[token]
151151
for i in range(expert_ids.numel()):
@@ -157,22 +157,25 @@ def unbatch_output(b_out, topk_weight, topk_ids, K):
157157
return out
158158

159159

160-
def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids):
161-
assert a.dim() == 3
162-
num_tokens, topk = topk_ids.shape
163-
_, max_num_tokens, K = a.shape
160+
def torch_batched_moe(a, w1, w2, topk_weight, topk_ids):
164161
num_experts = w1.shape[0]
165-
out = torch.zeros((num_experts, max_num_tokens, w2.shape[1]), dtype=a.dtype, device=a.device)
162+
b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts)
163+
assert b_a.dim() == 3
164+
num_tokens, topk = topk_ids.shape
165+
_, max_num_tokens, K = b_a.shape
166+
assert num_experts == b_a.shape[0] and K == w2.shape[1]
167+
out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device)
168+
tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device)
166169
for expert in range(num_experts):
167170
num = tokens_per_expert[expert]
168171
if num > 0:
169-
out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
172+
torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1))
173+
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
170174

171-
out = unbatch_output(out, topk_weight, topk_ids, K)
172-
173-
return out
175+
return torch_combine(out, topk_weight, topk_ids)
174176

175177

178+
# TODO: same as torch_moe but with fused_topk factored out.
176179
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
177180
M, K = a.shape
178181
topk = topk_ids.shape[1]
@@ -217,16 +220,14 @@ def test_fused_moe_batched_experts(
217220

218221
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
219222

220-
b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e)
221-
222223
if True:
223-
triton_output = torch_batched_moe(b_a,
224+
triton_output = torch_batched_moe(a,
224225
w1,
225226
w2,
226-
tokens_per_expert,
227227
topk_weight,
228228
topk_ids)
229229
else:
230+
b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e)
230231
triton_output = fused_batched_experts(
231232
b_a,
232233
w1,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,7 +1783,7 @@ def workspace_shapes(
17831783
) -> Tuple[int, int, torch.dtype]:
17841784
max_num_tokens = a.shape[1]
17851785
workspace13 = num_experts * max_num_tokens * K
1786-
workspace2 = M * topk * N * num_experts
1786+
workspace2 = max_num_tokens * (N // 2)
17871787
return (workspace13, workspace2, a_dtype)
17881788

17891789
def apply(
@@ -1810,12 +1810,14 @@ def apply(
18101810
_, max_num_tokens, K = hidden_states.shape
18111811
num_experts = w1.shape[0]
18121812
out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1]))
1813+
# causes deadlock
18131814
#tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
18141815
for expert in range(num_experts):
1815-
num = 1 #tokens_per_expert[expert]
1816+
num = max_num_tokens #tokens_per_expert[expert]
18161817
if num > 0:
1817-
#out[expert, :num, :] = SiluAndMul(hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
1818-
out[expert, :, :] = SiluAndMul()(hidden_states[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
1818+
tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2))
1819+
torch.ops._C.silu_and_mul(tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1))
1820+
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
18191821

18201822
return out
18211823

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,11 +1028,15 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
10281028
full_hidden_states.shape[0])
10291029
full_final_hidden_states = torch.empty_like(full_hidden_states)
10301030

1031-
for _ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank):
1031+
#print(f"ORIGINAL SHAPE {full_hidden_states.shape}")
1032+
1033+
#print(f"moe_dp_chunk_size_per_rank = {moe_dp_chunk_size_per_rank}")
1034+
1035+
for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank):
10321036
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
10331037
router_logits = full_router_logits[chunk_start:chunk_end, :]
10341038

1035-
#print(f"loop {chunk_start}:{chunk_end}")
1039+
#print(f"loop {iter}: {chunk_start}:{chunk_end}, {hidden_states.shape}")
10361040

10371041
cu_tokens_across_dp_this_iter = torch.cumsum(
10381042
num_tokens_remaining_across_dp.clamp(
@@ -1062,6 +1066,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
10621066
activation=self.activation,
10631067
)
10641068

1069+
#print(f"final1 = {final_hidden_states.shape}")
1070+
10651071
if self.dp_size > 1:
10661072
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[
10671073
self.dp_rank - 1]
@@ -1071,19 +1077,31 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
10711077
final_hidden_states)
10721078
final_hidden_states = all_hidden_states[start:end, :]
10731079

1080+
#print(f"final2 (AR) = {final_hidden_states.shape}")
1081+
10741082
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
10751083
# Default set to False. (May have to add shared expert outputs.)
10761084
final_hidden_states = tensor_model_parallel_all_reduce(
10771085
final_hidden_states)
10781086

1087+
#print(f"final3 (AR) = {final_hidden_states.shape}")
1088+
10791089
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
10801090
final_hidden_states)
10811091

1092+
#print(f"full final = {full_final_hidden_states.shape}")
1093+
10821094
# Update bounds
10831095
num_tokens_remaining_across_dp = torch.clamp(
10841096
num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank,
10851097
min=0)
10861098

1099+
#print(f"num remaining = {num_tokens_remaining_across_dp}")
1100+
1101+
# HACK FIX
1102+
if num_tokens_remaining_across_dp.sum() == 0:
1103+
break
1104+
10871105
def update_chunk_bound(x: int):
10881106
return min(x + moe_dp_chunk_size_per_rank,
10891107
full_hidden_states.shape[0])

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,8 @@ def forward(
312312
Returns:
313313
- torch.Tensor: The output tensor after applying the MoE layer.
314314
"""
315-
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank)
316-
print(f"START {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}")
315+
#from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank)
316+
#print(f"START {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}")
317317

318318
a1 = hidden_states
319319
E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids)
@@ -364,6 +364,6 @@ def forward(
364364
self.dispatch_combine.combine(output, fused_out, topk_weights,
365365
topk_ids, apply_router_weight_on_input)
366366

367-
print(f"DONE {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}")
367+
#print(f"DONE {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}")
368368

369369
return output

vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def combine(
133133
#device = get_dp_group().device
134134
#assert fused_expert_output.device == device
135135

136-
print(f"COMBINE START {self.rank}")
136+
#print(f"COMBINE START {self.rank}")
137137

138138
# This argument is optional
139139
#bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens
@@ -154,4 +154,4 @@ def combine(
154154
expert_y=fused_expert_output,
155155
bound_m=bound_m)
156156

157-
print(f"COMBINE END {self.rank}")
157+
#print(f"COMBINE END {self.rank}")

0 commit comments

Comments
 (0)