Skip to content

Commit be24517

Browse files
committed
wip
Signed-off-by: Bill Nell <[email protected]>
1 parent 792d751 commit be24517

File tree

5 files changed

+37
-13
lines changed

5 files changed

+37
-13
lines changed

vllm/distributed/parallel_state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,7 @@ def get_tensor_model_parallel_rank():
10981098
def destroy_model_parallel():
10991099
"""Set the groups to none and destroy them."""
11001100
global _TP
1101+
11011102
nvshmem_finalize()
11021103

11031104
if _TP:

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,12 +1686,18 @@ def apply(
16861686
global_num_experts, expert_map
16871687
))
16881688
else:
1689-
stride = hidden_states.shape[1]
1690-
sorted_token_ids = torch.arange(0, hidden_states.shape[0], device=hidden_states.device, dtype=torch.int)
1691-
sorted_token_ids = sorted_token_ids * stride
1692-
expert_ids = torch.logical_not(torch.isnan(hidden_states)).sum(dim=(1,2)).nonzero()
1693-
num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int)
1689+
#stride = hidden_states.shape[1]
1690+
sorted_token_ids = torch.arange(0, num_tokens*hidden_states.shape[1], device=hidden_states.device, dtype=torch.int)
1691+
sorted_token_ids = sorted_token_ids.flatten()
1692+
nans = torch.isnan(hidden_states).sum(dim=(1,2))
1693+
expert_ids = torch.where((nans > 0).flatten(), -1, torch.arange(0, nans.numel(), device=hidden_states.device, dtype=torch.int32))
1694+
#expert_ids = torch.repeat_interleave(expert_ids, hidden_states.shape[1], dim=0)
1695+
#print(f"EXPERT_IDS {nans.shape} {expert_ids}")
1696+
#num_tokens_post_padded = torch.tensor([num_tokens], device=hidden_states.device, dtype=torch.int32)
1697+
num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int32)
1698+
num_tokens_post_padded.fill_(num_tokens)
16941699
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
1700+
#print(f"P = {sorted_token_ids}, {hidden_states.shape}")
16951701

16961702
invoke_fused_moe_kernel(hidden_states,
16971703
w1,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def get_or_create(self, **kwargs):
116116

117117
with self._lock:
118118
instance = self._cache.get(key)
119-
if instance is None:
119+
if True or instance is None:
120120
instance = pplx.AllToAll(**kwargs)
121121
self._cache[key] = instance
122122
return instance
@@ -605,7 +605,7 @@ def __init__(
605605
self.quant_method = quant_method
606606

607607
# TODO: move to method?
608-
if False and self.dp_size > 1:
608+
if self.dp_size > 1:
609609
max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size
610610
world_size = moe.ep_size
611611
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
@@ -1029,6 +1029,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
10291029
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
10301030
router_logits = full_router_logits[chunk_start:chunk_end, :]
10311031

1032+
print(f"loop {chunk_start}:{chunk_end}")
1033+
10321034
cu_tokens_across_dp_this_iter = torch.cumsum(
10331035
num_tokens_remaining_across_dp.clamp(
10341036
max=moe_dp_chunk_size_per_rank),

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,9 @@ def forward(
312312
Returns:
313313
- torch.Tensor: The output tensor after applying the MoE layer.
314314
"""
315+
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank)
316+
print(f"START {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}")
317+
315318
a1 = hidden_states
316319
E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids)
317320

@@ -361,4 +364,6 @@ def forward(
361364
self.dispatch_combine.combine(output, fused_out, topk_weights,
362365
topk_ids, apply_router_weight_on_input)
363366

367+
print(f"DONE {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}")
368+
364369
return output

vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def dispatch(
4646
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
4747
# Is this always going to be a1.device?
4848
device = a1.device
49+
num_tokens = a1.shape[0] # M
50+
hidden_dim = a1.shape[-1] # K
4951

5052
assert expert_map is None, "NYI"
5153

@@ -71,15 +73,15 @@ def dispatch(
7173
expert_num_tokens = torch.empty(
7274
num_local_experts,
7375
dtype=torch.int32,
74-
device=a1.device,
76+
device=device,
7577
)
7678
expert_num_tokens.fill_(-1) # debugging remove
7779

7880
num_dp = self.world_size // self.dp_size
7981
expert_x = torch.empty(
8082
(num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]),
8183
dtype=a1q.dtype,
82-
device=a1.device,
84+
device=device,
8385
)
8486
expert_x.fill_(torch.nan) # debugging remove
8587

@@ -95,7 +97,7 @@ def dispatch(
9597
(expert_x.size(2) + block_size - 1) // block_size,
9698
),
9799
dtype=torch.float32,
98-
device=a1.device,
100+
device=device,
99101
)
100102

101103
# This argument is optional, defaults to indices.shape[0]
@@ -105,7 +107,7 @@ def dispatch(
105107
bound_m = None
106108

107109
# TODO: optimize this?
108-
indices = rank_topk_ids.to(dtype=torch.uint32)
110+
indices = rank_topk_ids.to(dtype=torch.uint32).to(device)
109111

110112
self.a2a.dispatch(
111113
out_expert_num_tokens=expert_num_tokens,
@@ -126,8 +128,17 @@ def combine(
126128
topk_ids: torch.Tensor,
127129
apply_router_weight_on_input: bool,
128130
) -> None:
131+
device = fused_expert_output.device
132+
#device = torch.device("cuda", self.rank)
133+
#device = get_dp_group().device
134+
#assert fused_expert_output.device == device
135+
136+
print(f"COMBINE START {self.rank}")
137+
129138
# This argument is optional
130139
#bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens
140+
#num_tokens = fused_expert_output.shape[0] # M
141+
#bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device)
131142
bound_m = None
132143

133144
assert output.shape[0] <= self.max_num_tokens
@@ -143,5 +154,4 @@ def combine(
143154
expert_y=fused_expert_output,
144155
bound_m=bound_m)
145156

146-
#print("END COMBINE")
147-
157+
print(f"COMBINE END {self.rank}")

0 commit comments

Comments
 (0)