Skip to content

Commit 792d751

Browse files
committed
seems to be working
Signed-off-by: Bill Nell <[email protected]>
1 parent 3b319a1 commit 792d751

File tree

5 files changed

+70
-59
lines changed

5 files changed

+70
-59
lines changed

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ def apply(
134134
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
135135
(a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids)
136136

137-
self.activation(activation, workspace2, workspace1.view(-1, N))
137+
self.activation(activation,
138+
workspace2,
139+
workspace1.view(-1, N))
138140

139141
a2q_scale: Optional[torch.Tensor] = None
140142

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,12 +1678,20 @@ def apply(
16781678
intermediate_cache3 = _resize_cache(workspace13,
16791679
(num_tokens, top_k_num, K))
16801680

1681-
sorted_token_ids, expert_ids, num_tokens_post_padded = (
1682-
moe_align_block_size(
1683-
topk_ids,
1684-
config['BLOCK_SIZE_M'] if self.block_m is None else self.block_m,
1685-
global_num_experts, expert_map
1686-
))
1681+
if hidden_states.dim() == 2: #block_m is None:
1682+
sorted_token_ids, expert_ids, num_tokens_post_padded = (
1683+
moe_align_block_size(
1684+
topk_ids,
1685+
config['BLOCK_SIZE_M'],
1686+
global_num_experts, expert_map
1687+
))
1688+
else:
1689+
stride = hidden_states.shape[1]
1690+
sorted_token_ids = torch.arange(0, hidden_states.shape[0], device=hidden_states.device, dtype=torch.int)
1691+
sorted_token_ids = sorted_token_ids * stride
1692+
expert_ids = torch.logical_not(torch.isnan(hidden_states)).sum(dim=(1,2)).nonzero()
1693+
num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int)
1694+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
16871695

16881696
invoke_fused_moe_kernel(hidden_states,
16891697
w1,
@@ -1706,7 +1714,8 @@ def apply(
17061714
per_channel_quant=self.per_channel_quant,
17071715
block_shape=self.block_shape)
17081716

1709-
self.activation(activation, intermediate_cache2,
1717+
self.activation(activation,
1718+
intermediate_cache2,
17101719
intermediate_cache1.view(-1, N))
17111720

17121721
a2q_scale: Optional[torch.Tensor] = None

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def apply(
241241
# Maybe extra args
242242
def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool:
243243
block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size)
244-
print(f"block_m = {block_m}")
244+
#print(f"block_m = {block_m}")
245245

246246
experts = TritonExperts(
247247
use_fp8_w8a8 = False,
@@ -550,8 +550,8 @@ def __init__(
550550
self.ep_size = 1
551551
self.local_num_experts = self.global_num_experts
552552
self.expert_map = None
553+
#self.global_num_experts = num_experts redundant?
553554
self.top_k = top_k
554-
self.global_num_experts = num_experts
555555

556556
assert intermediate_size % self.tp_size == 0
557557
self.hidden_size = hidden_size
@@ -571,11 +571,12 @@ def __init__(
571571
if self.scoring_func != "softmax" and not self.use_grouped_topk:
572572
raise ValueError("Only softmax scoring function is supported for "
573573
"non-grouped topk.")
574+
574575
if current_platform.is_hpu():
575576
from vllm_hpu_extension.ops import DynamicFusedMOE
576577
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
577578

578-
print(f"params dtype= {params_dtype}")
579+
#print(f"params dtype= {params_dtype}")
579580

580581
moe = MoEConfig(
581582
num_experts=self.global_num_experts,
@@ -604,59 +605,59 @@ def __init__(
604605
self.quant_method = quant_method
605606

606607
# TODO: move to method?
607-
if self.dp_size > 1:
608-
if True:
609-
max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size
610-
world_size = moe.ep_size
611-
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
612-
rank = moe.ep_rank
608+
if False and self.dp_size > 1:
609+
max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size
610+
world_size = moe.ep_size
611+
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
612+
rank = moe.ep_rank
613613

614+
if False:
614615
print(f"max num = {max_num_tokens}")
615616
print(f"world size = {world_size}")
616617
print(f"moe ep size = {moe.ep_size}")
617618
print(f"moe dp size = {moe.dp_size}")
618619
print(f"dp size = {dp_size}")
619620
print(f"rank= {rank}")
620621

621-
all_to_all = get_all_to_all(
622-
max_num_tokens=max_num_tokens,
623-
num_experts=moe.num_experts,
624-
experts_per_token=moe.experts_per_token, # topk
625-
rank=rank,
626-
world_size=world_size,
627-
dp_size=dp_size,
628-
hidden_dim=moe.hidden_dim,
629-
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
630-
# For blocked per token: set to ceil_div(hidden_dim, block_size) * sizeof(float32)
631-
# For per-token: set to sizeof(float32)
632-
hidden_dim_scale_bytes=(
633-
0
634-
if moe.in_dtype.itemsize != 1
635-
else (
636-
(moe.hidden_dim + moe.block_size - 1)
637-
// moe.block_size
638-
* torch.float32.itemsize
639-
)
622+
all_to_all = get_all_to_all(
623+
max_num_tokens=max_num_tokens,
624+
num_experts=moe.num_experts,
625+
experts_per_token=moe.experts_per_token, # topk
626+
rank=rank,
627+
world_size=world_size,
628+
dp_size=dp_size,
629+
hidden_dim=moe.hidden_dim,
630+
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
631+
# For blocked per token: set to ceil_div(hidden_dim, block_size) * sizeof(float32)
632+
# For per-token: set to sizeof(float32)
633+
hidden_dim_scale_bytes=(
634+
0
635+
if moe.in_dtype.itemsize != 1
636+
else (
637+
(moe.hidden_dim + moe.block_size - 1)
638+
// moe.block_size
639+
* torch.float32.itemsize
640640
)
641641
)
642+
)
642643

643-
dispatch_combine = PplxDispatchCombine(
644-
all_to_all,
645-
max_num_tokens,
646-
world_size,
647-
dp_size,
648-
rank, # just for debugging
649-
moe.in_dtype,
650-
)
651-
else:
652-
dispatch_combine = StandardDispatchCombine(
653-
moe.in_dtype,
654-
quant_config.weight_block_size if quant_config is not None else None,
655-
)
644+
dispatch_combine = PplxDispatchCombine(
645+
all_to_all,
646+
max_num_tokens,
647+
world_size,
648+
dp_size,
649+
rank, # just for debugging
650+
moe.in_dtype,
651+
)
656652

657653
success = self.quant_method.set_dispatch_combine(dispatch_combine)
658654
if not success:
659655
logger.warning("DP+EP not supported for %s.", type(self.quant_method))
656+
else:
657+
dispatch_combine = StandardDispatchCombine(
658+
moe.in_dtype,
659+
quant_config.weight_block_size if quant_config is not None else None,
660+
)
660661

661662
self.apply_router_weight_on_input = apply_router_weight_on_input
662663
moe_quant_params = {
@@ -1010,7 +1011,7 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
10101011
num_tokens_across_dp = get_forward_context(
10111012
).dp_metadata.num_tokens_across_dp
10121013

1013-
print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}")
1014+
#print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}")
10141015

10151016
#In this function we define two ranges:
10161017
# 1. chunk_range - The current iteration of the loops's range over the DP world tokens

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@ def _moe_problem_size(
6060
E, N, _ = w1.shape
6161
K = w2.shape[1]
6262

63-
assert topk_ids.dim() == 2
64-
topk = topk_ids.shape[1]
65-
6663
if a1.dim() == 2:
6764
# Make sure we are using the correct a1 (pre-permute).
6865
assert topk_ids.shape[0] == a1.shape[0], \
@@ -73,6 +70,9 @@ def _moe_problem_size(
7370
assert E == a1.shape[0]
7471
M = a1.shape[1] # This is max_num_tokens
7572

73+
assert topk_ids.dim() == 2
74+
topk = topk_ids.shape[1]
75+
7676
return E, M, N, K, topk
7777

7878

vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@ def __init__(self,
3232
self.dp_size = dp_size
3333
self.rank = rank
3434
self.quant_dtype = quant_dtype
35-
print(f"max_num_tokens = {max_num_tokens}")
36-
print(f"dp_num_tokens = {self.dp_num_tokens}")
37-
print(f"world_size = {world_size}")
38-
print(f"dp_size = {dp_size}")
3935

4036
def dispatch(
4137
self,
@@ -77,15 +73,15 @@ def dispatch(
7773
dtype=torch.int32,
7874
device=a1.device,
7975
)
80-
expert_num_tokens.fill_(-1)
76+
expert_num_tokens.fill_(-1) # debugging remove
8177

8278
num_dp = self.world_size // self.dp_size
8379
expert_x = torch.empty(
8480
(num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]),
8581
dtype=a1q.dtype,
8682
device=a1.device,
8783
)
88-
expert_x.fill_(torch.nan)
84+
expert_x.fill_(torch.nan) # debugging remove
8985

9086
expert_x_scale: Optional[torch.Tensor] = None
9187
if a1q.dtype.itemsize == 1:
@@ -146,3 +142,6 @@ def combine(
146142
weights=topk_weights,
147143
expert_y=fused_expert_output,
148144
bound_m=bound_m)
145+
146+
#print("END COMBINE")
147+

0 commit comments

Comments
 (0)