Skip to content

Commit 2ecc0d6

Browse files
tjtanaatjtanaavllm
authored andcommitted
ck moe 2 stage: cherry pick 612c2ed
Signed-off-by: tjtanaavllm <tunjian.tan@amd.com>
1 parent 021ebeb commit 2ecc0d6

File tree

7 files changed

+174
-238
lines changed

7 files changed

+174
-238
lines changed

tests/kernels/moe/test_moe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,16 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
222222
"""Make sure our Mixtral MoE implementation agrees with the one from
223223
huggingface."""
224224

225+
# clear the cache before every test
226+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
227+
is_rocm_aiter_moe_enabled)
228+
is_rocm_aiter_moe_enabled.cache_clear()
225229
if use_rocm_aiter:
226230
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
227231

232+
if dtype == torch.float32:
233+
pytest.skip("AITER ROCm test skip for float32")
234+
228235
# Instantiate our and huggingface's MoE blocks
229236
config = MixtralConfig()
230237
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@
88
from vllm.model_executor.layers.activation import (GeluAndMul,
99
ReLUSquaredActivation,
1010
SiluAndMul)
11-
from vllm.model_executor.layers.fused_moe.fused_moe import (
12-
dispatch_fused_experts_func, dispatch_topk_func,
13-
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
14-
vllm_topk_softmax)
11+
from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
12+
vllm_topk_softmax)
1513
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
1614
is_rocm_aiter_moe_enabled)
1715
from vllm.model_executor.layers.layernorm import (
@@ -142,24 +140,6 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
142140
assert topk_func == vllm_topk_softmax
143141

144142

145-
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
146-
@pytest.mark.parametrize("inplace", [True, False])
147-
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
148-
monkeypatch):
149-
150-
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
151-
is_rocm_aiter_moe_enabled.cache_clear()
152-
fused_experts_func = dispatch_fused_experts_func(inplace)
153-
if current_platform.is_rocm() and int(use_rocm_aiter):
154-
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
155-
rocm_aiter_fused_experts)
156-
assert fused_experts_func == rocm_aiter_fused_experts
157-
elif inplace:
158-
assert fused_experts_func == torch_vllm_inplace_fused_experts
159-
else:
160-
assert fused_experts_func == torch_vllm_outplace_fused_experts
161-
162-
163143
@pytest.mark.parametrize("add_residual", [True, False])
164144
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
165145
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,9 +1098,6 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
10981098

10991099

11001100
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
1101-
if is_rocm_aiter_moe_enabled():
1102-
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
1103-
return rocm_aiter_fused_experts
11041101
if inplace:
11051102
return torch_vllm_inplace_fused_experts
11061103
return torch_vllm_outplace_fused_experts

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,16 @@ def apply(
8686
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
8787
"""MoE method without quantization."""
8888

89+
def __init__(self):
90+
super().__init__()
91+
92+
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
93+
if self.rocm_aiter_moe_enabled:
94+
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
95+
self.rocm_aiter_fused_experts = rocm_aiter_fused_experts
96+
else:
97+
self.rocm_aiter_fused_experts = None # type: ignore
98+
8999
def create_weights(self, layer: torch.nn.Module, num_experts: int,
90100
hidden_size: int, intermediate_size_per_partition: int,
91101
params_dtype: torch.dtype, **extra_weight_attrs):
@@ -128,18 +138,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
128138
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
129139
# Lazy import to avoid importing triton.
130140
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
131-
is_rocm_aiter_2stage_moe_enabled, is_rocm_aiter_moe_enabled,
132141
shuffle_weights)
133-
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
134-
self.rocm_aiter_2stage_moe_enabled = is_rocm_aiter_2stage_moe_enabled()
135-
if self.rocm_aiter_moe_enabled:
136-
# reshaping weights is required for aiter moe kernel.
137-
layout = (32, 32) if self.rocm_aiter_2stage_moe_enabled else (16,
138-
16)
139142

143+
if self.rocm_aiter_moe_enabled:
144+
# use 2stage ck moe layout
140145
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
141146
layer.w2_weight.data,
142-
layout=layout)
147+
layout=(32, 32))
143148

144149
layer.w13_weight.data = shuffled_w13
145150
layer.w2_weight.data = shuffled_w2
@@ -221,18 +226,14 @@ def forward_cuda(
221226
e_score_correction_bias=e_score_correction_bias)
222227

