Skip to content

Commit a876454

Browse files
committed
wip
Signed-off-by: Bill Nell <[email protected]>
1 parent e0560d5 commit a876454

File tree

5 files changed

+32
-34
lines changed

5 files changed

+32
-34
lines changed

tests/kernels/test_pplx_moe.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@
2323
import vllm.model_executor.layers.fused_moe # noqa
2424
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
2525
torch_moe, torch_moe_single)
26-
#from vllm import _custom_ops as ops
2726
from vllm.config import VllmConfig, set_current_vllm_config
28-
#from vllm.model_executor.layers.fused_moe import fused_moe
29-
#from vllm.model_executor.layers.fused_moe.fused_batched_moe import fused_batched_experts
3027
from vllm.model_executor.layers.fused_moe.fused_moe import (
3128
fused_topk, moe_align_block_size)
3229
from vllm.platforms import current_platform

vllm/forward_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def set_forward_context(attn_metadata: Any,
9393
from vllm.distributed.parallel_state import get_dp_group
9494
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
9595
#TODO device?
96-
max_tokens_across_dp = torch.max(num_tokens_tensor).to(device="cuda")
96+
max_tokens_across_dp = torch.max(num_tokens_tensor) #.to(device="cuda")
9797
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
9898
dp_rank_num_tokens = torch.tensor(
9999
[num_tokens],

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,8 +1594,9 @@ def workspace_shapes(
15941594
topk: int,
15951595
num_experts: int,
15961596
) -> Tuple[int, int, torch.dtype]:
1597-
workspace1 = M * topk * max(N * 2, K)
1598-
workspace2 = M * topk * N
1597+
factor = num_experts if a.dim() == 3 else 1
1598+
workspace1 = M * topk * max(N * 2, K) * factor
1599+
workspace2 = M * topk * N * factor
15991600
return (workspace1, workspace2, a.dtype)
16001601

16011602
def apply(
@@ -1686,16 +1687,15 @@ def apply(
16861687
global_num_experts, expert_map
16871688
))
16881689
else:
1689-
#stride = hidden_states.shape[1]
1690-
sorted_token_ids = torch.arange(0, num_tokens*hidden_states.shape[1], device=hidden_states.device, dtype=torch.int)
1690+
max_num_tokens = hidden_states.shape[1]
1691+
sorted_token_ids = torch.arange(0, hidden_states.shape[0] * max_num_tokens, device=hidden_states.device, dtype=torch.int)
16911692
sorted_token_ids = sorted_token_ids.flatten()
1692-
nans = torch.isnan(hidden_states).sum(dim=(1,2))
1693-
expert_ids = torch.where((nans > 0).flatten(), -1, torch.arange(0, nans.numel(), device=hidden_states.device, dtype=torch.int32))
1694-
#expert_ids = torch.repeat_interleave(expert_ids, hidden_states.shape[1], dim=0)
1695-
#print(f"EXPERT_IDS {nans.shape} {expert_ids}")
1693+
expert_ids = torch.arange(0, global_num_experts, device=hidden_states.device, dtype=torch.int)
1694+
expert_ids = torch.repeat_interleave(expert_ids, max_num_tokens, dim=0)
1695+
print(f"EXPERT_IDS {expert_ids}")
16961696
#num_tokens_post_padded = torch.tensor([num_tokens], device=hidden_states.device, dtype=torch.int32)
16971697
num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int32)
1698-
num_tokens_post_padded.fill_(num_tokens)
1698+
num_tokens_post_padded.fill_(max_num_tokens)
16991699
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
17001700
#print(f"P = {sorted_token_ids}, {hidden_states.shape}")
17011701

@@ -1857,19 +1857,18 @@ def __init__(
18571857

18581858
def workspace_shapes(
18591859
self,
1860-
a_dtype: torch.dtype,
1860+
a: torch.Tensor,
18611861
M: int,
18621862
N: int,
18631863
K: int,
18641864
topk: int,
18651865
num_experts: int,
1866-
a: torch.Tensor,
18671866
) -> Tuple[int, int, torch.dtype]:
18681867
#assert self.max_num_tokens >= a.shape[1]
18691868
max_num_tokens = a.shape[1] if self.max_num_tokens is None else self.max_num_tokens
18701869
workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack
18711870
workspace2 = max_num_tokens * N
1872-
return (workspace13, workspace2, a_dtype)
1871+
return (workspace13, workspace2, a.dtype)
18731872

18741873
def apply(
18751874
self,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine
249249
logger.info(f"BatchedExperts {self.moe}")
250250
experts = BatchedExperts() #rank=self.moe.ep_rank, world_size=self.moe.ep_size)
251251
else:
252+
logger.info(f"TritonExperts {self.moe}")
252253
experts = TritonExperts(
253254
use_fp8_w8a8 = False,
254255
use_int8_w8a16 = False,
@@ -1011,21 +1012,20 @@ def forward(self, hidden_states: torch.Tensor,
10111012
router_logits: torch.Tensor):
10121013
if self.use_direct_call:
10131014
return self.forward_impl(hidden_states, router_logits)
1014-
else:
1015+
elif True:
10151016
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
10161017
self.layer_name)
10171018

10181019
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
10191020
full_router_logits: torch.Tensor):
10201021

1021-
max_tokens_across_dp = get_forward_context(
1022-
).dp_metadata.max_tokens_across_dp
1023-
cu_tokens_across_dp_cpu = get_forward_context(
1024-
).dp_metadata.cu_tokens_across_dp_cpu
1025-
num_tokens_across_dp = get_forward_context(
1026-
).dp_metadata.num_tokens_across_dp
1022+
ctx = get_forward_context()
1023+
1024+
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp
1025+
#cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu
1026+
num_tokens_across_dp = ctx.dp_metadata.num_tokens_across_dp
10271027

1028-
#print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}")
1028+
#print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{ctx.dp_metadata.dp_rank_num_tokens}")
10291029

10301030
#In this function we define two ranges:
10311031
# 1. chunk_range - The current iteration of the loops's range over the DP world tokens
@@ -1042,17 +1042,19 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
10421042
#print(f"ORIGINAL SHAPE {full_hidden_states.shape}")
10431043
#print(f"moe_dp_chunk_size_per_rank = {moe_dp_chunk_size_per_rank}")
10441044

1045+
assert full_hidden_states.shape[0] == full_router_logits.shape[0]
1046+
10451047
for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank):
10461048
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
10471049
router_logits = full_router_logits[chunk_start:chunk_end, :]
10481050

1049-
#print(f"loop {iter}: {chunk_start}:{chunk_end}, {hidden_states.shape}")
1050-
10511051
cu_tokens_across_dp_this_iter = torch.cumsum(
10521052
num_tokens_remaining_across_dp.clamp(
10531053
max=moe_dp_chunk_size_per_rank),
10541054
dim=0)
10551055

1056+
print(f"loop {iter}: {chunk_start}:{chunk_end}, {hidden_states.shape} {cu_tokens_across_dp_this_iter}")
1057+
10561058
hidden_states = self.naive_multicast(
10571059
hidden_states, cu_tokens_across_dp_this_iter)
10581060
router_logits = self.naive_multicast(
@@ -1087,14 +1089,14 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
10871089
final_hidden_states)
10881090
final_hidden_states = all_hidden_states[start:end, :]
10891091

1090-
#print(f"final2 (AR) = {final_hidden_states.shape}")
1092+
print(f"final2 (AR) = {final_hidden_states.shape}")
10911093

10921094
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
10931095
# Default set to False. (May have to add shared expert outputs.)
10941096
final_hidden_states = tensor_model_parallel_all_reduce(
10951097
final_hidden_states)
10961098

1097-
#print(f"final3 (AR) = {final_hidden_states.shape}")
1099+
print(f"final3 (AR) = {final_hidden_states.shape}")
10981100

10991101
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
11001102
final_hidden_states)
@@ -1128,8 +1130,9 @@ def forward_impl(self, hidden_states: torch.Tensor,
11281130
assert self.quant_method is not None
11291131

11301132
if self.dp_size > 1:
1131-
cu_tokens_across_dp_cpu = get_forward_context(
1132-
).dp_metadata.cu_tokens_across_dp_cpu
1133+
print("FORWARD_IMPL")
1134+
ctx = get_forward_context()
1135+
cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu
11331136

11341137
hidden_states = self.naive_multicast(hidden_states,
11351138
cu_tokens_across_dp_cpu)

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,20 @@ def __init__(
3737

3838
def workspace_shapes(
3939
self,
40-
a_dtype: torch.dtype,
40+
a: torch.Tensor,
4141
M: int,
4242
N: int,
4343
K: int,
4444
topk: int,
4545
num_experts: int,
46-
a: torch.Tensor,
4746
) -> Tuple[int, int, torch.dtype]:
4847
# Note: the deep gemm workspaces are strictly larger than the triton
4948
# workspaces so we can be pessimistic here and allocate for DeepGemm
5049
# even if we fall back to triton later, e.g. if expert maps are set.
5150
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
52-
return self.deep_gemm_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts, a)
51+
return self.deep_gemm_expert.workspace_shapes(a, M, N, K, topk, num_experts)
5352
else:
54-
return self.triton_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts, a)
53+
return self.triton_expert.workspace_shapes(a, M, N, K, topk, num_experts)
5554

5655
def apply(
5756
self,

0 commit comments

Comments
 (0)