Skip to content

Commit 0851b31

Browse files
Varun Sundar Rabindranathbnellnm
authored andcommitted
wip
Signed-off-by: Bill Nell <[email protected]>
1 parent 054c10a commit 0851b31

File tree

12 files changed

+360
-360
lines changed

12 files changed

+360
-360
lines changed

examples/offline_inference/data_parallel.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,14 @@ def parse_args():
6969
parser.add_argument("--enforce-eager",
7070
action='store_true',
7171
help="Enforce eager mode execution.")
72+
parser.add_argument("--trust-remote-code",
73+
action='store_true',
74+
help="Trust remote code.")
7275
return parser.parse_args()
7376

7477

7578
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
76-
dp_master_port, GPUs_per_dp_rank, enforce_eager):
79+
dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
7780
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
7881
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
7982
os.environ["VLLM_DP_SIZE"] = str(dp_size)
@@ -125,6 +128,7 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
125128
enforce_eager=enforce_eager,
126129
enable_expert_parallel=True,
127130
compilation_config=cconfig,
131+
trust_remote_code=trust_remote_code,
128132
)
129133
outputs = llm.generate(prompts, sampling_params)
130134
# Print the outputs.
@@ -168,12 +172,12 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
168172
proc = Process(target=main,
169173
args=(args.model, dp_size, local_dp_rank,
170174
global_dp_rank, dp_master_ip, dp_master_port,
171-
tp_size, args.enforce_eager))
175+
tp_size, args.enforce_eager, args.trust_remote_code))
172176
proc.start()
173177
procs.append(proc)
174178
exit_code = 0
175179
for proc in procs:
176-
proc.join(timeout=3000)
180+
proc.join(timeout=300)
177181
if proc.exitcode is None:
178182
print(f"Killing process {proc.pid} that "
179183
f"didn't stop within 5 minutes.")

tests/kernels/moe/test_pplx_moe.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,9 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
347347
ata,
348348
max_num_tokens,
349349
world_size,
350-
dp_size,
351350
rank,
351+
dp_size,
352+
a.dtype,
352353
)
353354

354355
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
@@ -486,8 +487,8 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
486487
ata,
487488
max_num_tokens,
488489
world_size,
489-
dp_size,
490490
rank,
491+
dp_size,
491492
)
492493

