Skip to content

Commit 9b97c83

Browse files
committed
review comments, only initialize pplx if EP is enabled
Signed-off-by: Bill Nell <[email protected]>
1 parent b5be324 commit 9b97c83

File tree

13 files changed

+48
-83
lines changed

13 files changed

+48
-83
lines changed

vllm/distributed/parallel_state.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,7 @@ def pplx_finalize():
979979
def initialize_model_parallel(
980980
tensor_model_parallel_size: int = 1,
981981
pipeline_model_parallel_size: int = 1,
982+
enable_expert_parallel: bool = False,
982983
backend: Optional[str] = None,
983984
) -> None:
984985
"""
@@ -1081,12 +1082,14 @@ def initialize_model_parallel(
10811082
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
10821083
_EP.rank_in_group)
10831084

1084-
pplx_init(rank, world_size)
1085+
if enable_expert_parallel:
1086+
pplx_init(rank, world_size)
10851087

10861088

10871089
def ensure_model_parallel_initialized(
10881090
tensor_model_parallel_size: int,
10891091
pipeline_model_parallel_size: int,
1092+
enable_expert_parallel: bool = False,
10901093
backend: Optional[str] = None,
10911094
) -> None:
10921095
"""Helper to initialize model parallel groups if they are not initialized,
@@ -1097,7 +1100,8 @@ def ensure_model_parallel_initialized(
10971100
get_world_group().device_group)
10981101
if not model_parallel_is_initialized():
10991102
initialize_model_parallel(tensor_model_parallel_size,
1100-
pipeline_model_parallel_size, backend)
1103+
pipeline_model_parallel_size,
1104+
enable_expert_parallel, backend)
11011105
return
11021106