223228
if self.rocm_aiter_moe_enabled:
224-
return rocm_aiter_fused_experts(
229+
return self.rocm_aiter_fused_experts(
225230
hidden_states=x,
226231
w1=layer.w13_weight,
227232
w2=layer.w2_weight,
228233
topk_weights=topk_weights,
229234
topk_ids=topk_ids,
230-
inplace=True,
231235
activation=activation,
232-
apply_router_weight_on_input=apply_router_weight_on_input,
233-
global_num_experts=global_num_experts,
234-
expert_map=expert_map,
235-
use_ck_moe_2stages=self.rocm_aiter_2stage_moe_enabled)
236+
apply_router_weight_on_input=apply_router_weight_on_input)
236237

237238
return fused_experts(
238239
hidden_states=x,

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 29 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from functools import cache
3-
from typing import List, Optional, Tuple
3+
from typing import Optional
44

55
import torch
66

@@ -16,13 +16,6 @@ def is_rocm_aiter_moe_enabled() -> bool:
1616
and envs.VLLM_ROCM_USE_AITER
1717

1818

19-
def is_rocm_aiter_2stage_moe_enabled() -> bool:
20-
return current_platform.is_rocm() \
21-
and envs.VLLM_ROCM_USE_AITER_2STAGE_MOE \
22-
and envs.VLLM_ROCM_USE_AITER_MOE \
23-
and envs.VLLM_ROCM_USE_AITER
24-
25-
2619
def rocm_aiter_asm_moe_tkw1_impl(
2720
hidden_states: torch.Tensor,
2821
w1: torch.Tensor,
@@ -76,23 +69,6 @@ def rocm_aiter_asm_moe_tkw1_fake(
7669
return torch.empty_like(hidden_states)
7770

7871

79-
def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor,
80-
w2: torch.Tensor, topk_weights: torch.Tensor,
81-
topk_ids: torch.Tensor) -> torch.Tensor:
82-
from aiter import ck_moe
83-
return ck_moe(hidden_states=hidden_states,
84-
w1=w1,
85-
w2=w2,
86-
topk_weights=topk_weights,
87-
topk_ids=topk_ids)
88-
89-
90-
def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor,
91-
w2: torch.Tensor, topk_weights: torch.Tensor,
92-
topk_ids: torch.Tensor) -> torch.Tensor:
93-
return torch.empty_like(hidden_states)
94-
95-
9672
def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
9773
topk_ids: torch.Tensor,
9874
topk_weights: torch.Tensor,
@@ -215,10 +191,9 @@ def rocm_aiter_ck_moe_2stages_impl(
215191
fc2_scale: Optional[torch.Tensor] = None,
216192
a1_scale: Optional[torch.Tensor] = None,
217193
a2_scale: Optional[torch.Tensor] = None,
218-
block_size: Optional[List[int]] = None,
194+
block_size: Optional[list[int]] = None,
219195
expert_mask: Optional[torch.Tensor] = None,
220196
) -> torch.Tensor:
221-
222197
from aiter.fused_moe_bf16_asm import ck_moe_2stages
223198
return ck_moe_2stages(a1=hidden_states,
224199
w1=w1,
@@ -243,7 +218,7 @@ def rocm_aiter_ck_moe_2stages_fake(
243218
fc2_scale: Optional[torch.Tensor] = None,
244219
a1_scale: Optional[torch.Tensor] = None,
245220
a2_scale: Optional[torch.Tensor] = None,
246-
block_size: Optional[List[int]] = None,
221+
block_size: Optional[list[int]] = None,
247222
expert_mask: Optional[torch.Tensor] = None,
248223
) -> torch.Tensor:
249224
return torch.empty_like(hidden_states)
@@ -308,14 +283,6 @@ def rocm_aiter_biased_grouped_topk_fake(
308283
dispatch_key=current_platform.dispatch_key,
309284
)
310285

311-
direct_register_custom_op(
312-
op_name="rocm_aiter_ck_moe",
313-
op_func=rocm_aiter_ck_moe_impl,
314-
mutates_args=[],
315-
fake_impl=rocm_aiter_ck_moe_fake,
316-
dispatch_key=current_platform.dispatch_key,
317-
)
318-
319286
direct_register_custom_op(
320287
op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1",
321288
op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl,
@@ -390,31 +357,20 @@ def rocm_aiter_biased_group_topk(
390357

391358

392359
def rocm_aiter_fused_experts(
393-
hidden_states: torch.Tensor,
394-
w1: torch.Tensor,
395-
w2: torch.Tensor,
396-
topk_weights: torch.Tensor,
397-
topk_ids: torch.Tensor,
398-
inplace: bool = False,
399-
activation: str = "silu",
400-
apply_router_weight_on_input: bool = False,
401-
use_fp8_w8a8: bool = False,
402-
use_int8_w8a8: bool = False,
403-
use_int8_w8a16: bool = False,
404-
use_int4_w4a16: bool = False,
405-
per_channel_quant: bool = False,
406-
global_num_experts: int = -1,
407-
expert_map: Optional[torch.Tensor] = None,
408-
w1_scale: Optional[torch.Tensor] = None,
409-
w2_scale: Optional[torch.Tensor] = None,
410-
w1_zp: Optional[torch.Tensor] = None,
411-
w2_zp: Optional[torch.Tensor] = None,
412-
a1_scale: Optional[torch.Tensor] = None,
413-
a2_scale: Optional[torch.Tensor] = None,
414-
block_shape: Optional[List[int]] = None,
415-
allow_deep_gemm: bool = False,
416-
use_ck_moe_2stages: bool = False,
417-
) -> torch.Tensor:
360+
hidden_states: torch.Tensor,
361+
w1: torch.Tensor,
362+
w2: torch.Tensor,
363+
topk_weights: torch.Tensor,
364+
topk_ids: torch.Tensor,
365+
activation: str = "silu",
366+
apply_router_weight_on_input: bool = False,
367+
use_fp8_w8a8: bool = False,
368+
per_channel_quant: bool = False,
369+
w1_scale: Optional[torch.Tensor] = None,
370+
w2_scale: Optional[torch.Tensor] = None,
371+
a1_scale: Optional[torch.Tensor] = None,
372+
a2_scale: Optional[torch.Tensor] = None,
373+
block_shape: Optional[list[int]] = None) -> torch.Tensor:
418374

419375
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
420376
per_token_group_quant_fp8)
@@ -465,7 +421,7 @@ def rocm_aiter_fused_experts(
465421
fc2_smooth_scale=None,
466422
a16=False,
467423
per_tensor_quant_scale=None,
468-
expert_mask=expert_map,
424+
expert_mask=None,
469425
activation_str=activation)
470426

471427
# w8a8 per-tensor activation per-tensor weight
@@ -475,7 +431,7 @@ def rocm_aiter_fused_experts(
475431

476432
# - faster static per-tensor-activation static per-tensor-weight
477433
# fp8 quantization w8a8
478-
if use_ck_moe_2stages and a1_scale is not None and a2_scale is not None:
434+
if a1_scale is not None and a2_scale is not None:
479435
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
480436
hidden_states=hidden_states,
481437
w1=w1,
@@ -514,28 +470,19 @@ def rocm_aiter_fused_experts(
514470
topk_ids = topk_ids.to(torch.int32)
515471
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
516472

517-
# faster w16a16
518-
if use_ck_moe_2stages:
519-
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
520-
hidden_states=hidden_states,
521-
w1=w1,
522-
w2=w2,
523-
topk_weights=topk_weights,
524-
topk_ids=topk_ids)
525-
526-
# w16a16 fallback to rocm_aiter_ck_moe w16a16
527-
return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states,
528-
w1=w1,
529-
w2=w2,
530-
topk_weights=topk_weights,
531-
topk_ids=topk_ids)
473+
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
474+
hidden_states=hidden_states,
475+
w1=w1,
476+
w2=w2,
477+
topk_weights=topk_weights,
478+
topk_ids=topk_ids)
532479

533480

534481
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
535482
topk_indices: torch.Tensor,
536483
token_expert_indices: torch.Tensor,
537484
gating_output: torch.Tensor,
538-
renormalize: bool) -> Tuple[torch.Tensor, ...]:
485+
renormalize: bool) -> tuple[torch.Tensor, ...]:
539486
torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices,
540487
token_expert_indices, gating_output,
541488
renormalize)
@@ -560,7 +507,7 @@ def shuffle_weights(*tensors: torch.Tensor,
560507

561508

562509
def expand_weights(*tensors: torch.Tensor,
563-
expansion_dims: List[int]) -> Tuple[torch.Tensor, ...]:
510+
expansion_dims: list[int]) -> tuple[torch.Tensor, ...]:
564511
"""
565512
Expands the dimensions of input tensors.
566513
@@ -570,12 +517,12 @@ def expand_weights(*tensors: torch.Tensor,
570517
corresponding to each tensor.
571518
572519
Returns:
573-
A Tuple of tensors with expanded dimensions.
520+
A tuple of tensors with expanded dimensions.
574521
"""
575522

576523
assert len(tensors) == len(expansion_dims), \
577524
"Number of tensors must match the number of expansion dimensions."
578525

579526
return tuple(
580527
tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1))
581-
for tensor, dim in zip(tensors, expansion_dims))
528+
for tensor, dim in zip(tensors, expansion_dims))

0 commit comments

Comments
 (0)