@@ -117,6 +117,7 @@ class Fp8MoeBackend(Enum):
117117 DEEPGEMM = 3
118118 MARLIN = 4
119119 TRITON = 5
120+ AITER = 6
120121
121122
122123def get_fp8_moe_backend (
@@ -189,6 +190,10 @@ def get_fp8_moe_backend(
189190 logger .info_once ("Using DeepGEMM backend for FP8 MoE" , scope = "local" )
190191 return Fp8MoeBackend .DEEPGEMM
191192
193+ if envs .VLLM_ROCM_USE_AITER and envs .VLLM_ROCM_USE_AITER_MOE :
194+ logger .info_once ("Using ROCm AITER backend for FP8 MoE" , scope = "local" )
195+ return Fp8MoeBackend .AITER
196+
192197 # default to Triton
193198 logger .info_once ("Using Triton backend for FP8 MoE" )
194199 return Fp8MoeBackend .TRITON
@@ -888,16 +893,10 @@ def create_weights(
888893 layer .w13_input_scale = None
889894 layer .w2_input_scale = None
890895
891- self .rocm_aiter_moe_enabled = False
892-
893896 def process_weights_after_loading (self , layer : Module ) -> None :
894897 if getattr (layer , "_already_called_process_weights_after_loading" , False ):
895898 return
896899
897- # Lazy import to avoid importing triton too early.
898-
899- self .rocm_aiter_moe_enabled = rocm_aiter_ops .is_fused_moe_enabled ()
900-
901900 # TODO (rob): refactor block quant into separate class.
902901 if self .block_quant :
903902 assert self .quant_config .activation_scheme == "dynamic"
@@ -932,7 +931,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
932931 replace_parameter (layer , "w13_weight_scale_inv" , w13_weight_scale_inv )
933932 replace_parameter (layer , "w2_weight" , w2_weight )
934933 replace_parameter (layer , "w2_weight_scale_inv" , w2_weight_scale_inv )
935- if self .rocm_aiter_moe_enabled :
934+ if self .fp8_backend == Fp8MoeBackend . AITER :
936935 # reshaping weights is required for aiter moe kernel.
937936 shuffled_w13 , shuffled_w2 = rocm_aiter_ops .shuffle_weights (
938937 layer .w13_weight .data , layer .w2_weight .data
@@ -1026,7 +1025,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
10261025 )
10271026 start += shard_size
10281027
1029- if self .rocm_aiter_moe_enabled :
1028+ if self .fp8_backend == Fp8MoeBackend . AITER :
10301029 shuffled_w13 , shuffled_w2 = rocm_aiter_ops .shuffle_weights (
10311030 layer .w13_weight , layer .w2_weight
10321031 )
@@ -1072,6 +1071,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
10721071 self .moe_quant_config = config
10731072
10741073 self .kernel = mk .FusedMoEModularKernel (
1074+ # TODO(rob): we can use the generic MoEPrepareAndFinalizeNoEP
1075+ # with the changes to defer input quantization
10751076 FlashInferAllGatherMoEPrepareAndFinalize (
10761077 use_dp = (self .moe .dp_size > 1 ),
10771078 use_deepseek_fp8_block_scale = self .block_quant ,
@@ -1093,6 +1094,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
10931094 Fp8MoeBackend .DEEPGEMM ,
10941095 Fp8MoeBackend .TRITON ,
10951096 Fp8MoeBackend .MARLIN ,
1097+ Fp8MoeBackend .AITER ,
10961098 ]:
10971099 from vllm .model_executor .layers .fused_moe import (
10981100 TritonOrDeepGemmExperts ,
@@ -1103,32 +1105,41 @@ def process_weights_after_loading(self, layer: Module) -> None:
11031105 from vllm .model_executor .layers .fused_moe .prepare_finalize import (
11041106 MoEPrepareAndFinalizeNoEP ,
11051107 )
1108+ from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
1109+ AiterExperts ,
1110+ )
11061111
11071112 config = self .get_fused_moe_quant_config (layer )
11081113 assert config is not None
11091114 self .moe_quant_config = config
1110- use_marlin = self .fp8_backend == Fp8MoeBackend .MARLIN
1111- allow_deep_gemm = self .fp8_backend == Fp8MoeBackend .DEEPGEMM
1112- moe_kernel = (
1113- MarlinExperts (quant_config = self .moe_quant_config )
1114- if use_marlin
1115- else TritonOrDeepGemmExperts (
1116- quant_config = self .moe_quant_config ,
1117- allow_deep_gemm = allow_deep_gemm ,
1118- )
1119- )
11201115
1121- self .kernel = mk .FusedMoEModularKernel (
1122- MoEPrepareAndFinalizeNoEP (), moe_kernel
1123- )
1116+ if self .fp8_backend == Fp8MoeBackend .AITER :
1117+ self .kernel = mk .FusedMoEModularKernel (
1118+ # TODO: make defer_input_quant an attr of the AiterExperts
1119+ MoEPrepareAndFinalizeNoEP (defer_input_quant = True ),
1120+ AiterExperts (quant_config = self .moe_quant_config ),
1121+ )
1122+ elif self .fp8_backend == Fp8MoeBackend .MARLIN :
1123+ self .kernel = mk .FusedMoEModularKernel (
1124+ MoEPrepareAndFinalizeNoEP (),
1125+ MarlinExperts (quant_config = self .moe_quant_config ),
1126+ )
1127+ else :
1128+ self .kernel = mk .FusedMoEModularKernel (
1129+ MoEPrepareAndFinalizeNoEP (),
1130+ TritonOrDeepGemmExperts (
1131+ quant_config = self .moe_quant_config ,
1132+ allow_deep_gemm = (self .fp8_backend == Fp8MoeBackend .DEEPGEMM ),
1133+ ),
1134+ )
11241135 self .use_inplace = True
11251136
11261137 def maybe_make_prepare_finalize (
11271138 self ,
11281139 routing_tables : tuple [torch .Tensor , torch .Tensor , torch .Tensor ] | None = None ,
11291140 ) -> mk .FusedMoEPrepareAndFinalize | None :
11301141 if (
1131- self .rocm_aiter_moe_enabled
1142+ self .fp8_backend == Fp8MoeBackend . AITER
11321143 or self .fp8_backend == Fp8MoeBackend .MARLIN
11331144 or self .flashinfer_moe_backend == FlashinferMoeBackend .TENSORRT_LLM
11341145 ):
@@ -1161,11 +1172,10 @@ def select_gemm_impl(
11611172 TritonOrDeepGemmExperts ,
11621173 )
11631174
1164- assert (
1165- self .fp8_backend != Fp8MoeBackend .MARLIN
1166- ) and not self .rocm_aiter_moe_enabled , (
1167- "Marlin and ROCm AITER are not supported with all2all yet."
1168- )
1175+ if self .fp8_backend in [Fp8MoeBackend .MARLIN , Fp8MoeBackend .AITER ]:
1176+ raise NotImplementedError (
1177+ "Marlin and ROCm AITER are not supported with all2all yet."
1178+ )
11691179
11701180 assert self .moe_quant_config is not None
11711181
@@ -1313,37 +1323,18 @@ def apply(
13131323 hidden_states = x ,
13141324 router_logits = router_logits ,
13151325 )
1316-
1317- if self .rocm_aiter_moe_enabled :
1318- from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import ( # noqa: E501
1319- rocm_aiter_fused_experts ,
1320- )
1321-
1322- # TODO(rob): convert this to MK.
1323- result = rocm_aiter_fused_experts (
1324- x ,
1325- layer .w13_weight ,
1326- layer .w2_weight ,
1327- topk_weights = topk_weights ,
1328- topk_ids = topk_ids ,
1329- activation = layer .activation ,
1330- apply_router_weight_on_input = layer .apply_router_weight_on_input ,
1331- expert_map = layer .expert_map ,
1332- quant_config = self .moe_quant_config ,
1333- )
1334- else :
1335- result = self .kernel (
1336- x ,
1337- layer .w13_weight ,
1338- layer .w2_weight ,
1339- topk_weights ,
1340- topk_ids ,
1341- inplace = self .use_inplace ,
1342- activation = layer .activation ,
1343- global_num_experts = layer .global_num_experts ,
1344- expert_map = layer .expert_map ,
1345- apply_router_weight_on_input = layer .apply_router_weight_on_input ,
1346- )
1326+ result = self .kernel (
1327+ x ,
1328+ layer .w13_weight ,
1329+ layer .w2_weight ,
1330+ topk_weights ,
1331+ topk_ids ,
1332+ inplace = self .use_inplace ,
1333+ activation = layer .activation ,
1334+ global_num_experts = layer .global_num_experts ,
1335+ expert_map = layer .expert_map ,
1336+ apply_router_weight_on_input = layer .apply_router_weight_on_input ,
1337+ )
13471338
13481339 return result
13491340
@@ -1456,15 +1447,10 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
14561447 layer .w13_input_scale = None
14571448 layer .w2_input_scale = None
14581449
1459- self .rocm_aiter_moe_enabled = False
1460-
14611450 def process_weights_after_loading (self , layer : Module ) -> None :
14621451 if getattr (layer , "_already_called_process_weights_after_loading" , False ):
14631452 return
14641453
1465- # Lazy import to avoid importing triton too early.
1466- self .rocm_aiter_moe_enabled = rocm_aiter_ops .is_fused_moe_enabled ()
1467-
14681454 # If checkpoint is fp16, quantize in place.
14691455 fp8_dtype = current_platform .fp8_dtype ()
14701456 w13_weight = torch .empty_like (layer .w13_weight .data , dtype = fp8_dtype )
@@ -1481,15 +1467,15 @@ def process_weights_after_loading(self, layer: Module) -> None:
14811467 replace_parameter (layer , "w2_weight" , w2_weight )
14821468
14831469 # Reshuffle weights for AITER if needed.
1484- if self .rocm_aiter_moe_enabled :
1470+ if self .fp8_backend == Fp8MoeBackend . AITER :
14851471 shuffled_w13 , shuffled_w2 = rocm_aiter_ops .shuffle_weights (
14861472 layer .w13_weight , layer .w2_weight
14871473 )
14881474 replace_parameter (layer , "w13_weight" , shuffled_w13 )
14891475 replace_parameter (layer , "w2_weight" , shuffled_w2 )
14901476
14911477 # Rushuffle weights for MARLIN if needed.
1492- if self .fp8_backend == Fp8MoeBackend .MARLIN :
1478+ elif self .fp8_backend == Fp8MoeBackend .MARLIN :
14931479 prepare_moe_fp8_layer_for_marlin (
14941480 layer , False , input_dtype = self .marlin_input_dtype
14951481 )
0 commit comments