@@ -421,9 +421,9 @@ def block_scale_mxfp_matmul( #
421421 stride_cm , stride_cn , #
422422 BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , BLOCK_K : tl .constexpr , #
423423 NUM_STAGES : tl .constexpr , USE_2D_SCALE_LOAD : tl .constexpr ):
424- ## This kernel assumes a_scale and b_scale are coming in with shapes
425- ## [BLOCK_M(or N) // 128, BLOCK_K // 128, 32, 4, 4] for optimial performance
426- ## on nvidia sm100+ HW
424+ # This kernel assumes a_scale and b_scale are coming in with shapes
425+ # [BLOCK_M(or N) // 128, BLOCK_K // 128, 32, 4, 4] for optimial performance
426+ # on nvidia sm100+ HW
427427 pid = tl .program_id (axis = 0 )
428428 num_pid_m = tl .cdiv (M , BLOCK_M )
429429 pid_m = pid % num_pid_m
@@ -482,18 +482,21 @@ def block_scale_mxfp_matmul( #
482482
483483
484484@triton .jit
485- def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4 (a_ptr , b_ptr , c_ptr , a_scales_ptr , b_scales_ptr , M , N , K , stride_am ,
486- stride_ak , stride_bk , stride_bn , stride_ck , stride_cm , stride_cn ,
487- stride_asm , stride_ask , stride_bsn , stride_bsk ,
488- # Meta-parameters
489- BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , BLOCK_K : tl .constexpr ,
490- mfma_nonkdim : tl .constexpr , preshuffle : tl .constexpr ):
485+ def _gemm_kernel_preshuffled_scales_cdna4 (a_ptr , b_ptr , c_ptr , a_scales_ptr , b_scales_ptr , M , N , K , stride_am ,
486+ stride_ak , stride_bk , stride_bn , stride_cm , stride_cn , stride_asm , stride_ask ,
487+ stride_bsn , stride_bsk ,
488+ # Meta-parameters
489+ DTYPE_A : tl .constexpr , DTYPE_B : tl .constexpr , BLOCK_M : tl .constexpr ,
490+ BLOCK_N : tl .constexpr , BLOCK_K : tl .constexpr , mfma_nonkdim : tl .constexpr ,
491+ preshuffle : tl .constexpr , fast_math : tl .constexpr ):
491492 """Kernel for computing the matmul C = A x B.
492- A and B inputs are in the microscale fp4 (mxfp4) format.
493493 A_scales and B_scales are in e8m0 format.
494494 A has shape (M, K), B has shape (K, N) and C has shape (M, N)
495495 """
496496
497+ PACK_FACTOR_A : tl .constexpr = 2 if DTYPE_A == "e2m1" else 1
498+ PACK_FACTOR_B : tl .constexpr = 2 if DTYPE_B == "e2m1" else 1
499+
497500 pid = tl .program_id (axis = 0 )
498501
499502 num_pid_n = tl .cdiv (N , BLOCK_N )
@@ -502,73 +505,99 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scale
502505
503506 # We assume 32 elements along K share the same scale.
504507 SCALE_GROUP_SIZE : tl .constexpr = 32
508+ MX_SCALE_BLOCK_K : tl .constexpr = BLOCK_K // SCALE_GROUP_SIZE
505509
506510 if preshuffle :
507511 NON_K_PRESHUFFLE_BLOCK_SIZE : tl .constexpr = 32
508512 else :
509513 NON_K_PRESHUFFLE_BLOCK_SIZE : tl .constexpr = 1
510514
511- num_k_iter = tl .cdiv (K , BLOCK_K // 2 )
512515 # Create pointers for first block of A and B input matrices
513516 # The BLOCK sizes are of the elements and in fp4 we pack 2 per uint8 container.
514- offs_k = tl .arange (0 , BLOCK_K // 2 )
515- offs_k_split = offs_k
517+ offs_ak = tl .arange (0 , BLOCK_K // PACK_FACTOR_A )
518+ offs_bk = tl . arange ( 0 , BLOCK_K // PACK_FACTOR_B )
516519 offs_am = (pid_m * BLOCK_M + tl .arange (0 , BLOCK_M )) % M
517520 offs_bn = (pid_n * BLOCK_N + tl .arange (0 , BLOCK_N )) % N
518- a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_k_split [None , :] * stride_ak )
519- b_ptrs = b_ptr + (offs_k_split [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
521+ a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_ak [None , :] * stride_ak )
522+ b_ptrs = b_ptr + (offs_bk [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
520523
521524 # Create pointers for the first block of A and B scales
522- offs_asn = (pid_n *
523- (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ) + tl .arange (0 , (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ))) % N
524- offs_ks = tl .arange (0 , BLOCK_K // SCALE_GROUP_SIZE * NON_K_PRESHUFFLE_BLOCK_SIZE )
525+ offs_ks = tl .arange (0 , MX_SCALE_BLOCK_K * NON_K_PRESHUFFLE_BLOCK_SIZE )
525526
526527 # B scales are N x K even though B operand is K x N.
527- b_scale_ptrs = (b_scales_ptr + offs_asn [:, None ] * stride_bsn + offs_ks [None , :] * stride_bsk )
528- offs_asm = (pid_m *
529- (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ) + tl .arange (0 , (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ))) % M
530- a_scale_ptrs = (a_scales_ptr + offs_asm [:, None ] * stride_asm + offs_ks [None , :] * stride_ask )
528+ if a_scales_ptr is not None :
529+ offs_asm = (pid_m *
530+ (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ) + tl .arange (0 ,
531+ (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ))) % M
532+ a_scale_ptrs = (a_scales_ptr + offs_asm [:, None ] * stride_asm + offs_ks [None , :] * stride_ask )
533+ if b_scales_ptr is not None :
534+ offs_asn = (pid_n *
535+ (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ) + tl .arange (0 ,
536+ (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ))) % N
537+ b_scale_ptrs = (b_scales_ptr + offs_asn [:, None ] * stride_bsn + offs_ks [None , :] * stride_bsk )
531538 accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
532539
533- for k in range (0 , num_k_iter ):
540+ for k in range (0 , tl . cdiv ( K , BLOCK_K ) ):
534541 if preshuffle :
535542 # Here we "undo" the shuffle done in global memory (shuffle_scales_cdna4 function).
536543 if mfma_nonkdim == 32 :
537- a_scales = tl .load (a_scale_ptrs ).reshape (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ,
538- BLOCK_K // SCALE_GROUP_SIZE // 8 , 2 , 32 , 4 ,
539- 1 ).permute (0 , 3 , 1 , 4 , 2 ,
540- 5 ).reshape (BLOCK_M , BLOCK_K // SCALE_GROUP_SIZE )
541- b_scales = tl .load (b_scale_ptrs ).reshape (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ,
542- BLOCK_K // SCALE_GROUP_SIZE // 8 , 2 , 32 , 4 ,
543- 1 ).permute (0 , 3 , 1 , 4 , 2 ,
544- 5 ).reshape (BLOCK_N , BLOCK_K // SCALE_GROUP_SIZE )
544+ if a_scales_ptr is not None :
545+ a_scales = tl .load (a_scale_ptrs ).reshape (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ,
546+ MX_SCALE_BLOCK_K // 8 , 2 , 32 , 4 ,
547+ 1 ).permute (0 , 3 , 1 , 4 , 2 ,
548+ 5 ).reshape (BLOCK_M , MX_SCALE_BLOCK_K )
549+ else :
550+ a_scales = None
551+ if b_scales_ptr is not None :
552+ b_scales = tl .load (b_scale_ptrs ).reshape (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ,
553+ MX_SCALE_BLOCK_K // 8 , 2 , 32 , 4 ,
554+ 1 ).permute (0 , 3 , 1 , 4 , 2 ,
555+ 5 ).reshape (BLOCK_N , MX_SCALE_BLOCK_K )
556+ else :
557+ b_scales = None
545558 elif mfma_nonkdim == 16 :
546- a_scales = tl .load (a_scale_ptrs ).reshape (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ,
547- BLOCK_K // SCALE_GROUP_SIZE // 8 , 4 , 16 , 2 , 2 ,
548- 1 ).permute (0 , 5 , 3 , 1 , 4 , 2 ,
549- 6 ).reshape (BLOCK_M , BLOCK_K // SCALE_GROUP_SIZE )
550- b_scales = tl .load (b_scale_ptrs ).reshape (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ,
551- BLOCK_K // SCALE_GROUP_SIZE // 8 , 4 , 16 , 2 , 2 ,
552- 1 ).permute (0 , 5 , 3 , 1 , 4 , 2 ,
553- 6 ).reshape (BLOCK_N , BLOCK_K // SCALE_GROUP_SIZE )
559+ if a_scales_ptr is not None :
560+ a_scales = tl .load (a_scale_ptrs ).reshape (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ,
561+ MX_SCALE_BLOCK_K // 8 , 4 , 16 , 2 , 2 ,
562+ 1 ).permute (0 , 5 , 3 , 1 , 4 , 2 ,
563+ 6 ).reshape (BLOCK_M , MX_SCALE_BLOCK_K )
564+ else :
565+ a_scales = None
566+ if b_scales_ptr is not None :
567+ b_scales = tl .load (b_scale_ptrs ).reshape (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ,
568+ MX_SCALE_BLOCK_K // 8 , 4 , 16 , 2 , 2 ,
569+ 1 ).permute (0 , 5 , 3 , 1 , 4 , 2 ,
570+ 6 ).reshape (BLOCK_N , MX_SCALE_BLOCK_K )
571+ else :
572+ b_scales = None
554573 else :
555- a_scales = tl .load (a_scale_ptrs )
556- b_scales = tl .load (b_scale_ptrs )
574+ if a_scales_ptr is not None :
575+ a_scales = tl .load (a_scale_ptrs )
576+ else :
577+ a_scales = None
578+ if b_scales_ptr is not None :
579+ b_scales = tl .load (b_scale_ptrs )
580+ else :
581+ b_scales = None
557582
558583 a = tl .load (a_ptrs )
559584 b = tl .load (b_ptrs , cache_modifier = None )
560585
561- accumulator += tl .dot_scaled (a , a_scales , "e2m1" , b , b_scales , "e2m1" )
586+ accumulator += tl .dot_scaled (a , a_scales , DTYPE_A , b , b_scales , DTYPE_B , fast_math = fast_math )
562587
563588 # Advance the ptrs to the next K block.
564- a_ptrs += (BLOCK_K // 2 ) * stride_ak
565- b_ptrs += (BLOCK_K // 2 ) * stride_bk
589+ a_ptrs += (BLOCK_K // PACK_FACTOR_A ) * stride_ak
590+ b_ptrs += (BLOCK_K // PACK_FACTOR_B ) * stride_bk
566591 if preshuffle :
567- a_scale_ptrs += BLOCK_K * stride_ask
568- b_scale_ptrs += BLOCK_K * stride_bsk
592+ if a_scales_ptr is not None :
593+ a_scale_ptrs += BLOCK_K * stride_ask
594+ if b_scales_ptr is not None :
595+ b_scale_ptrs += BLOCK_K * stride_bsk
569596 else :
570- a_scale_ptrs += (BLOCK_K // SCALE_GROUP_SIZE ) * stride_ask
571- b_scale_ptrs += (BLOCK_K // SCALE_GROUP_SIZE ) * stride_bsk
597+ if a_scales_ptr is not None :
598+ a_scale_ptrs += MX_SCALE_BLOCK_K * stride_ask
599+ if b_scales_ptr is not None :
600+ b_scale_ptrs += MX_SCALE_BLOCK_K * stride_bsk
572601
573602 c = accumulator .to (c_ptr .type .element_ty )
574603
@@ -583,11 +612,14 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scale
583612
584613@pytest .mark .parametrize ("M, N, K" , [(1024 , 1024 , 1024 )])
585614@pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K" , [(128 , 128 , 256 ), (64 , 64 , 512 ), [32 , 32 , 64 ]])
615+ @pytest .mark .parametrize ("DTYPE_A, DTYPE_B, FAST_MATH" , [("mxfp4" , "mxfp4" , False ), ("fp16" , "mxfp8e5" , False ),
616+ ("mxfp8e4" , "bf16" , False ), ("bf16" , "mxfp4" , True )])
586617@pytest .mark .parametrize ("mfma_nonkdim" , [16 , 32 ])
587618@pytest .mark .parametrize ("preshuffle" , [True , False ])
588619@pytest .mark .skipif (is_cuda () and torch .cuda .get_device_capability ()[0 ] == 10 , reason = "Compilation bug for GB200." )
589620@pytest .mark .skipif (is_hip () and not is_hip_cdna4 (), reason = "Scaled dot is not emulated on other archs yet." )
590- def test_preshuffle_scale_mxfp_cdna4 (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , mfma_nonkdim , preshuffle , device ):
621+ def test_preshuffle_scale_mxfp_cdna4 (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , DTYPE_A , DTYPE_B , FAST_MATH , mfma_nonkdim ,
622+ preshuffle , device ):
591623 # This test primarily evaluates correctness for efficient scale packing for MFMA-scaled instructions.
592624 #
593625 # Scales are stored as 8-bit tensors, where each element scales 32 values from the A or B operand tensors.
@@ -637,6 +669,12 @@ def test_preshuffle_scale_mxfp_cdna4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, mfma_no
637669 if preshuffle and (BLOCK_M < 32 or BLOCK_N < 32 or BLOCK_K < 256 ):
638670 pytest .skip ("Minimal tile size for preshuffling is 32x32x256" )
639671
672+ if not (DTYPE_A .startswith ("mx" ) or DTYPE_B .startswith ("mx" )):
673+ pytest .skip ("Requires at least 1 microscaling operand" )
674+
675+ if is_cuda () and (DTYPE_A == "mxfp8e4" or DTYPE_B == "mxfp8e4" ):
676+ pytest .skip ("Skip fp8e4 on NV backend" )
677+
640678 def shuffle_scales_cdna4 (scales : torch .Tensor ):
641679 if not preshuffle :
642680 return scales
@@ -665,63 +703,77 @@ def run_torch(x, w, x_scales, w_scales, dtype):
665703 x_f32 = x .to (torch .float32 )
666704 w_f32 = w .to (torch .float32 )
667705 # Next convert the e8m0 scales to f32.
668- x_scales = x_scales .repeat_interleave (SCALE_GROUP_SIZE , dim = 1 ).to (torch .float32 )
669- x_scales_f32 = e8m0_to_f32 (x_scales )
670- x_f32 = x_f32 * x_scales_f32
671- w_scales = w_scales .repeat_interleave (SCALE_GROUP_SIZE , dim = 1 ).to (torch .float32 )
672- w_scales_f32 = e8m0_to_f32 (w_scales )
673- w_f32 = w_f32 * w_scales_f32
706+ if x_scales is not None :
707+ x_scales = x_scales .repeat_interleave (SCALE_GROUP_SIZE , dim = 1 ).to (torch .float32 )
708+ x_scales_f32 = e8m0_to_f32 (x_scales )
709+ x_f32 = x_f32 * x_scales_f32
710+ if w_scales is not None :
711+ w_scales = w_scales .repeat_interleave (SCALE_GROUP_SIZE , dim = 1 ).to (torch .float32 )
712+ w_scales_f32 = e8m0_to_f32 (w_scales )
713+ w_f32 = w_f32 * w_scales_f32
674714 return torch .mm (x_f32 , w_f32 .T ).to (dtype )
675715
676- def generate_gemm_afp4wfp4_inputs (M , N , K ):
716+ dtype_to_torch_type = {
717+ "fp16" : torch .half , "bf16" : torch .bfloat16 , "mxfp8e5" : torch .float8_e5m2 , "mxfp8e4" : torch .float8_e4m3fn
718+ }
719+
720+ dtype_to_triton_type = {"fp16" : "fp16" , "bf16" : "bf16" , "mxfp8e5" : "e5m2" , "mxfp8e4" : "e4m3" , "mxfp4" : "e2m1" }
721+
722+ def generate_gemm_input (dim0 , dim1 , dtype ):
677723 torch .manual_seed (5 )
678724 SCALE_GROUP_SIZE = 32
679725
680- x = MXFP4Tensor (size = (M , K ), device = "cuda" ).random ()
681- w = MXFP4Tensor (size = (N , K ), device = "cuda" ).random ()
682-
683- x_scales = torch .randint (124 , 128 , (K // SCALE_GROUP_SIZE , M ), dtype = torch .uint8 , device = "cuda" )
684- w_scales = torch .randint (124 , 128 , (K // SCALE_GROUP_SIZE , N ), dtype = torch .uint8 , device = "cuda" )
685- x_scales = x_scales .T
686- w_scales = w_scales .T
687- x_scales_shuffled = shuffle_scales_cdna4 (x_scales )
688- w_scales_shuffled = shuffle_scales_cdna4 (w_scales )
689-
690- return (
691- x ,
692- w ,
693- x_scales ,
694- w_scales ,
695- x_scales_shuffled ,
696- w_scales_shuffled ,
697- )
698-
699- x_mxfp4 , w_mxfp4 , x_scales , w_scales , x_scales_triton , w_scales_triton = generate_gemm_afp4wfp4_inputs (M , N , K )
700-
701- x = x_mxfp4 .to_packed_tensor (dim = 1 )
702- w = w_mxfp4 .to_packed_tensor (dim = 1 )
703-
704- torch_out = run_torch (x_mxfp4 , w_mxfp4 , x_scales , w_scales , torch .float32 )
705- M , K = x .shape
706- N , K = w .shape
726+ if dtype == "mxfp4" :
727+ v = MXFP4Tensor (size = (dim0 , dim1 ), device = "cuda" ).random ()
728+ elif dtype == "mxfp8e5" :
729+ v = torch .randint (20 , 40 , (dim0 , dim1 ), dtype = torch .uint8 ).view (torch .float8_e5m2 ).to (device )
730+ elif dtype == "mxfp8e4" :
731+ v = torch .randint (20 , 40 , (dim0 , dim1 ), dtype = torch .uint8 ).view (torch .float8_e4m3fn ).to (device )
732+ elif dtype in ("fp16" , "bf16" ):
733+ v = torch .randn ((dim0 , dim1 ), device = device , dtype = dtype_to_torch_type [dtype ])
734+ else :
735+ raise ValueError (f"Unsupported data type: { dtype } " )
736+
737+ if dtype .startswith ("mx" ):
738+ scales = torch .randint (124 , 128 , (dim0 , dim1 // SCALE_GROUP_SIZE ), dtype = torch .uint8 , device = device )
739+ scales_shuffled = shuffle_scales_cdna4 (scales )
740+ else :
741+ scales = None
742+ scales_shuffled = None
743+
744+ return (v , scales , scales_shuffled )
745+
746+ x , x_scales , x_scales_triton = generate_gemm_input (M , K , DTYPE_A )
747+ w , w_scales , w_scales_triton = generate_gemm_input (N , K , DTYPE_B )
748+
749+ torch_out = run_torch (x , w , x_scales , w_scales , torch .float32 )
750+
751+ if DTYPE_A == "mxfp4" :
752+ x = x .to_packed_tensor (dim = 1 )
753+
754+ if DTYPE_B == "mxfp4" :
755+ w = w .to_packed_tensor (dim = 1 )
756+
707757 w = w .T
708758 triton_out = torch .empty ((M , N ), device = x .device )
709759
760+ x_scales_strides = x_scales_triton .stride () if x_scales is not None else (None , None )
761+ w_scales_strides = w_scales_triton .stride () if w_scales is not None else (None , None )
762+
710763 kernel_kwargs = {}
711764 if is_hip ():
712765 kernel_kwargs ["matrix_instr_nonkdim" ] = mfma_nonkdim
713766
714767 grid = (triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , BLOCK_N ), 1 )
715- k = _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4 [grid ](x , w , triton_out , x_scales_triton ,
716- w_scales_triton , M , N , K , x .stride (0 ), x .stride (1 ),
717- w .stride (0 ), w .stride (1 ), 0 , triton_out .stride (0 ),
718- triton_out .stride (1 ), x_scales_triton .stride (0 ),
719- x_scales_triton .stride (1 ), w_scales_triton .stride (0 ),
720- w_scales_triton .stride (1 ), BLOCK_M , BLOCK_N , BLOCK_K ,
721- mfma_nonkdim , preshuffle , num_warps = 8 , num_stages = 1 ,
722- ** kernel_kwargs )
768+ k = _gemm_kernel_preshuffled_scales_cdna4 [grid ](x , w , triton_out , x_scales_triton , w_scales_triton , M , N , K ,
769+ x .stride (0 ), x .stride (1 ), w .stride (0 ), w .stride (1 ),
770+ triton_out .stride (0 ), triton_out .stride (1 ), * x_scales_strides ,
771+ * w_scales_strides , dtype_to_triton_type [DTYPE_A ],
772+ dtype_to_triton_type [DTYPE_B ], BLOCK_M , BLOCK_N , BLOCK_K ,
773+ mfma_nonkdim , preshuffle , fast_math = FAST_MATH , num_warps = 8 ,
774+ num_stages = 1 , ** kernel_kwargs )
723775 triton_out = triton_out .to (torch .float32 )
724- torch .testing .assert_close (torch_out , triton_out )
776+ torch .testing .assert_close (torch_out , triton_out , atol = 2e-5 , rtol = 1e-4 )
725777 if is_hip () and preshuffle :
726778 assert "tilesPerWarp = [2, 2]" in k .asm ["ttgir" ]
727779 assert "ds_read_u8" not in k .asm ["amdgcn" ]
@@ -738,7 +790,7 @@ def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_
738790 NUM_STAGES = min (NUM_STAGES , 2 )
739791 elif BLOCK_K == 256 :
740792 NUM_STAGES = min (NUM_STAGES , 3 )
741- #since the block size are big we use num_warps = 8 to avoid pressure problems.
793+ # since the block size are big we use num_warps = 8 to avoid pressure problems.
742794 num_warps = 8
743795 torch .manual_seed (42 )
744796 dtype_src_str = "float8e5"
0 commit comments