3434class MetaData ():
3535 use_fp8_w8a8 = False
3636 use_int8_w8a16 = False
37+ use_int8_w8a8 = False
3738
38- def __init__ (self , topk_weights , topk_ids , sorted_token_ids , expert_ids , num_tokens_post_padded , config ):
39+ def __init__ (self , top_k , topk_weights , topk_ids , sorted_token_ids , expert_ids , num_tokens_post_padded , config ):
40+ self .top_k = top_k
3941 self .topk_weights = topk_weights
4042 self .topk_ids = topk_ids
4143 self .sorted_token_ids = sorted_token_ids
@@ -54,10 +56,15 @@ def set_use_int8_w8a16(self, b_descale):
5456 self .b_descale = b_descale
5557 self .a_descale = None
5658
59+ def set_use_int8_w8a8 (self , a_descale , b_descale ):
60+ self .use_int8_w8a8 = True
61+ self .a_descale = a_descale
62+ self .b_descale = b_descale
63+
5764 def check_args (self , a , b , o ):
5865 assert a .shape [- 1 ] == b .shape [- 1 ] and b .shape [1 ] == o .shape [- 1 ]
5966
60- assert not (self .use_fp8_w8a8 and self .use_int8_w8a16 )
67+ assert not (self .use_fp8_w8a8 and self .use_int8_w8a16 and self . use_int8_w8a8 )
6168 if self .use_fp8_w8a8 :
6269 assert self .fp8_type in supported_fp8 , f"fp8 type { self .fp8_type } not supported"
6370
@@ -89,6 +96,7 @@ def moe_gemm_kernel(
8996 MUL_ROUTED_WEIGHT : tl .constexpr ,
9097 use_fp8_w8a8 : tl .constexpr ,
9198 use_int8_w8a16 : tl .constexpr ,
99+ use_int8_w8a8 : tl .constexpr ,
92100 BLOCK_SIZE_M : tl .constexpr ,
93101 BLOCK_SIZE_N : tl .constexpr ,
94102 BLOCK_SIZE_K : tl .constexpr ,
@@ -146,7 +154,7 @@ def moe_gemm_kernel(
146154 b_scale_ptrs = B_scale + off_experts * stride_bse + offs_bn [None , :] * stride_bsn
147155 b_scale = tl .load (b_scale_ptrs )
148156
149- if use_fp8_w8a8 :
157+ if use_fp8_w8a8 or use_int8_w8a8 :
150158 a_scale = tl .load (A_scale )
151159 b_scale = tl .load (B_scale + off_experts )
152160
@@ -163,7 +171,7 @@ def moe_gemm_kernel(
163171
164172 if use_int8_w8a16 :
165173 accumulator = tl .dot (a , b .to (a .dtype ), acc = accumulator )
166- elif use_fp8_w8a8 :
174+ elif use_fp8_w8a8 or use_int8_w8a8 :
167175 accumulator += tl .dot (a , b )
168176 else :
169177 accumulator = tl .dot (a , b , acc = accumulator )
@@ -177,7 +185,7 @@ def moe_gemm_kernel(
177185
178186 if use_int8_w8a16 :
179187 accumulator = (accumulator * b_scale ).to (Out .dtype .element_ty )
180- elif use_fp8_w8a8 :
188+ elif use_fp8_w8a8 or use_int8_w8a8 :
181189 accumulator = (accumulator * a_scale * b_scale ).to (Out .dtype .element_ty )
182190 else :
183191 accumulator = accumulator .to (Out .dtype .element_ty )
@@ -278,11 +286,13 @@ def moe_align_block_size(topk_ids: torch.Tensor, block_size: int,
278286
279287
280288def get_config_dtype_str (dtype : torch .dtype , use_int8_w8a16 : Optional [bool ] = False ,
281- use_fp8_w8a8 : Optional [bool ] = False ):
289+ use_int8_w8a8 : Optional [ bool ] = False , use_fp8_w8a8 : Optional [bool ] = False ):
282290 if use_fp8_w8a8 :
283291 return "fp8_w8a8"
284292 elif use_int8_w8a16 :
285293 return "int8_w8a16"
294+ elif use_int8_w8a8 :
295+ return "int8_w8a8"
286296 elif dtype == torch .float :
287297 # avoiding cases where kernel fails when float32 MoE
288298 # use fp16/bfloat16 configs
@@ -360,19 +370,19 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, metadata: MetaDa
360370 # TODO shard M dim
361371 metadata .check_args (a , b , c )
362372
363- topk_ids , num_tokens_post_padded , topk_weights , sorted_token_ids , expert_ids , config = metadata . topk_ids , metadata .num_tokens_post_padded , metadata .topk_weights , metadata .sorted_token_ids , metadata .expert_ids , metadata .config
373+ num_tokens_post_padded , topk_weights , sorted_token_ids , expert_ids , config = metadata .num_tokens_post_padded , metadata .topk_weights , metadata .sorted_token_ids , metadata .expert_ids , metadata .config
364374
365- use_fp8_w8a8 , use_int8_w8a16 = metadata .use_fp8_w8a8 , metadata .use_int8_w8a16
375+ use_fp8_w8a8 , use_int8_w8a16 , use_int8_w8a8 = metadata .use_fp8_w8a8 , metadata .use_int8_w8a16 , metadata . use_int8_w8a8
366376 a_descale , b_descale = None , None
367377 stride_bse = None
368378 stride_bsn = None
369- if use_fp8_w8a8 or use_int8_w8a16 :
379+ if use_fp8_w8a8 or use_int8_w8a16 or use_int8_w8a8 :
370380 a_descale , b_descale = metadata .a_descale , metadata .b_descale
371381 if use_int8_w8a16 :
372382 stride_bse = b_descale .stride (0 )
373383 stride_bsn = b_descale .stride (1 )
374384
375- _ , top_k = topk_ids . shape
385+ top_k = metadata . top_k
376386
377387 EM = num_tokens_post_padded .item ()
378388 _ , N , K = b .shape
@@ -384,7 +394,7 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, metadata: MetaDa
384394 b_descale , a .stride (0 ), a .stride (1 ), b .stride (0 ), b .stride (1 ), b .stride (2 ), c .stride (1 ),
385395 c .stride (2 ), stride_bse , stride_bsn , top_k , topk_weights , sorted_token_ids , expert_ids , EM , N ,
386396 K , EVEN_K , MUL_ROUTED_WEIGHT = topk_weights is not None , use_fp8_w8a8 = use_fp8_w8a8 ,
387- use_int8_w8a16 = use_int8_w8a16 , ** config )
397+ use_int8_w8a16 = use_int8_w8a16 , use_int8_w8a8 = use_int8_w8a8 , ** config )
388398 return c
389399
390400
@@ -410,8 +420,9 @@ def quantize_tensor(tensor: torch.Tensor, dtype, dim=()) -> tuple[torch.Tensor,
410420 return tensor_quantized , scale , 1 / scale
411421
412422
413- def quantize_input (a , b , use_fp8_w8a8 : tl .constexpr , use_int8_w8a16 : tl .constexpr , metatdata : MetaData , fp8_type = None ):
414- assert not (use_fp8_w8a8 and use_int8_w8a16 )
423+ def quantize_input (a , b , use_fp8_w8a8 : tl .constexpr , use_int8_w8a16 : tl .constexpr , use_int8_w8a8 : tl .constexpr ,
424+ metatdata : MetaData , fp8_type = None ):
425+ assert not (use_fp8_w8a8 and use_int8_w8a16 and use_int8_w8a8 )
415426 assert not (use_fp8_w8a8 and fp8_type is None )
416427
417428 if use_fp8_w8a8 :
@@ -420,14 +431,20 @@ def quantize_input(a, b, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexp
420431 metatdata .set_use_fp8_w8a8 (a_descale , b_descale , fp8_type )
421432 return a_quantized , b_quantized
422433
434+ if use_int8_w8a8 :
435+ a_quantized , _ , a_descale = quantize_tensor (a , dtype = torch .int8 )
436+ b_quantized , _ , b_descale = quantize_tensor (b , dim = (0 , ), dtype = torch .int8 )
437+ metatdata .set_use_int8_w8a8 (a_descale , b_descale )
438+ return a_quantized , b_quantized
439+
423440 if use_int8_w8a16 :
424441 b_quantized , _ , b_descale = quantize_tensor (b , dim = (0 , 1 ), dtype = torch .int8 )
425442 metatdata .set_use_int8_w8a16 (b_descale )
426443 return a , b_quantized
427444
428445
429446def input_helper (M : int , N : int , K : int , top_k : int , E : int , routed_weight : bool , use_fp8_w8a8 : bool ,
430- use_int8_w8a16 : bool , fp8_type , dtype ):
447+ use_int8_w8a16 : bool , use_int8_w8a8 : bool , fp8_type , dtype ):
431448 a = torch .randn ((M , K ), dtype = dtype , device = 'cuda' )
432449 b = torch .randn ((E , N , K ), dtype = dtype , device = 'cuda' )
433450 c = torch .zeros ((M , top_k , N ), dtype = dtype , device = 'cuda' )
@@ -437,7 +454,8 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool
437454 softmax_vals = torch .softmax (values , dim = 1 )
438455 topk_weights , topk_ids = torch .topk (softmax_vals , k = top_k , dim = 1 )
439456
440- config_dtype = get_config_dtype_str (use_fp8_w8a8 = use_fp8_w8a8 , use_int8_w8a16 = use_int8_w8a16 , dtype = dtype )
457+ config_dtype = get_config_dtype_str (use_fp8_w8a8 = use_fp8_w8a8 , use_int8_w8a16 = use_int8_w8a16 ,
458+ use_int8_w8a8 = use_int8_w8a8 , dtype = dtype )
441459 get_config_func = functools .partial (
442460 try_get_optimal_moe_config ,
443461 E ,
@@ -446,11 +464,11 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool
446464 config = get_config_func (M )
447465 sorted_token_ids , expert_ids , num_tokens_post_padded = moe_align_block_size (topk_ids , config ['BLOCK_SIZE_M' ], E )
448466
449- metadata = MetaData (topk_weights if routed_weight else None , topk_ids , sorted_token_ids , expert_ids ,
467+ metadata = MetaData (top_k , topk_weights if routed_weight else None , topk_ids , sorted_token_ids , expert_ids ,
450468 num_tokens_post_padded , config )
451469
452- if use_fp8_w8a8 or use_int8_w8a16 :
453- a , b = quantize_input (a , b , use_fp8_w8a8 , use_int8_w8a16 , metadata , fp8_type )
470+ if use_fp8_w8a8 or use_int8_w8a16 or use_int8_w8a8 :
471+ a , b = quantize_input (a , b , use_fp8_w8a8 , use_int8_w8a16 , use_int8_w8a8 , metadata , fp8_type )
454472
455473 return a , b , c , metadata
456474
@@ -471,7 +489,7 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool
471489def test_correctness (M : int , N : int , K : int , top_k : int , E : int , routed_weight : bool , dtype = torch .float16 ):
472490 torch .manual_seed (20 )
473491 a , b , c , metadata = input_helper (M , N , K , top_k , E , routed_weight = routed_weight , use_fp8_w8a8 = False ,
474- use_int8_w8a16 = False , fp8_type = None , dtype = dtype )
492+ use_int8_w8a16 = False , use_int8_w8a8 = False , fp8_type = None , dtype = dtype )
475493
476494 tri_out = moe_gemm (a , b , c , metadata )
477495
@@ -508,7 +526,7 @@ def test_correctness_fp8(M: int, N: int, K: int, top_k: int, E: int, routed_weig
508526 dtype = torch .float16 ):
509527 torch .manual_seed (20 )
510528 a , b , c , metadata = input_helper (M , N , K , top_k , E , routed_weight = routed_weight , use_fp8_w8a8 = use_fp8_w8a8 ,
511- use_int8_w8a16 = False , fp8_type = fp8_type , dtype = dtype )
529+ use_int8_w8a16 = False , fp8_type = fp8_type , use_int8_w8a8 = False , dtype = dtype )
512530
513531 tri_out = moe_gemm (a , b , c , metadata )
514532
@@ -545,11 +563,11 @@ def test_correctness_fp8(M: int, N: int, K: int, top_k: int, E: int, routed_weig
545563])
546564@pytest .mark .parametrize ('routed_weight' , [True , False ])
547565@pytest .mark .parametrize ('use_int8_w8a16' , [True ])
548- def test_correctness_int8 (M : int , N : int , K : int , top_k : int , E : int , routed_weight : bool , use_int8_w8a16 ,
549- dtype = torch .float16 ):
566+ def test_correctness_int8_w8a16 (M : int , N : int , K : int , top_k : int , E : int , routed_weight : bool , use_int8_w8a16 ,
567+ dtype = torch .float16 ):
550568 torch .manual_seed (20 )
551569 a , b , c , metadata = input_helper (M , N , K , top_k , E , routed_weight = routed_weight , use_fp8_w8a8 = False ,
552- use_int8_w8a16 = use_int8_w8a16 , fp8_type = None , dtype = dtype )
570+ use_int8_w8a16 = use_int8_w8a16 , use_int8_w8a8 = False , fp8_type = None , dtype = dtype )
553571
554572 tri_out = moe_gemm (a , b , c , metadata )
555573
@@ -560,7 +578,7 @@ def test_correctness_int8(M: int, N: int, K: int, top_k: int, E: int, routed_wei
560578 a_expanded = a .unsqueeze (1 ).repeat (1 , top_k , 1 )
561579 # (M, top_k, N, K)
562580 b_indexed = b [topk_ids ]
563- ref_out = torch .einsum ("mek,menk->men" , a_expanded .to ( torch . float32 ), b_indexed .to ( torch . float32 ))
581+ ref_out = torch .einsum ("mek,menk->men" , a_expanded .float ( ), b_indexed .float ( ))
564582 if routed_weight :
565583 ref_out *= topk_weights .unsqueeze (- 1 )
566584
@@ -571,6 +589,46 @@ def test_correctness_int8(M: int, N: int, K: int, top_k: int, E: int, routed_wei
571589 torch .testing .assert_close (tri_out , ref_out , atol = 1e-2 , rtol = 1e-2 )
572590
573591
592+ @pytest .mark .parametrize ("M, N, K, top_k, E" , [
593+ (64 , 14336 , 4096 , 2 , 8 ),
594+ (16 , 14336 , 1 , 2 , 4 ),
595+ (1 , 14336 , 128 , 2 , 4 ),
596+ (16 , 14336 , 128 , 1 , 4 ),
597+ (16 , 14336 , 128 , 1 , 1 ),
598+ (64 , 7186 , 128 , 2 , 8 ),
599+ (64 , 3584 , 128 , 2 , 8 ),
600+ (64 , 1792 , 128 , 2 , 8 ),
601+ (64 , 64 , 128 , 2 , 8 ),
602+ ])
603+ @pytest .mark .parametrize ('routed_weight' , [True , False ])
604+ @pytest .mark .parametrize ('use_int8_w8a8' , [True ])
605+ def test_correctness_int8_w8a8 (M : int , N : int , K : int , top_k : int , E : int , routed_weight : bool , use_int8_w8a8 ,
606+ dtype = torch .float16 ):
607+ torch .manual_seed (20 )
608+ a , b , c , metadata = input_helper (M , N , K , top_k , E , routed_weight = routed_weight , use_fp8_w8a8 = False ,
609+ use_int8_w8a16 = False , use_int8_w8a8 = use_int8_w8a8 , fp8_type = None , dtype = dtype )
610+
611+ tri_out = moe_gemm (a , b , c , metadata )
612+
613+ topk_ids = metadata .topk_ids
614+ topk_weights = metadata .topk_weights
615+ ref_out = torch .empty_like (c )
616+ # Repeat a -> (M, top_k, K)
617+ a_expanded = a .unsqueeze (1 ).repeat (1 , top_k , 1 )
618+ # (M, top_k, N, K)
619+ b_indexed = b [topk_ids ]
620+ ref_out = torch .einsum ("mek,menk->men" , a_expanded .float (), b_indexed .float ())
621+ if routed_weight :
622+ ref_out *= topk_weights .unsqueeze (- 1 )
623+
624+ ref_out = ref_out * metadata .b_descale [topk_ids ].unsqueeze (- 1 )
625+ ref_out = ref_out * metadata .a_descale
626+ ref_out = ref_out .to (dtype )
627+
628+ # Validate correctness
629+ torch .testing .assert_close (tri_out , ref_out , atol = 1e-2 , rtol = 1e-2 )
630+
631+
574632def get_configs ():
575633 configs = [
576634 {"M" : 64 , "N" : 256 , "K" : 128 , "E" : 8 , "top_k" : 2 },
@@ -606,8 +664,10 @@ def model_benchmark_configs(args):
606664
607665 E = 8
608666 top_k = 2
667+ # The first moe layer
609668 moe_configs .append ((model_name , M , N1 , K1 , E , top_k ))
610- moe_configs .append ((model_name , M , N2 , K2 , E , top_k ))
669+ # The second moe layer
670+ moe_configs .append ((model_name , M * top_k , N2 , K2 , E , 1 ))
611671
612672 return moe_configs
613673
@@ -616,6 +676,7 @@ def run_benchmark(custom, args):
616676 routed_weight = args .routed_weight
617677 use_int8_w8a16 = args .int8_w8a16
618678 use_fp8_w8a8 = args .fp8_w8a8
679+ use_int8_w8a8 = args .int8_w8a8
619680 dtype = arg_to_torch_dtype [args .dtype ]
620681 fp8_type = arg_to_torch_dtype [args .fp8_type ]
621682
@@ -640,14 +701,15 @@ def run_benchmark(custom, args):
640701 styles = [('red' , '-' ), ('blue' , '-' ),
641702 ('yellow' , '-' )], ylabel = 'ms / TFLOPS / GB/s' , plot_name = 'moe-gemm-benchmark' , args = {
642703 'dtype' : dtype , 'routed_weight' : routed_weight , 'use_fp8_w8a8' : use_fp8_w8a8 , 'use_int8_w8a16' :
643- use_int8_w8a16 , 'fp8_type' : fp8_type
704+ use_int8_w8a16 , 'use_int8_w8a8' : use_int8_w8a8 , ' fp8_type' : fp8_type
644705 })
645706
646707 @triton .testing .perf_report ([benchmark ])
647- def bench_moe_gemm (M , N , K , E , top_k , dtype , routed_weight , metric , use_fp8_w8a8 , use_int8_w8a16 , fp8_type ,
648- model = None ):
708+ def bench_moe_gemm (M , N , K , E , top_k , dtype , routed_weight , metric , use_fp8_w8a8 , use_int8_w8a16 , use_int8_w8a8 ,
709+ fp8_type , model = None ):
649710 a , b , c , metadata = input_helper (M , N , K , top_k , E , routed_weight = routed_weight , use_fp8_w8a8 = use_fp8_w8a8 ,
650- use_int8_w8a16 = use_int8_w8a16 , fp8_type = fp8_type , dtype = dtype )
711+ use_int8_w8a16 = use_int8_w8a16 , use_int8_w8a8 = use_int8_w8a8 , fp8_type = fp8_type ,
712+ dtype = dtype )
651713
652714 # (M, K) * (top_k, N, K) -> (M, top_k, N). 2 for multiplication and accumulation
653715 flops = 2.0 * M * top_k * K * N
@@ -658,6 +720,9 @@ def bench_moe_gemm(M, N, K, E, top_k, dtype, routed_weight, metric, use_fp8_w8a8
658720 if use_fp8_w8a8 :
659721 a_bytes = b_bytes = torch .tensor ([], dtype = fp8_type ).element_size ()
660722 c_bytes = torch .tensor ([], dtype = dtype ).element_size ()
723+ if use_int8_w8a8 :
724+ a_bytes = b_bytes = torch .tensor ([], dtype = torch .int8 ).element_size ()
725+ c_bytes = torch .tensor ([], dtype = torch .int8 ).element_size ()
661726 elif use_int8_w8a16 :
662727 b_bytes = torch .tensor ([], dtype = torch .int8 ).element_size ()
663728 a_bytes = c_bytes = torch .tensor ([], dtype = dtype ).element_size ()
@@ -705,6 +770,7 @@ def parse_args():
705770 parser .add_argument ("-top_k" , type = int , default = 0 , help = "top_k experts per token" )
706771 parser .add_argument ("-routed_weight" , action = 'store_true' , default = False )
707772 parser .add_argument ("-int8_w8a16" , action = 'store_true' , default = False )
773+ parser .add_argument ("-int8_w8a8" , action = 'store_true' , default = False )
708774 parser .add_argument ("-fp8_w8a8" , action = 'store_true' , default = False )
709775 parser .add_argument ("-dtype" , default = 'fp16' )
710776 parser .add_argument ("-fp8_type" , default = 'e5m2fnuz' )
0 commit comments