@@ -472,6 +472,248 @@ def block_scale_mxfp_matmul( #
472472 tl .store (output_ptrs , accumulator , mask = c_mask )
473473
474474
475+ @triton .jit
476+ def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4 (a_ptr , b_ptr , c_ptr , a_scales_ptr , b_scales_ptr , M , N , K , stride_am ,
477+ stride_ak , stride_bk , stride_bn , stride_ck , stride_cm , stride_cn ,
478+ stride_asm , stride_ask , stride_bsn , stride_bsk ,
479+ # Meta-parameters
480+ BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , BLOCK_K : tl .constexpr ,
481+ mfma_nonkdim : tl .constexpr , preshuffle : tl .constexpr ):
482+ """Kernel for computing the matmul C = A x B.
483+ A and B inputs are in the microscale fp4 (mxfp4) format.
484+ A_scales and B_scales are in e8m0 format.
485+ A has shape (M, K), B has shape (K, N) and C has shape (M, N)
486+ """
487+
488+ pid = tl .program_id (axis = 0 )
489+
490+ num_pid_n = tl .cdiv (N , BLOCK_N )
491+ pid_m = pid // num_pid_n
492+ pid_n = pid % num_pid_n
493+
494+ # We assume 32 elements along K share the same scale.
495+ SCALE_GROUP_SIZE : tl .constexpr = 32
496+
497+ if preshuffle :
498+ NON_K_PRESHUFFLE_BLOCK_SIZE : tl .constexpr = 32
499+ else :
500+ NON_K_PRESHUFFLE_BLOCK_SIZE : tl .constexpr = 1
501+
502+ num_k_iter = tl .cdiv (K , BLOCK_K // 2 )
503+ # Create pointers for first block of A and B input matrices
504+ # The BLOCK sizes are of the elements and in fp4 we pack 2 per uint8 container.
505+ offs_k = tl .arange (0 , BLOCK_K // 2 )
506+ offs_k_split = offs_k
507+ offs_am = (pid_m * BLOCK_M + tl .arange (0 , BLOCK_M )) % M
508+ offs_bn = (pid_n * BLOCK_N + tl .arange (0 , BLOCK_N )) % N
509+ a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_k_split [None , :] * stride_ak )
510+ b_ptrs = b_ptr + (offs_k_split [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
511+
512+ # Create pointers for the first block of A and B scales
513+ offs_asn = (pid_n *
514+ (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ) + tl .arange (0 , (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ))) % N
515+ offs_ks = tl .arange (0 , BLOCK_K // SCALE_GROUP_SIZE * NON_K_PRESHUFFLE_BLOCK_SIZE )
516+
517+ # B scales are N x K even though B operand is K x N.
518+ b_scale_ptrs = (b_scales_ptr + offs_asn [:, None ] * stride_bsn + offs_ks [None , :] * stride_bsk )
519+ offs_asm = (pid_m *
520+ (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ) + tl .arange (0 , (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ))) % M
521+ a_scale_ptrs = (a_scales_ptr + offs_asm [:, None ] * stride_asm + offs_ks [None , :] * stride_ask )
522+ accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
523+
524+ for k in range (0 , num_k_iter ):
525+ if preshuffle :
526+ # Here we "undo" the shuffle done in global memory (shuffle_scales_cdna4 function).
527+ if mfma_nonkdim == 32 :
528+ a_scales = tl .load (a_scale_ptrs ).reshape (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ,
529+ BLOCK_K // SCALE_GROUP_SIZE // 8 , 2 , 32 , 4 ,
530+ 1 ).permute (0 , 3 , 1 , 4 , 2 ,
531+ 5 ).reshape (BLOCK_M , BLOCK_K // SCALE_GROUP_SIZE )
532+ b_scales = tl .load (b_scale_ptrs ).reshape (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ,
533+ BLOCK_K // SCALE_GROUP_SIZE // 8 , 2 , 32 , 4 ,
534+ 1 ).permute (0 , 3 , 1 , 4 , 2 ,
535+ 5 ).reshape (BLOCK_N , BLOCK_K // SCALE_GROUP_SIZE )
536+ elif mfma_nonkdim == 16 :
537+ a_scales = tl .load (a_scale_ptrs ).reshape (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ,
538+ BLOCK_K // SCALE_GROUP_SIZE // 8 , 4 , 16 , 2 , 2 ,
539+ 1 ).permute (0 , 5 , 3 , 1 , 4 , 2 ,
540+ 6 ).reshape (BLOCK_M , BLOCK_K // SCALE_GROUP_SIZE )
541+ b_scales = tl .load (b_scale_ptrs ).reshape (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ,
542+ BLOCK_K // SCALE_GROUP_SIZE // 8 , 4 , 16 , 2 , 2 ,
543+ 1 ).permute (0 , 5 , 3 , 1 , 4 , 2 ,
544+ 6 ).reshape (BLOCK_N , BLOCK_K // SCALE_GROUP_SIZE )
545+ else :
546+ a_scales = tl .load (a_scale_ptrs )
547+ b_scales = tl .load (b_scale_ptrs )
548+
549+ a = tl .load (a_ptrs )
550+ b = tl .load (b_ptrs , cache_modifier = None )
551+
552+ accumulator += tl .dot_scaled (a , a_scales , "e2m1" , b , b_scales , "e2m1" )
553+
554+ # Advance the ptrs to the next K block.
555+ a_ptrs += (BLOCK_K // 2 ) * stride_ak
556+ b_ptrs += (BLOCK_K // 2 ) * stride_bk
557+ if preshuffle :
558+ a_scale_ptrs += BLOCK_K * stride_ask
559+ b_scale_ptrs += BLOCK_K * stride_bsk
560+ else :
561+ a_scale_ptrs += (BLOCK_K // SCALE_GROUP_SIZE ) * stride_ask
562+ b_scale_ptrs += (BLOCK_K // SCALE_GROUP_SIZE ) * stride_bsk
563+
564+ c = accumulator .to (c_ptr .type .element_ty )
565+
566+ # Write back the block of the output matrix C with masks.
567+ offs_cm = pid_m * BLOCK_M + tl .arange (0 , BLOCK_M ).to (tl .int64 )
568+ offs_cn = pid_n * BLOCK_N + tl .arange (0 , BLOCK_N ).to (tl .int64 )
569+ c_ptrs = (c_ptr + stride_cm * offs_cm [:, None ] + stride_cn * offs_cn [None , :])
570+ c_mask = (offs_cm [:, None ] < M ) & (offs_cn [None , :] < N )
571+
572+ tl .store (c_ptrs , c , mask = c_mask , cache_modifier = ".wt" )
573+
574+
575+ @pytest .mark .parametrize ("M, N, K" , [(1024 , 1024 , 1024 )])
576+ @pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K" , [(128 , 128 , 256 ), (64 , 64 , 512 ), [32 , 32 , 64 ]])
577+ @pytest .mark .parametrize ("mfma_nonkdim" , [16 , 32 ])
578+ @pytest .mark .parametrize ("preshuffle" , [True , False ])
579+ @pytest .mark .skipif (is_cuda () and torch .cuda .get_device_capability ()[0 ] == 10 , reason = "Compilation bug for GB200." )
580+ @pytest .mark .skipif (is_hip () and not is_hip_cdna4 (), reason = "Scaled dot is not emulated on other archs yet." )
581+ def test_preshuffle_scale_mxfp_cdna4 (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , mfma_nonkdim , preshuffle , device ):
582+ # This test primarily evaluates correctness for efficient scale packing for MFMA-scaled instructions.
583+ #
584+ # Scales are stored as 8-bit tensors, where each element scales 32 values from the A or B operand tensors.
585+ # Since MFMA instructions are wave-level instructions, that means that each thread provides a fixed set of operand values to MFMA instructions.
586+ #
587+ # For example, in an MFMA instruction with shape 16x16x128:
588+ # - 4 threads contribute elements along the K dimension.
589+ # - 16 threads contribute elements along the M or N dimension.
590+ #
591+ # From the perspective of the scales tensor, even if the K dimension is stored contiguously in LDS,
592+ # each thread sees its elements along K dim as strided due to interleaving with other threads.
593+ # This striding limits the ability to load scale values using vectorized memory access.
594+ #
595+ # Our goal is to reorganize the scale tensor so that:
596+ # 1. Each thread stores the 4 scale values it needs for 4 MFMA ops in contiguous memory.
597+ # 2. Continuous threads access contiguous memory locations improving global memory coalescing when bypassing LDS,
598+ # which is especially beneficial for "skinny" matmuls.
599+ #
600+ # We consider two MFMA cases: one with non-K dimension 16, and one with 32.
601+ # In both, the minimum tile size for preshuffling is 32x32x256.
602+ # For example, for a 32x256 operand tile, the corresponding scale tensor has shape 32x8,
603+ # where each scale covers 32 elements along the K dimension.
604+ #
605+ # Each thread holds one scale per MFMA operation. We pack the 4 scale values (for 4 different MFMA ops)
606+ # next to each other in memory.
607+ #
608+ # Case 1: mfma_scaled_16x16x128
609+ #
610+ # Packing order: mfma_op_0, mfma_op_2, mfma_op_1, mfma_op_3
611+ #
612+ # K = 128 K = 128
613+ # +------------+ +------------+
614+ # M=16| MFMA op 0 | | MFMA op 1 |
615+ # +------------+ +------------+
616+ # M=16| MFMA op 2 | | MFMA op 3 |
617+ # +------------+ +------------+
618+ #
619+ # Case 2: mfma_scaled_32x32x64
620+ #
621+ # Packing order: mfma_op_0, mfma_op_1, mfma_op_2, mfma_op_3
622+ #
623+ # K=64 K=64 K=64 K=64
624+ # +--------+ +--------+ +--------+ +--------+
625+ # M=32| op 0 | | op 1 | | op 2 | | op 3 |
626+ # +--------+ +--------+ +--------+ +--------+
627+
628+ if preshuffle and (BLOCK_M < 32 or BLOCK_N < 32 or BLOCK_K < 256 ):
629+ pytest .skip ("Minimal tile size for preshuffling is 32x32x256" )
630+
631+ def shuffle_scales_cdna4 (scales : torch .Tensor ):
632+ if not preshuffle :
633+ return scales
634+
635+ scales_shuffled = scales .clone ()
636+
637+ sm , sn = scales_shuffled .shape
638+ if mfma_nonkdim == 32 :
639+ scales_shuffled = scales_shuffled .view (sm // 32 , 32 , sn // 8 , 4 , 2 , 1 )
640+ scales_shuffled = scales_shuffled .permute (0 , 2 , 4 , 1 , 3 , 5 ).contiguous ()
641+ elif mfma_nonkdim == 16 :
642+ scales_shuffled = scales_shuffled .view (sm // 32 , 2 , 16 , sn // 8 , 2 , 4 , 1 )
643+ scales_shuffled = scales_shuffled .permute (0 , 3 , 5 , 2 , 4 , 1 , 6 ).contiguous ()
644+
645+ scales_shuffled = scales_shuffled .view (sm // 32 , sn * 32 )
646+ return scales_shuffled
647+
648+ def e8m0_to_f32 (x ):
649+ x_f32 = 2 ** ((x - 127 ).to (torch .float32 ))
650+ x_f32 [x_f32 == 128 ] = float ("nan" )
651+ return x_f32
652+
653+ def run_torch (x , w , x_scales , w_scales , dtype ):
654+ # First convert the x and w inputs to f32.
655+ SCALE_GROUP_SIZE = 32
656+ x_f32 = x .to (torch .float32 )
657+ w_f32 = w .to (torch .float32 )
658+ # Next convert the e8m0 scales to f32.
659+ x_scales = x_scales .repeat_interleave (SCALE_GROUP_SIZE , dim = 1 ).to (torch .float32 )
660+ x_scales_f32 = e8m0_to_f32 (x_scales )
661+ x_f32 = x_f32 * x_scales_f32
662+ w_scales = w_scales .repeat_interleave (SCALE_GROUP_SIZE , dim = 1 ).to (torch .float32 )
663+ w_scales_f32 = e8m0_to_f32 (w_scales )
664+ w_f32 = w_f32 * w_scales_f32
665+ return torch .mm (x_f32 , w_f32 .T ).to (dtype )
666+
667+ def generate_gemm_afp4wfp4_inputs (M , N , K ):
668+ torch .manual_seed (5 )
669+ SCALE_GROUP_SIZE = 32
670+
671+ x = MXFP4Tensor (size = (M , K ), device = "cuda" ).random ()
672+ w = MXFP4Tensor (size = (N , K ), device = "cuda" ).random ()
673+
674+ x_scales = torch .randint (124 , 128 , (K // SCALE_GROUP_SIZE , M ), dtype = torch .uint8 , device = "cuda" )
675+ w_scales = torch .randint (124 , 128 , (K // SCALE_GROUP_SIZE , N ), dtype = torch .uint8 , device = "cuda" )
676+ x_scales = x_scales .T
677+ w_scales = w_scales .T
678+ x_scales_shuffled = shuffle_scales_cdna4 (x_scales )
679+ w_scales_shuffled = shuffle_scales_cdna4 (w_scales )
680+
681+ return (
682+ x ,
683+ w ,
684+ x_scales ,
685+ w_scales ,
686+ x_scales_shuffled ,
687+ w_scales_shuffled ,
688+ )
689+
690+ x_mxfp4 , w_mxfp4 , x_scales , w_scales , x_scales_triton , w_scales_triton = generate_gemm_afp4wfp4_inputs (M , N , K )
691+
692+ x = x_mxfp4 .to_packed_tensor (dim = 1 )
693+ w = w_mxfp4 .to_packed_tensor (dim = 1 )
694+
695+ torch_out = run_torch (x_mxfp4 , w_mxfp4 , x_scales , w_scales , torch .float32 )
696+ M , K = x .shape
697+ N , K = w .shape
698+ w = w .T
699+ triton_out = torch .empty ((M , N ), device = x .device )
700+
701+ kernel_kwargs = {}
702+ if is_hip ():
703+ kernel_kwargs ["matrix_instr_nonkdim" ] = mfma_nonkdim
704+
705+ grid = (triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , BLOCK_N ), 1 )
706+ _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4 [grid ](x , w , triton_out , x_scales_triton , w_scales_triton , M , N , K ,
707+ x .stride (0 ), x .stride (1 ), w .stride (0 ), w .stride (1 ), 0 ,
708+ triton_out .stride (0 ), triton_out .stride (1 ),
709+ x_scales_triton .stride (0 ), x_scales_triton .stride (1 ),
710+ w_scales_triton .stride (0 ), w_scales_triton .stride (1 ), BLOCK_M ,
711+ BLOCK_N , BLOCK_K , mfma_nonkdim , preshuffle , num_warps = 8 ,
712+ num_stages = 1 , ** kernel_kwargs )
713+ triton_out = triton_out .to (torch .float32 )
714+ torch .testing .assert_close (torch_out , triton_out )
715+
716+
475717@pytest .mark .parametrize ("M, N, K" , [(1024 , 512 , 512 ), (998 , 111 , 512 ), (63 , 128 , 512 )])
476718@pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K" , [(128 , 128 , 128 ), (256 , 128 , 128 ), (128 , 256 , 128 ),
477719 (128 , 128 , 256 ), (128 , 256 , 256 )])
0 commit comments