44
44
45
45
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
46
46
47
+ try :
48
+ from aiter .ops .triton .moe_op_mxfp4 import _fused_moe_kernel_mxfp4
49
+ except ImportError :
50
+ _fused_moe_kernel_mxfp4 = None
51
+
47
52
logger = init_logger (__name__ )
48
53
49
54
@@ -507,6 +512,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
507
512
use_int8_w8a8 : bool ,
508
513
use_int8_w8a16 : bool ,
509
514
use_int4_w4a16 : bool ,
515
+ use_mxfp4_w4a4 : bool ,
510
516
per_channel_quant : bool ,
511
517
block_shape : Optional [list [int ]] = None ,
512
518
B_bias : Optional [torch .Tensor ] = None ) -> None :
@@ -524,6 +530,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
524
530
elif use_int8_w8a16 or use_int4_w4a16 :
525
531
assert B_scale is not None
526
532
assert block_shape is None or block_shape [0 ] == 0
533
+ elif use_mxfp4_w4a4 :
534
+ assert A_scale is not None
535
+ assert B_scale is not None
527
536
else :
528
537
assert A_scale is None
529
538
assert B_scale is None
@@ -611,6 +620,55 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
611
620
use_int8_w8a16 = use_int8_w8a16 ,
612
621
** config ,
613
622
)
623
+ elif use_mxfp4_w4a4 :
624
+ ONE = torch .ones (B .size (0 ), dtype = torch .float32 , device = A .device )
625
+ # overwrite config with a static one for now
626
+ config = {
627
+ "BLOCK_SIZE_M" : 128 ,
628
+ "BLOCK_SIZE_N" : 128 ,
629
+ "BLOCK_SIZE_K" : 128 ,
630
+ "GROUP_SIZE_M" : 4 ,
631
+ "num_warps" : 8 ,
632
+ "num_stages" : 2 ,
633
+ "waves_per_eu" : 0 ,
634
+ "matrix_instr_nonkdim" : 16 ,
635
+ "kpack" : 1 ,
636
+ }
637
+ _fused_moe_kernel_mxfp4 [grid ](
638
+ A ,
639
+ B ,
640
+ C ,
641
+ ONE [0 ],
642
+ ONE ,
643
+ A_scale ,
644
+ B_scale ,
645
+ topk_weights ,
646
+ sorted_token_ids ,
647
+ expert_ids ,
648
+ num_tokens_post_padded ,
649
+ B .size (1 ),
650
+ A .size (1 ),
651
+ EM ,
652
+ num_tokens ,
653
+ A .stride (0 ),
654
+ A .stride (1 ),
655
+ B .stride (0 ),
656
+ B .stride (2 ),
657
+ B .stride (1 ),
658
+ C .stride (1 ),
659
+ C .stride (2 ),
660
+ A_scale .stride (0 ),
661
+ A_scale .stride (1 ),
662
+ B_scale .stride (0 ),
663
+ B_scale .stride (2 ),
664
+ B_scale .stride (1 ),
665
+ MUL_ROUTED_WEIGHT = mul_routed_weight ,
666
+ top_k = top_k ,
667
+ compute_type = compute_type ,
668
+ SWIZZLE_MX_A = False ,
669
+ SWIZZLE_MX_B = False ,
670
+ ** config ,
671
+ )
614
672
else :
615
673
config = config .copy ()
616
674
BLOCK_SIZE_K = config .pop ("BLOCK_SIZE_K" )
@@ -1570,7 +1628,7 @@ def fused_experts_impl(
1570
1628
else :
1571
1629
out_hidden_states = torch .empty_like (hidden_states )
1572
1630
1573
- if use_mxfp4_w4a4 :
1631
+ if use_mxfp4_w4a4 and not current_platform . supports_mx () :
1574
1632
# Weight has to be dequantized for mxfp4 emulation.
1575
1633
w1 = dequant_mxfp4 (w1 , w1_scale , hidden_states .dtype )
1576
1634
w1_scale = None
@@ -1629,6 +1687,8 @@ def fused_experts_impl(
1629
1687
use_int8_w8a8 = use_int8_w8a8 ,
1630
1688
use_int8_w8a16 = use_int8_w8a16 ,
1631
1689
use_int4_w4a16 = use_int4_w4a16 ,
1690
+ use_mxfp4_w4a4 = use_mxfp4_w4a4
1691
+ and current_platform .supports_mx (),
1632
1692
per_channel_quant = per_channel_quant ,
1633
1693
block_shape = block_shape ,
1634
1694
B_bias = w1_bias )
@@ -1687,6 +1747,8 @@ def swiglu_oai(gate_up):
1687
1747
use_int8_w8a8 = use_int8_w8a8 ,
1688
1748
use_int8_w8a16 = use_int8_w8a16 ,
1689
1749
use_int4_w4a16 = use_int4_w4a16 ,
1750
+ use_mxfp4_w4a4 = use_mxfp4_w4a4
1751
+ and current_platform .supports_mx (),
1690
1752
per_channel_quant = per_channel_quant ,
1691
1753
block_shape = block_shape ,
1692
1754
B_bias = w2_bias )
@@ -1994,6 +2056,8 @@ def apply(
1994
2056
use_int8_w8a8 = self .use_int8_w8a8 ,
1995
2057
use_int8_w8a16 = self .use_int8_w8a16 ,
1996
2058
use_int4_w4a16 = self .use_int4_w4a16 ,
2059
+ use_mxfp4_w4a4 = self .use_mxfp4_w4a4
2060
+ and current_platform .supports_mx (),
1997
2061
per_channel_quant = self .per_act_token_quant ,
1998
2062
block_shape = self .block_shape ,
1999
2063
B_bias = None # TODO support B_bias
@@ -2027,6 +2091,8 @@ def apply(
2027
2091
use_int8_w8a8 = self .use_int8_w8a8 ,
2028
2092
use_int8_w8a16 = self .use_int8_w8a16 ,
2029
2093
use_int4_w4a16 = self .use_int4_w4a16 ,
2094
+ use_mxfp4_w4a4 = self .use_mxfp4_w4a4
2095
+ and current_platform .supports_mx (),
2030
2096
per_channel_quant = self .per_act_token_quant ,
2031
2097
block_shape = self .block_shape ,
2032
2098
B_bias = None # TODO support B_bias
0 commit comments