11# SPDX-License-Identifier: Apache-2.0
22from functools import cache
3- from typing import List , Optional , Tuple
3+ from typing import Optional
44
55import torch
66
@@ -16,13 +16,6 @@ def is_rocm_aiter_moe_enabled() -> bool:
1616 and envs .VLLM_ROCM_USE_AITER
1717
1818
19- def is_rocm_aiter_2stage_moe_enabled () -> bool :
20- return current_platform .is_rocm () \
21- and envs .VLLM_ROCM_USE_AITER_2STAGE_MOE \
22- and envs .VLLM_ROCM_USE_AITER_MOE \
23- and envs .VLLM_ROCM_USE_AITER
24-
25-
2619def rocm_aiter_asm_moe_tkw1_impl (
2720 hidden_states : torch .Tensor ,
2821 w1 : torch .Tensor ,
@@ -76,23 +69,6 @@ def rocm_aiter_asm_moe_tkw1_fake(
7669 return torch .empty_like (hidden_states )
7770
7871
79- def rocm_aiter_ck_moe_impl (hidden_states : torch .Tensor , w1 : torch .Tensor ,
80- w2 : torch .Tensor , topk_weights : torch .Tensor ,
81- topk_ids : torch .Tensor ) -> torch .Tensor :
82- from aiter import ck_moe
83- return ck_moe (hidden_states = hidden_states ,
84- w1 = w1 ,
85- w2 = w2 ,
86- topk_weights = topk_weights ,
87- topk_ids = topk_ids )
88-
89-
90- def rocm_aiter_ck_moe_fake (hidden_states : torch .Tensor , w1 : torch .Tensor ,
91- w2 : torch .Tensor , topk_weights : torch .Tensor ,
92- topk_ids : torch .Tensor ) -> torch .Tensor :
93- return torch .empty_like (hidden_states )
94-
95-
9672def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl (
9773 topk_ids : torch .Tensor ,
9874 topk_weights : torch .Tensor ,
@@ -215,10 +191,9 @@ def rocm_aiter_ck_moe_2stages_impl(
215191 fc2_scale : Optional [torch .Tensor ] = None ,
216192 a1_scale : Optional [torch .Tensor ] = None ,
217193 a2_scale : Optional [torch .Tensor ] = None ,
218- block_size : Optional [List [int ]] = None ,
194+ block_size : Optional [list [int ]] = None ,
219195 expert_mask : Optional [torch .Tensor ] = None ,
220196) -> torch .Tensor :
221-
222197 from aiter .fused_moe_bf16_asm import ck_moe_2stages
223198 return ck_moe_2stages (a1 = hidden_states ,
224199 w1 = w1 ,
@@ -243,7 +218,7 @@ def rocm_aiter_ck_moe_2stages_fake(
243218 fc2_scale : Optional [torch .Tensor ] = None ,
244219 a1_scale : Optional [torch .Tensor ] = None ,
245220 a2_scale : Optional [torch .Tensor ] = None ,
246- block_size : Optional [List [int ]] = None ,
221+ block_size : Optional [list [int ]] = None ,
247222 expert_mask : Optional [torch .Tensor ] = None ,
248223) -> torch .Tensor :
249224 return torch .empty_like (hidden_states )
@@ -308,14 +283,6 @@ def rocm_aiter_biased_grouped_topk_fake(
308283 dispatch_key = current_platform .dispatch_key ,
309284 )
310285
311- direct_register_custom_op (
312- op_name = "rocm_aiter_ck_moe" ,
313- op_func = rocm_aiter_ck_moe_impl ,
314- mutates_args = [],
315- fake_impl = rocm_aiter_ck_moe_fake ,
316- dispatch_key = current_platform .dispatch_key ,
317- )
318-
319286 direct_register_custom_op (
320287 op_name = "rocm_aiter_fmoe_fp8_blockscale_g1u1" ,
321288 op_func = rocm_aiter_fmoe_fp8_blockscale_g1u1_impl ,
@@ -390,31 +357,20 @@ def rocm_aiter_biased_group_topk(
390357
391358
392359def rocm_aiter_fused_experts (
393- hidden_states : torch .Tensor ,
394- w1 : torch .Tensor ,
395- w2 : torch .Tensor ,
396- topk_weights : torch .Tensor ,
397- topk_ids : torch .Tensor ,
398- inplace : bool = False ,
399- activation : str = "silu" ,
400- apply_router_weight_on_input : bool = False ,
401- use_fp8_w8a8 : bool = False ,
402- use_int8_w8a8 : bool = False ,
403- use_int8_w8a16 : bool = False ,
404- use_int4_w4a16 : bool = False ,
405- per_channel_quant : bool = False ,
406- global_num_experts : int = - 1 ,
407- expert_map : Optional [torch .Tensor ] = None ,
408- w1_scale : Optional [torch .Tensor ] = None ,
409- w2_scale : Optional [torch .Tensor ] = None ,
410- w1_zp : Optional [torch .Tensor ] = None ,
411- w2_zp : Optional [torch .Tensor ] = None ,
412- a1_scale : Optional [torch .Tensor ] = None ,
413- a2_scale : Optional [torch .Tensor ] = None ,
414- block_shape : Optional [List [int ]] = None ,
415- allow_deep_gemm : bool = False ,
416- use_ck_moe_2stages : bool = False ,
417- ) -> torch .Tensor :
360+ hidden_states : torch .Tensor ,
361+ w1 : torch .Tensor ,
362+ w2 : torch .Tensor ,
363+ topk_weights : torch .Tensor ,
364+ topk_ids : torch .Tensor ,
365+ activation : str = "silu" ,
366+ apply_router_weight_on_input : bool = False ,
367+ use_fp8_w8a8 : bool = False ,
368+ per_channel_quant : bool = False ,
369+ w1_scale : Optional [torch .Tensor ] = None ,
370+ w2_scale : Optional [torch .Tensor ] = None ,
371+ a1_scale : Optional [torch .Tensor ] = None ,
372+ a2_scale : Optional [torch .Tensor ] = None ,
373+ block_shape : Optional [list [int ]] = None ) -> torch .Tensor :
418374
419375 from vllm .model_executor .layers .quantization .utils .fp8_utils import (
420376 per_token_group_quant_fp8 )
@@ -465,7 +421,7 @@ def rocm_aiter_fused_experts(
465421 fc2_smooth_scale = None ,
466422 a16 = False ,
467423 per_tensor_quant_scale = None ,
468- expert_mask = expert_map ,
424+ expert_mask = None ,
469425 activation_str = activation )
470426
471427 # w8a8 per-tensor activation per-tensor weight
@@ -475,7 +431,7 @@ def rocm_aiter_fused_experts(
475431
476432 # - faster static per-tensor-activation static per-tensor-weight
477433 # fp8 quantization w8a8
478- if use_ck_moe_2stages and a1_scale is not None and a2_scale is not None :
434+ if a1_scale is not None and a2_scale is not None :
479435 return torch .ops .vllm .rocm_aiter_ck_moe_2stages (
480436 hidden_states = hidden_states ,
481437 w1 = w1 ,
@@ -514,28 +470,19 @@ def rocm_aiter_fused_experts(
514470 topk_ids = topk_ids .to (torch .int32 )
515471 topk_weights = torch .ones_like (topk_weights , dtype = torch .float32 )
516472
517- # faster w16a16
518- if use_ck_moe_2stages :
519- return torch .ops .vllm .rocm_aiter_ck_moe_2stages (
520- hidden_states = hidden_states ,
521- w1 = w1 ,
522- w2 = w2 ,
523- topk_weights = topk_weights ,
524- topk_ids = topk_ids )
525-
526- # w16a16 fallback to rocm_aiter_ck_moe w16a16
527- return torch .ops .vllm .rocm_aiter_ck_moe (hidden_states = hidden_states ,
528- w1 = w1 ,
529- w2 = w2 ,
530- topk_weights = topk_weights ,
531- topk_ids = topk_ids )
473+ return torch .ops .vllm .rocm_aiter_ck_moe_2stages (
474+ hidden_states = hidden_states ,
475+ w1 = w1 ,
476+ w2 = w2 ,
477+ topk_weights = topk_weights ,
478+ topk_ids = topk_ids )
532479
533480
534481def rocm_aiter_topk_softmax (topk_weights : torch .Tensor ,
535482 topk_indices : torch .Tensor ,
536483 token_expert_indices : torch .Tensor ,
537484 gating_output : torch .Tensor ,
538- renormalize : bool ) -> Tuple [torch .Tensor , ...]:
485+ renormalize : bool ) -> tuple [torch .Tensor , ...]:
539486 torch .ops .vllm .rocm_aiter_topk_softmax (topk_weights , topk_indices ,
540487 token_expert_indices , gating_output ,
541488 renormalize )
@@ -560,7 +507,7 @@ def shuffle_weights(*tensors: torch.Tensor,
560507
561508
562509def expand_weights (* tensors : torch .Tensor ,
563- expansion_dims : List [int ]) -> Tuple [torch .Tensor , ...]:
510+ expansion_dims : list [int ]) -> tuple [torch .Tensor , ...]:
564511 """
565512 Expands the dimensions of input tensors.
566513
@@ -570,12 +517,12 @@ def expand_weights(*tensors: torch.Tensor,
570517 corresponding to each tensor.
571518
572519 Returns:
573- A Tuple of tensors with expanded dimensions.
520+ A tuple of tensors with expanded dimensions.
574521 """
575522
576523 assert len (tensors ) == len (expansion_dims ), \
577524 "Number of tensors must match the number of expansion dimensions."
578525
579526 return tuple (
580527 tensor .unsqueeze (- 1 ).unsqueeze (- 1 ).expand ((- 1 , dim , - 1 ))
581- for tensor , dim in zip (tensors , expansion_dims ))
528+ for tensor , dim in zip (tensors , expansion_dims ))
0 commit comments