Skip to content

Commit b57b967

Browse files
robertgshaw2-redhatRobert Shaw
andauthored
[MoE Refactor][7/N] AITER MK (vllm-project#31102)
Signed-off-by: Robert Shaw <[email protected]> Co-authored-by: Robert Shaw <[email protected]>
1 parent 6d518ff commit b57b967

File tree

4 files changed

+144
-66
lines changed

4 files changed

+144
-66
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2132,6 +2132,7 @@ def apply(
21322132
torch.float16,
21332133
torch.bfloat16,
21342134
torch.float8_e4m3fn,
2135+
torch.float8_e4m3fnuz,
21352136
]
21362137

21372138
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
@@ -2156,7 +2157,10 @@ def apply(
21562157
compute_type = tl.float16
21572158
elif hidden_states.dtype == torch.float32:
21582159
compute_type = tl.float32
2159-
elif hidden_states.dtype == torch.float8_e4m3fn:
2160+
elif (
2161+
hidden_states.dtype == torch.float8_e4m3fn
2162+
or hidden_states.dtype == torch.float8_e4m3fnuz
2163+
):
21602164
compute_type = tl.bfloat16
21612165
else:
21622166
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")

vllm/model_executor/layers/fused_moe/prepare_finalize.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414

1515
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
16+
def __init__(self, defer_input_quant: bool = False) -> None:
17+
super().__init__()
18+
self.defer_input_quant = defer_input_quant
19+
1620
@property
1721
def activation_format(self) -> mk.FusedMoEActivationFormat:
1822
return mk.FusedMoEActivationFormat.Standard
@@ -48,6 +52,11 @@ def prepare(
4852
# Note: do not use inplace for shared experts overlap
4953
a1 = a1 * topk_weights.to(a1.dtype)
5054

55+
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
56+
# which use a single kernel call for quant + experts.
57+
if self.defer_input_quant:
58+
return a1, None, None, None, None
59+
5160
a1q, a1q_scale = moe_kernel_quantize_input(
5261
a1,
5362
quant_config.a1_scale,

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55

66
import torch
77

8+
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
89
from vllm._aiter_ops import rocm_aiter_ops
910
from vllm.model_executor.layers.fused_moe.config import (
1011
FUSED_MOE_UNQUANTIZED_CONFIG,
1112
FusedMoEQuantConfig,
1213
)
14+
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
15+
TopKWeightAndReduceNoOP,
16+
)
1317

1418

1519
class QuantMethod(IntEnum):
@@ -263,3 +267,78 @@ def rocm_aiter_fused_experts(
263267
a2_scale=quant_config.a2_scale,
264268
doweight_stage1=apply_router_weight_on_input,
265269
)
270+
271+
272+
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
273+
def __init__(self, quant_config):
274+
super().__init__(quant_config)
275+
276+
@property
277+
def activation_formats(
278+
self,
279+
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
280+
return (
281+
mk.FusedMoEActivationFormat.Standard,
282+
mk.FusedMoEActivationFormat.Standard,
283+
)
284+
285+
def supports_expert_map(self):
286+
return True
287+
288+
def supports_chunking(self):
289+
return False
290+
291+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
292+
return TopKWeightAndReduceNoOP()
293+
294+
def workspace_shapes(
295+
self,
296+
M: int,
297+
N: int,
298+
K: int,
299+
topk: int,
300+
global_num_experts: int,
301+
local_num_experts: int,
302+
expert_tokens_meta: mk.ExpertTokensMetadata | None,
303+
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
304+
# Workspaces are managed internally by AITER.
305+
workspace1 = (0,)
306+
workspace2 = (0,)
307+
output = (M, K)
308+
return (workspace1, workspace2, output)
309+
310+
def apply(
311+
self,
312+
output: torch.Tensor,
313+
hidden_states: torch.Tensor,
314+
w1: torch.Tensor,
315+
w2: torch.Tensor,
316+
topk_weights: torch.Tensor,
317+
topk_ids: torch.Tensor,
318+
activation: str,
319+
global_num_experts: int,
320+
expert_map: torch.Tensor | None,
321+
a1q_scale: torch.Tensor | None,
322+
a2_scale: torch.Tensor | None,
323+
workspace13: torch.Tensor,
324+
workspace2: torch.Tensor,
325+
expert_tokens_meta: mk.ExpertTokensMetadata | None,
326+
apply_router_weight_on_input: bool,
327+
):
328+
assert a1q_scale is None
329+
assert a2_scale is None
330+
assert expert_tokens_meta is None
331+
332+
result = rocm_aiter_fused_experts(
333+
hidden_states=hidden_states,
334+
w1=w1,
335+
w2=w2,
336+
topk_weights=topk_weights,
337+
topk_ids=topk_ids,
338+
activation=activation,
339+
apply_router_weight_on_input=apply_router_weight_on_input,
340+
expert_map=expert_map,
341+
quant_config=self.quant_config,
342+
)
343+
assert result.shape == output.shape
344+
output.copy_(result)

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 51 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class Fp8MoeBackend(Enum):
117117
DEEPGEMM = 3
118118
MARLIN = 4
119119
TRITON = 5
120+
AITER = 6
120121

121122

122123
def get_fp8_moe_backend(
@@ -189,6 +190,10 @@ def get_fp8_moe_backend(
189190
logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
190191
return Fp8MoeBackend.DEEPGEMM
191192

193+
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
194+
logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
195+
return Fp8MoeBackend.AITER
196+
192197
# default to Triton
193198
logger.info_once("Using Triton backend for FP8 MoE")
194199
return Fp8MoeBackend.TRITON
@@ -888,16 +893,10 @@ def create_weights(
888893
layer.w13_input_scale = None
889894
layer.w2_input_scale = None
890895

891-
self.rocm_aiter_moe_enabled = False
892-
893896
def process_weights_after_loading(self, layer: Module) -> None:
894897
if getattr(layer, "_already_called_process_weights_after_loading", False):
895898
return
896899

897-
# Lazy import to avoid importing triton too early.
898-
899-
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
900-
901900
# TODO (rob): refactor block quant into separate class.
902901
if self.block_quant:
903902
assert self.quant_config.activation_scheme == "dynamic"
@@ -932,7 +931,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
932931
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
933932
replace_parameter(layer, "w2_weight", w2_weight)
934933
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
935-
if self.rocm_aiter_moe_enabled:
934+
if self.fp8_backend == Fp8MoeBackend.AITER:
936935
# reshaping weights is required for aiter moe kernel.
937936
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
938937
layer.w13_weight.data, layer.w2_weight.data
@@ -1026,7 +1025,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
10261025
)
10271026
start += shard_size
10281027

1029-
if self.rocm_aiter_moe_enabled:
1028+
if self.fp8_backend == Fp8MoeBackend.AITER:
10301029
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
10311030
layer.w13_weight, layer.w2_weight
10321031
)
@@ -1072,6 +1071,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
10721071
self.moe_quant_config = config
10731072

10741073
self.kernel = mk.FusedMoEModularKernel(
1074+
# TODO(rob): we can use the generic MoEPrepareAndFinalizeNoEP
1075+
# with the changes to defer input quantization
10751076
FlashInferAllGatherMoEPrepareAndFinalize(
10761077
use_dp=(self.moe.dp_size > 1),
10771078
use_deepseek_fp8_block_scale=self.block_quant,
@@ -1093,6 +1094,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
10931094
Fp8MoeBackend.DEEPGEMM,
10941095
Fp8MoeBackend.TRITON,
10951096
Fp8MoeBackend.MARLIN,
1097+
Fp8MoeBackend.AITER,
10961098
]:
10971099
from vllm.model_executor.layers.fused_moe import (
10981100
TritonOrDeepGemmExperts,
@@ -1103,32 +1105,41 @@ def process_weights_after_loading(self, layer: Module) -> None:
11031105
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
11041106
MoEPrepareAndFinalizeNoEP,
11051107
)
1108+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
1109+
AiterExperts,
1110+
)
11061111

11071112
config = self.get_fused_moe_quant_config(layer)
11081113
assert config is not None
11091114
self.moe_quant_config = config
1110-
use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
1111-
allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
1112-
moe_kernel = (
1113-
MarlinExperts(quant_config=self.moe_quant_config)
1114-
if use_marlin
1115-
else TritonOrDeepGemmExperts(
1116-
quant_config=self.moe_quant_config,
1117-
allow_deep_gemm=allow_deep_gemm,
1118-
)
1119-
)
11201115

1121-
self.kernel = mk.FusedMoEModularKernel(
1122-
MoEPrepareAndFinalizeNoEP(), moe_kernel
1123-
)
1116+
if self.fp8_backend == Fp8MoeBackend.AITER:
1117+
self.kernel = mk.FusedMoEModularKernel(
1118+
# TODO: make defer_input_quant an attr of the AiterExperts
1119+
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
1120+
AiterExperts(quant_config=self.moe_quant_config),
1121+
)
1122+
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
1123+
self.kernel = mk.FusedMoEModularKernel(
1124+
MoEPrepareAndFinalizeNoEP(),
1125+
MarlinExperts(quant_config=self.moe_quant_config),
1126+
)
1127+
else:
1128+
self.kernel = mk.FusedMoEModularKernel(
1129+
MoEPrepareAndFinalizeNoEP(),
1130+
TritonOrDeepGemmExperts(
1131+
quant_config=self.moe_quant_config,
1132+
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
1133+
),
1134+
)
11241135
self.use_inplace = True
11251136

11261137
def maybe_make_prepare_finalize(
11271138
self,
11281139
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
11291140
) -> mk.FusedMoEPrepareAndFinalize | None:
11301141
if (
1131-
self.rocm_aiter_moe_enabled
1142+
self.fp8_backend == Fp8MoeBackend.AITER
11321143
or self.fp8_backend == Fp8MoeBackend.MARLIN
11331144
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
11341145
):
@@ -1161,11 +1172,10 @@ def select_gemm_impl(
11611172
TritonOrDeepGemmExperts,
11621173
)
11631174

1164-
assert (
1165-
self.fp8_backend != Fp8MoeBackend.MARLIN
1166-
) and not self.rocm_aiter_moe_enabled, (
1167-
"Marlin and ROCm AITER are not supported with all2all yet."
1168-
)
1175+
if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
1176+
raise NotImplementedError(
1177+
"Marlin and ROCm AITER are not supported with all2all yet."
1178+
)
11691179

11701180
assert self.moe_quant_config is not None
11711181

@@ -1313,37 +1323,18 @@ def apply(
13131323
hidden_states=x,
13141324
router_logits=router_logits,
13151325
)
1316-
1317-
if self.rocm_aiter_moe_enabled:
1318-
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
1319-
rocm_aiter_fused_experts,
1320-
)
1321-
1322-
# TODO(rob): convert this to MK.
1323-
result = rocm_aiter_fused_experts(
1324-
x,
1325-
layer.w13_weight,
1326-
layer.w2_weight,
1327-
topk_weights=topk_weights,
1328-
topk_ids=topk_ids,
1329-
activation=layer.activation,
1330-
apply_router_weight_on_input=layer.apply_router_weight_on_input,
1331-
expert_map=layer.expert_map,
1332-
quant_config=self.moe_quant_config,
1333-
)
1334-
else:
1335-
result = self.kernel(
1336-
x,
1337-
layer.w13_weight,
1338-
layer.w2_weight,
1339-
topk_weights,
1340-
topk_ids,
1341-
inplace=self.use_inplace,
1342-
activation=layer.activation,
1343-
global_num_experts=layer.global_num_experts,
1344-
expert_map=layer.expert_map,
1345-
apply_router_weight_on_input=layer.apply_router_weight_on_input,
1346-
)
1326+
result = self.kernel(
1327+
x,
1328+
layer.w13_weight,
1329+
layer.w2_weight,
1330+
topk_weights,
1331+
topk_ids,
1332+
inplace=self.use_inplace,
1333+
activation=layer.activation,
1334+
global_num_experts=layer.global_num_experts,
1335+
expert_map=layer.expert_map,
1336+
apply_router_weight_on_input=layer.apply_router_weight_on_input,
1337+
)
13471338

13481339
return result
13491340

@@ -1456,15 +1447,10 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
14561447
layer.w13_input_scale = None
14571448
layer.w2_input_scale = None
14581449

1459-
self.rocm_aiter_moe_enabled = False
1460-
14611450
def process_weights_after_loading(self, layer: Module) -> None:
14621451
if getattr(layer, "_already_called_process_weights_after_loading", False):
14631452
return
14641453

1465-
# Lazy import to avoid importing triton too early.
1466-
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
1467-
14681454
# If checkpoint is fp16, quantize in place.
14691455
fp8_dtype = current_platform.fp8_dtype()
14701456
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
@@ -1481,15 +1467,15 @@ def process_weights_after_loading(self, layer: Module) -> None:
14811467
replace_parameter(layer, "w2_weight", w2_weight)
14821468

14831469
# Reshuffle weights for AITER if needed.
1484-
if self.rocm_aiter_moe_enabled:
1470+
if self.fp8_backend == Fp8MoeBackend.AITER:
14851471
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
14861472
layer.w13_weight, layer.w2_weight
14871473
)
14881474
replace_parameter(layer, "w13_weight", shuffled_w13)
14891475
replace_parameter(layer, "w2_weight", shuffled_w2)
14901476

14911477
# Rushuffle weights for MARLIN if needed.
1492-
if self.fp8_backend == Fp8MoeBackend.MARLIN:
1478+
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
14931479
prepare_moe_fp8_layer_for_marlin(
14941480
layer, False, input_dtype=self.marlin_input_dtype
14951481
)

0 commit comments

Comments
 (0)