493494
experts = BatchedExperts(a.shape[0])
@@ -584,13 +585,13 @@ def _pplx_moe(
584585
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
585586
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
586587
pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
587-
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
588+
batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
588589

589590
torch_output = chunk_by_rank(torch_output, pgi.rank,
590591
pgi.world_size).to(pplx_output.device)
591592

592593
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
593-
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
594+
torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
594595

595596
nvshmem_finalize()
596597

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 29 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,8 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
587587

588588
def __init__(
589589
self,
590+
world_size: int,
591+
dp_size: int,
590592
max_num_tokens: Optional[int] = None,
591593
use_fp8_w8a8: bool = False,
592594
use_int8_w8a8: bool = False,
@@ -603,6 +605,8 @@ def __init__(
603605
assert not use_int8_w8a16, "NYI"
604606
assert not use_int4_w4a16, "NYI"
605607
self.max_num_tokens = max_num_tokens
608+
self.world_size = world_size
609+
self.dp_size = dp_size
606610

607611
def workspace_shapes(
608612
self,
@@ -614,10 +618,12 @@ def workspace_shapes(
614618
num_experts: int,
615619
) -> Tuple[int, int, torch.dtype]:
616620
assert a.dim() == 2
621+
num_dp = self.world_size // self.dp_size
617622
max_num_tokens = a.shape[
618623
0] if self.max_num_tokens is None else self.max_num_tokens
619-
workspace13 = num_experts * max_num_tokens * K
620-
workspace2 = max_num_tokens * N
624+
#print(f"WORKSPACE {max_num_tokens} {num_dp}")
625+
workspace13 = num_experts * max_num_tokens * num_dp * K
626+
workspace2 = max_num_tokens * num_dp * N
621627
return (workspace13, workspace2, a.dtype)
622628

623629
def apply(
@@ -648,23 +654,24 @@ def apply(
648654
else:
649655
max_num_tokens = self.max_num_tokens
650656

657+
num_dp = self.world_size // self.dp_size
651658
num_experts = global_num_experts
652659
out = _resize_cache(workspace13,
653-
(num_experts, max_num_tokens, hidden_dim))
660+
(num_experts, max_num_tokens * num_dp, hidden_dim))
654661
num_local_experts = w1.shape[0] #expert_num_tokens.numel()
655662
assert num_local_experts == w1.shape[0], f"{num_local_experts} == {w1.shape[0]}"
656663

657664
N = w1.shape[1] // 2
658665

659666
# Not cudagraph friendly
660-
assert (torch.cuda.is_current_stream_capturing() or
661-
torch.all(expert_num_tokens <= max_num_tokens)), (
662-
f"{expert_num_tokens} <= {max_num_tokens}")
667+
# assert (torch.cuda.is_current_stream_capturing() or
668+
# torch.all(expert_num_tokens <= max_num_tokens)), (
669+
# f"{expert_num_tokens} <= {max_num_tokens}")
663670

664671
for expert in range(num_local_experts):
665672
# Indexing expert_num_tokens doesn't work w/cudagraphs
666-
if torch.cuda.is_current_stream_capturing():
667-
num = max_num_tokens
673+
if True or torch.cuda.is_current_stream_capturing():
674+
num = max_num_tokens * num_dp
668675
else:
669676
num = int(expert_num_tokens[expert].item())
670677
tmp = _resize_cache(workspace2, (num, N))
@@ -675,166 +682,6 @@ def apply(
675682
return out
676683

677684

678-
def _apply(
679-
hidden_states: torch.Tensor,
680-
w1: torch.Tensor,
681-
w2: torch.Tensor,
682-
topk_ids: torch.Tensor,
683-
activation: str,
684-
global_num_experts: int,
685-
expert_map: Optional[torch.Tensor],
686-
w1_scale: Optional[torch.Tensor],
687-
w2_scale: Optional[torch.Tensor],
688-
w1_zp: Optional[torch.Tensor],
689-
w2_zp: Optional[torch.Tensor],
690-
a1q_scale: Optional[torch.Tensor],
691-
a2_scale: Optional[torch.Tensor],
692-
workspace13: torch.Tensor,
693-
workspace2: torch.Tensor,
694-
expert_num_tokens: Optional[torch.Tensor],
695-
use_fp8_w8a8: bool,
696-
use_int8_w8a16: bool,
697-
use_int4_w4a16: bool,
698-
block_shape: Optional[List[int]],
699-
) -> torch.Tensor:
700-
# Check constraints.
701-
if use_int4_w4a16:
702-
assert hidden_states.shape[-1] // 2 == w1.shape[
703-
2], "Hidden size mismatch"
704-
else:
705-
assert hidden_states.shape[-1] == w1.shape[2], \
706-
(f"Hidden size mismatch {hidden_states.shape[-1]} "
707-
f"!= {w1.shape[2]}")
708-
709-
assert hidden_states.is_contiguous(
710-
), "Hidden_states must be contiguous"
711-
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
712-
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
713-
assert hidden_states.dtype in [
714-
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
715-
]
716-
717-
# TODO: num_tokens -> max_num_tokens?
718-
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
719-
hidden_states, w1, w2, topk_ids)
720-
721-
assert w1.shape[0] == E
722-
assert w2.shape[0] == E
723-
724-
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
725-
use_int8_w8a16=use_int8_w8a16,
726-
use_int4_w4a16=use_int4_w4a16,
727-
dtype=hidden_states.dtype)
728-
729-
config = try_get_optimal_moe_config(
730-
w1.shape,
731-
w2.shape,
732-
top_k_num,
733-
config_dtype,
734-
num_tokens,
735-
block_shape=block_shape,
736-
)
737-
738-
if hidden_states.dtype == torch.bfloat16:
739-
compute_type = tl.bfloat16
740-
elif hidden_states.dtype == torch.float16:
741-
compute_type = tl.float16
742-
elif hidden_states.dtype == torch.float32:
743-
compute_type = tl.float32
744-
elif hidden_states.dtype == torch.float8_e4m3fn:
745-
compute_type = tl.bfloat16
746-
else:
747-
raise ValueError(
748-
f"Unsupported compute_type: {hidden_states.dtype}")
749-
750-
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
751-
# We can reuse the memory between these because by the time we need
752-
# cache3, we're done with cache1
753-
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
754-
intermediate_cache2 = _resize_cache(workspace2,
755-
(E, num_tokens, N // 2))
756-
intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K))
757-
758-
# MM1
759-
invoke_moe_batched_triton_kernel(A=hidden_states,
760-
B=w1,
761-
C=intermediate_cache1,
762-
expert_num_tokens=expert_num_tokens,
763-
compute_type=compute_type,
764-
A_scale=a1q_scale,
765-
B_scale=w1_scale,
766-
B_zp=w1_zp,
767-
use_fp8_w8a8=use_fp8_w8a8,
768-
use_int8_w8a16=use_int8_w8a16,
769-
use_int4_w4a16=use_int4_w4a16,
770-
config=config,
771-
block_shape=block_shape)
772-
773-
# Fix activations
774-
assert activation == "silu"
775-
invoke_batched_silu_and_mul(output=intermediate_cache2,
776-
input=intermediate_cache1,
777-
expert_num_tokens=expert_num_tokens)
778-
779-
#qintermediate_cache2 = intermediate_cache2
780-
a2q_scale = a2_scale
781-
# TODO (varun) : support w8a8
782-
assert not use_fp8_w8a8
783-
#if self.use_fp8_w8a8:
784-
# qintermediate_cache2, a2q_scale = _fp8_quantize(
785-
# intermediate_cache2, a2_scale, self.block_shape)
786-
787-
invoke_moe_batched_triton_kernel(A=intermediate_cache2,
788-
B=w2,
789-
C=intermediate_cache3,
790-
expert_num_tokens=expert_num_tokens,
791-
compute_type=compute_type,
792-
A_scale=a2q_scale,
793-
B_scale=w2_scale,
794-
B_zp=w2_zp,
795-
use_fp8_w8a8=use_fp8_w8a8,
796-
use_int8_w8a16=use_int8_w8a16,
797-
use_int4_w4a16=use_int4_w4a16,
798-
config=config,
799-
block_shape=block_shape)
800-
801-
return intermediate_cache3
802-
803-
804-
def _apply_fake(
805-
hidden_states: torch.Tensor,
806-
w1: torch.Tensor,
807-
w2: torch.Tensor,
808-
topk_ids: torch.Tensor,
809-
activation: str,
810-
global_num_experts: int,
811-
expert_map: Optional[torch.Tensor],
812-
w1_scale: Optional[torch.Tensor],
813-
w2_scale: Optional[torch.Tensor],
814-
w1_zp: Optional[torch.Tensor],
815-
w2_zp: Optional[torch.Tensor],
816-
a1q_scale: Optional[torch.Tensor],
817-
a2_scale: Optional[torch.Tensor],
818-
workspace13: torch.Tensor,
819-
workspace2: torch.Tensor,
820-
expert_num_tokens: Optional[torch.Tensor],
821-
use_fp8_w8a8: bool,
822-
use_int8_w8a16: bool,
823-
use_int4_w4a16: bool,
824-
block_shape: Optional[List[int]],
825-
) -> torch.Tensor:
826-
return torch.empty_like(hidden_states)
827-
828-
829-
direct_register_custom_op(
830-
op_name="_apply",
831-
op_func=_apply,
832-
mutates_args=[],
833-
fake_impl=_apply_fake,
834-
tags=(torch.Tag.needs_fixed_stride_order, ),
835-
)
836-
837-
838685
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
839686

840687
def __init__(
@@ -845,6 +692,8 @@ def __init__(
845692
use_int8_w8a16: bool = False,
846693
use_int4_w4a16: bool = False,
847694
block_shape: Optional[List[int]] = None,
695+
world_size: int = 1,
696+
dp_size: int = 1,
848697
):
849698
super().__init__()
850699
self.use_fp8_w8a8 = use_fp8_w8a8
@@ -855,6 +704,8 @@ def __init__(
855704
self.max_num_tokens = max_num_tokens
856705
assert not use_int8_w8a8, "NYI"
857706
assert not use_int4_w4a16, "NYI"
707+
self.world_size = world_size
708+
self.dp_size = dp_size
858709

859710
def workspace_shapes(
860711
self,
@@ -866,10 +717,11 @@ def workspace_shapes(
866717
num_experts: int,
867718
) -> Tuple[int, int, torch.dtype]:
868719
assert a.dim() == 2
720+
num_dp = self.world_size // self.dp_size
869721
max_num_tokens = a.shape[
870722
0] if self.max_num_tokens is None else self.max_num_tokens
871-
workspace13 = num_experts * max_num_tokens * max(K, N)
872-
workspace2 = num_experts * max_num_tokens * (N // 2)
723+
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
724+
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
873725
return (workspace13, workspace2, a.dtype)
874726

875727
def apply(
@@ -891,29 +743,6 @@ def apply(
891743
workspace2: torch.Tensor,
892744
expert_num_tokens: Optional[torch.Tensor],
893745
) -> torch.Tensor:
894-
return torch.ops.vllm._apply(
895-
hidden_states,
896-
w1,
897-
w2,
898-
topk_ids,
899-
activation,
900-
global_num_experts,
901-
expert_map,
902-
w1_scale,
903-
w2_scale,
904-
w1_zp,
905-
w2_zp,
906-
a1q_scale,
907-
a2_scale,
908-
workspace13,
909-
workspace2,
910-
expert_num_tokens,
911-
self.use_fp8_w8a8,
912-
self.use_int8_w8a16,
913-
self.use_int4_w4a16,
914-
self.block_shape,
915-
)
916-
917746
# Check constraints.
918747
if self.use_int4_w4a16:
919748
assert hidden_states.shape[-1] // 2 == w1.shape[
@@ -988,10 +817,13 @@ def apply(
988817
block_shape=self.block_shape)
989818

990819
# Fix activations
991-
assert activation == "silu"
992-
invoke_batched_silu_and_mul(output=intermediate_cache2,
993-
input=intermediate_cache1,
994-
expert_num_tokens=expert_num_tokens)
820+
# assert activation == "silu"
821+
# invoke_batched_silu_and_mul(output=intermediate_cache2,
822+
# input=intermediate_cache1,
823+
# expert_num_tokens=expert_num_tokens)
824+
self.activation(activation,
825+
intermediate_cache2.view(-1, N//2),
826+
intermediate_cache1.view(-1, N))
995827

996828
#qintermediate_cache2 = intermediate_cache2
997829
a2q_scale = a2_scale

0 commit comments

Comments
 (0)