11031107
assert (

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -175,29 +175,6 @@ def apply(
175175
return c3
176176

177177

178-
def modular_cutlass_moe_fp8(
179-
per_act_token: bool,
180-
ab_strides1: torch.Tensor,
181-
c_strides1: torch.Tensor,
182-
ab_strides2: torch.Tensor,
183-
c_strides2: torch.Tensor,
184-
out_dtype: torch.dtype = torch.half,
185-
) -> mk.FusedMoEModularKernel:
186-
return mk.FusedMoEModularKernel(
187-
StandardPrepareAndFinalize(
188-
per_channel_quant=per_act_token,
189-
quant_dtype=torch.float8_e4m3fn,
190-
),
191-
CutlassExpertsFp8(
192-
ab_strides1,
193-
c_strides1,
194-
ab_strides2,
195-
c_strides2,
196-
out_dtype,
197-
),
198-
)
199-
200-
201178
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
202179
def cutlass_moe_fp8(
203180
a: torch.Tensor,
@@ -263,13 +240,18 @@ def cutlass_moe_fp8(
263240
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
264241
a2_scale.numel() != 1 if a2_scale is not None else False)
265242

266-
fn = modular_cutlass_moe_fp8(
267-
per_act_token,
268-
ab_strides1,
269-
c_strides1,
270-
ab_strides2,
271-
c_strides2,
272-
out_dtype,
243+
fn = mk.FusedMoEModularKernel(
244+
StandardPrepareAndFinalize(
245+
per_channel_quant=per_act_token,
246+
quant_dtype=torch.float8_e4m3fn,
247+
),
248+
CutlassExpertsFp8(
249+
ab_strides1,
250+
c_strides1,
251+
ab_strides2,
252+
c_strides2,
253+
out_dtype,
254+
),
273255
)
274256

275257
return fn(

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,6 @@ def apply(
151151
return workspace3
152152

153153

154-
def modular_deep_gemm_fused_moe_fp8() -> mk.FusedMoEModularKernel:
155-
return mk.FusedMoEModularKernel(
156-
StandardPrepareAndFinalize(quant_dtype=torch.float8_e4m3fn,
157-
block_shape=deep_gemm_block_shape()),
158-
DeepGemmExperts(),
159-
)
160-
161-
162154
def deep_gemm_moe_fp8(
163155
hidden_states: torch.Tensor,
164156
w1: torch.Tensor,
@@ -212,7 +204,11 @@ def deep_gemm_moe_fp8(
212204
Returns:
213205
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
214206
"""
215-
fn = modular_deep_gemm_fused_moe_fp8()
207+
fn = mk.FusedMoEModularKernel(
208+
StandardPrepareAndFinalize(quant_dtype=torch.float8_e4m3fn,
209+
block_shape=deep_gemm_block_shape()),
210+
DeepGemmExperts(),
211+
)
216212
return fn(
217213
hidden_states,
218214
w1,

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -381,11 +381,6 @@ def invoke_moe_batched_triton_kernel(
381381
BLOCK_K=BLOCK_K)
382382

383383

384-
def rank_chunk(num, r, w):
385-
rem = num % w
386-
return (num // w) + (1 if r < rem else 0)
387-
388-
389384
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
390385
"""
391386
A reference prepare/finalize class that reorganizes the tokens into
@@ -475,12 +470,12 @@ def finalize(
475470
last_expert = first_expert + num_local_experts
476471

477472
for expert_id in range(first_expert, last_expert):
478-
topkws = topk_ids == expert_id
479-
topks = torch.any(topkws, dim=1).flatten()
473+
matching_tokens = topk_ids == expert_id
474+
topks = torch.any(matching_tokens, dim=1).flatten()
480475
rows = torch.count_nonzero(topks)
481476
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
482477
if not apply_router_weight_on_input:
483-
rhs.mul_(topk_weights[topkws].view(rhs.size(0), 1))
478+
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
484479
output[topks] = output[topks] + rhs
485480

486481

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,7 @@ def get_config_dtype_str(
979979
return None
980980

981981

982-
# TODO: use scalar_type instead of bools?
982+
# TODO (bnell): use scalar_type instead of bools?
983983
def get_config_qtype(
984984
use_fp8_w8a8: bool,
985985
use_int8_w8a8: bool,
@@ -1585,6 +1585,7 @@ def apply(
15851585

15861586
assert hidden_states.is_contiguous(
15871587
), "Hidden_states must be contiguous"
1588+
assert hidden_states.dim() == 2
15881589
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
15891590
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
15901591
assert hidden_states.dtype in [
@@ -1632,30 +1633,9 @@ def apply(
16321633
intermediate_cache3 = _resize_cache(workspace13,
16331634
(num_tokens, top_k_num, K))
16341635

1635-
if hidden_states.dim() == 2: #block_m is None:
1636-
sorted_token_ids, expert_ids, num_tokens_post_padded = (
1637-
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
1638-
global_num_experts, expert_map))
1639-
else:
1640-
max_num_tokens = hidden_states.size(1)
1641-
sorted_token_ids = torch.arange(0,
1642-
hidden_states.size(0) *
1643-
max_num_tokens,
1644-
device=hidden_states.device,
1645-
dtype=torch.int)
1646-
sorted_token_ids = sorted_token_ids.flatten()
1647-
expert_ids = torch.arange(0,
1648-
global_num_experts,
1649-
device=hidden_states.device,
1650-
dtype=torch.int)
1651-
expert_ids = torch.repeat_interleave(expert_ids,
1652-
max_num_tokens,
1653-
dim=0)
1654-
num_tokens_post_padded = torch.zeros(1,
1655-
device=hidden_states.device,
1656-
dtype=torch.int32)
1657-
num_tokens_post_padded.fill_(max_num_tokens)
1658-
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
1636+
sorted_token_ids, expert_ids, num_tokens_post_padded = (
1637+
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
1638+
global_num_experts, expert_map))
16591639

16601640
invoke_fused_moe_kernel(hidden_states,
16611641
w1,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ def _construct_prepare_finalize(
687687
rank = moe.ep_rank
688688

689689
if moe.use_pplx_kernels:
690-
logger.debug("using pplx dispatch")
690+
logger.debug("using PplxPrepareAndFinalize")
691691

692692
all_to_all = get_all_to_all(
693693
max_num_tokens=max_num_tokens,

vllm/v1/worker/gpu_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,8 @@ def init_worker_distributed_environment(
341341
distributed_init_method, local_rank)
342342

343343
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
344-
parallel_config.pipeline_parallel_size)
344+
parallel_config.pipeline_parallel_size,
345+
parallel_config.enable_expert_parallel)
345346

346347
ensure_kv_transfer_initialized(vllm_config)
347348

vllm/v1/worker/tpu_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,4 +265,5 @@ def init_tpu_worker_distributed_environment(
265265
backend="gloo",
266266
)
267267
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
268-
parallel_config.pipeline_parallel_size)
268+
parallel_config.pipeline_parallel_size,
269+
parallel_config.enable_expert_parallel)

vllm/worker/cpu_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,8 @@ def init_distributed_environment(self) -> None:
390390

391391
ensure_model_parallel_initialized(
392392
parallel_config.tensor_parallel_size,
393-
parallel_config.pipeline_parallel_size)
393+
parallel_config.pipeline_parallel_size,
394+
parallel_config.enable_expert_parallel)
394395

395396
def get_cache_block_size_bytes(self) -> int:
396397
"""Return the size in bytes of a single KV cache block.

vllm/worker/hpu_worker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,8 @@ def init_worker_distributed_environment(
416416
backend='hccl')
417417

418418
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
419-
parallel_config.pipeline_parallel_size)
419+
parallel_config.pipeline_parallel_size,
420+
parallel_config.enable_expert_parallel)
420421

421422
if torch.distributed.is_initialized():
422423
torch_world_size = torch.distributed.get_world_size()
@@ -442,7 +443,8 @@ def init_worker_distributed_environment(
442443
torch.distributed.all_reduce(dummy_tensor_hpu)
443444
assert dummy_tensor_hpu.item() == parallel_config.world_size
444445
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
445-
parallel_config.pipeline_parallel_size)
446+
parallel_config.pipeline_parallel_size,
447+
parallel_config.enable_expert_parallel)
446448

447449

448450
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,

0 commit comments

Comments
 (0)