@@ -475,6 +475,248 @@ def block_scale_mxfp_matmul( #
475
475
tl .store (output_ptrs , accumulator , mask = c_mask )
476
476
477
477
478
+ @triton .jit
479
+ def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4 (a_ptr , b_ptr , c_ptr , a_scales_ptr , b_scales_ptr , M , N , K , stride_am ,
480
+ stride_ak , stride_bk , stride_bn , stride_ck , stride_cm , stride_cn ,
481
+ stride_asm , stride_ask , stride_bsn , stride_bsk ,
482
+ # Meta-parameters
483
+ BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , BLOCK_K : tl .constexpr ,
484
+ mfma_nonkdim : tl .constexpr , preshuffle : tl .constexpr ):
485
+ """Kernel for computing the matmul C = A x B.
486
+ A and B inputs are in the microscale fp4 (mxfp4) format.
487
+ A_scales and B_scales are in e8m0 format.
488
+ A has shape (M, K), B has shape (K, N) and C has shape (M, N)
489
+ """
490
+
491
+ pid = tl .program_id (axis = 0 )
492
+
493
+ num_pid_n = tl .cdiv (N , BLOCK_N )
494
+ pid_m = pid // num_pid_n
495
+ pid_n = pid % num_pid_n
496
+
497
+ # We assume 32 elements along K share the same scale.
498
+ SCALE_GROUP_SIZE : tl .constexpr = 32
499
+
500
+ if preshuffle :
501
+ NON_K_PRESHUFFLE_BLOCK_SIZE : tl .constexpr = 32
502
+ else :
503
+ NON_K_PRESHUFFLE_BLOCK_SIZE : tl .constexpr = 1
504
+
505
+ num_k_iter = tl .cdiv (K , BLOCK_K // 2 )
506
+ # Create pointers for first block of A and B input matrices
507
+ # The BLOCK sizes are of the elements and in fp4 we pack 2 per uint8 container.
508
+ offs_k = tl .arange (0 , BLOCK_K // 2 )
509
+ offs_k_split = offs_k
510
+ offs_am = (pid_m * BLOCK_M + tl .arange (0 , BLOCK_M )) % M
511
+ offs_bn = (pid_n * BLOCK_N + tl .arange (0 , BLOCK_N )) % N
512
+ a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_k_split [None , :] * stride_ak )
513
+ b_ptrs = b_ptr + (offs_k_split [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
514
+
515
+ # Create pointers for the first block of A and B scales
516
+ offs_asn = (pid_n *
517
+ (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ) + tl .arange (0 , (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ))) % N
518
+ offs_ks = tl .arange (0 , BLOCK_K // SCALE_GROUP_SIZE * NON_K_PRESHUFFLE_BLOCK_SIZE )
519
+
520
+ # B scales are N x K even though B operand is K x N.
521
+ b_scale_ptrs = (b_scales_ptr + offs_asn [:, None ] * stride_bsn + offs_ks [None , :] * stride_bsk )
522
+ offs_asm = (pid_m *
523
+ (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ) + tl .arange (0 , (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ))) % M
524
+ a_scale_ptrs = (a_scales_ptr + offs_asm [:, None ] * stride_asm + offs_ks [None , :] * stride_ask )
525
+ accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
526
+
527
+ for k in range (0 , num_k_iter ):
528
+ if preshuffle :
529
+ # Here we "undo" the shuffle done in global memory (shuffle_scales_cdna4 function).
530
+ if mfma_nonkdim == 32 :
531
+ a_scales = tl .load (a_scale_ptrs ).reshape (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ,
532
+ BLOCK_K // SCALE_GROUP_SIZE // 8 , 2 , 32 , 4 ,
533
+ 1 ).permute (0 , 3 , 1 , 4 , 2 ,
534
+ 5 ).reshape (BLOCK_M , BLOCK_K // SCALE_GROUP_SIZE )
535
+ b_scales = tl .load (b_scale_ptrs ).reshape (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ,
536
+ BLOCK_K // SCALE_GROUP_SIZE // 8 , 2 , 32 , 4 ,
537
+ 1 ).permute (0 , 3 , 1 , 4 , 2 ,
538
+ 5 ).reshape (BLOCK_N , BLOCK_K // SCALE_GROUP_SIZE )
539
+ elif mfma_nonkdim == 16 :
540
+ a_scales = tl .load (a_scale_ptrs ).reshape (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE ,
541
+ BLOCK_K // SCALE_GROUP_SIZE // 8 , 4 , 16 , 2 , 2 ,
542
+ 1 ).permute (0 , 5 , 3 , 1 , 4 , 2 ,
543
+ 6 ).reshape (BLOCK_M , BLOCK_K // SCALE_GROUP_SIZE )
544
+ b_scales = tl .load (b_scale_ptrs ).reshape (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE ,
545
+ BLOCK_K // SCALE_GROUP_SIZE // 8 , 4 , 16 , 2 , 2 ,
546
+ 1 ).permute (0 , 5 , 3 , 1 , 4 , 2 ,
547
+ 6 ).reshape (BLOCK_N , BLOCK_K // SCALE_GROUP_SIZE )
548
+ else :
549
+ a_scales = tl .load (a_scale_ptrs )
550
+ b_scales = tl .load (b_scale_ptrs )
551
+
552
+ a = tl .load (a_ptrs )
553
+ b = tl .load (b_ptrs , cache_modifier = None )
554
+
555
+ accumulator += tl .dot_scaled (a , a_scales , "e2m1" , b , b_scales , "e2m1" )
556
+
557
+ # Advance the ptrs to the next K block.
558
+ a_ptrs += (BLOCK_K // 2 ) * stride_ak
559
+ b_ptrs += (BLOCK_K // 2 ) * stride_bk
560
+ if preshuffle :
561
+ a_scale_ptrs += BLOCK_K * stride_ask
562
+ b_scale_ptrs += BLOCK_K * stride_bsk
563
+ else :
564
+ a_scale_ptrs += (BLOCK_K // SCALE_GROUP_SIZE ) * stride_ask
565
+ b_scale_ptrs += (BLOCK_K // SCALE_GROUP_SIZE ) * stride_bsk
566
+
567
+ c = accumulator .to (c_ptr .type .element_ty )
568
+
569
+ # Write back the block of the output matrix C with masks.
570
+ offs_cm = pid_m * BLOCK_M + tl .arange (0 , BLOCK_M ).to (tl .int64 )
571
+ offs_cn = pid_n * BLOCK_N + tl .arange (0 , BLOCK_N ).to (tl .int64 )
572
+ c_ptrs = (c_ptr + stride_cm * offs_cm [:, None ] + stride_cn * offs_cn [None , :])
573
+ c_mask = (offs_cm [:, None ] < M ) & (offs_cn [None , :] < N )
574
+
575
+ tl .store (c_ptrs , c , mask = c_mask , cache_modifier = ".wt" )
576
+
577
+
578
+ @pytest .mark .parametrize ("M, N, K" , [(1024 , 1024 , 1024 )])
579
+ @pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K" , [(128 , 128 , 256 ), (64 , 64 , 512 ), [32 , 32 , 64 ]])
580
+ @pytest .mark .parametrize ("mfma_nonkdim" , [16 , 32 ])
581
+ @pytest .mark .parametrize ("preshuffle" , [True , False ])
582
+ @pytest .mark .skipif (is_cuda () and torch .cuda .get_device_capability ()[0 ] == 10 , reason = "Compilation bug for GB200." )
583
+ @pytest .mark .skipif (is_hip () and not is_hip_cdna4 (), reason = "Scaled dot is not emulated on other archs yet." )
584
+ def test_preshuffle_scale_mxfp_cdna4 (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , mfma_nonkdim , preshuffle , device ):
585
+ # This test primarily evaluates correctness for efficient scale packing for MFMA-scaled instructions.
586
+ #
587
+ # Scales are stored as 8-bit tensors, where each element scales 32 values from the A or B operand tensors.
588
+ # Since MFMA instructions are wave-level instructions, that means that each thread provides a fixed set of operand values to MFMA instructions.
589
+ #
590
+ # For example, in an MFMA instruction with shape 16x16x128:
591
+ # - 4 threads contribute elements along the K dimension.
592
+ # - 16 threads contribute elements along the M or N dimension.
593
+ #
594
+ # From the perspective of the scales tensor, even if the K dimension is stored contiguously in LDS,
595
+ # each thread sees its elements along K dim as strided due to interleaving with other threads.
596
+ # This striding limits the ability to load scale values using vectorized memory access.
597
+ #
598
+ # Our goal is to reorganize the scale tensor so that:
599
+ # 1. Each thread stores the 4 scale values it needs for 4 MFMA ops in contiguous memory.
600
+ # 2. Continuous threads access contiguous memory locations improving global memory coalescing when bypassing LDS,
601
+ # which is especially beneficial for "skinny" matmuls.
602
+ #
603
+ # We consider two MFMA cases: one with non-K dimension 16, and one with 32.
604
+ # In both, the minimum tile size for preshuffling is 32x32x256.
605
+ # For example, for a 32x256 operand tile, the corresponding scale tensor has shape 32x8,
606
+ # where each scale covers 32 elements along the K dimension.
607
+ #
608
+ # Each thread holds one scale per MFMA operation. We pack the 4 scale values (for 4 different MFMA ops)
609
+ # next to each other in memory.
610
+ #
611
+ # Case 1: mfma_scaled_16x16x128
612
+ #
613
+ # Packing order: mfma_op_0, mfma_op_2, mfma_op_1, mfma_op_3
614
+ #
615
+ # K = 128 K = 128
616
+ # +------------+ +------------+
617
+ # M=16| MFMA op 0 | | MFMA op 1 |
618
+ # +------------+ +------------+
619
+ # M=16| MFMA op 2 | | MFMA op 3 |
620
+ # +------------+ +------------+
621
+ #
622
+ # Case 2: mfma_scaled_32x32x64
623
+ #
624
+ # Packing order: mfma_op_0, mfma_op_1, mfma_op_2, mfma_op_3
625
+ #
626
+ # K=64 K=64 K=64 K=64
627
+ # +--------+ +--------+ +--------+ +--------+
628
+ # M=32| op 0 | | op 1 | | op 2 | | op 3 |
629
+ # +--------+ +--------+ +--------+ +--------+
630
+
631
+ if preshuffle and (BLOCK_M < 32 or BLOCK_N < 32 or BLOCK_K < 256 ):
632
+ pytest .skip ("Minimal tile size for preshuffling is 32x32x256" )
633
+
634
+ def shuffle_scales_cdna4 (scales : torch .Tensor ):
635
+ if not preshuffle :
636
+ return scales
637
+
638
+ scales_shuffled = scales .clone ()
639
+
640
+ sm , sn = scales_shuffled .shape
641
+ if mfma_nonkdim == 32 :
642
+ scales_shuffled = scales_shuffled .view (sm // 32 , 32 , sn // 8 , 4 , 2 , 1 )
643
+ scales_shuffled = scales_shuffled .permute (0 , 2 , 4 , 1 , 3 , 5 ).contiguous ()
644
+ elif mfma_nonkdim == 16 :
645
+ scales_shuffled = scales_shuffled .view (sm // 32 , 2 , 16 , sn // 8 , 2 , 4 , 1 )
646
+ scales_shuffled = scales_shuffled .permute (0 , 3 , 5 , 2 , 4 , 1 , 6 ).contiguous ()
647
+
648
+ scales_shuffled = scales_shuffled .view (sm // 32 , sn * 32 )
649
+ return scales_shuffled
650
+
651
+ def e8m0_to_f32 (x ):
652
+ x_f32 = 2 ** ((x - 127 ).to (torch .float32 ))
653
+ x_f32 [x_f32 == 128 ] = float ("nan" )
654
+ return x_f32
655
+
656
+ def run_torch (x , w , x_scales , w_scales , dtype ):
657
+ # First convert the x and w inputs to f32.
658
+ SCALE_GROUP_SIZE = 32
659
+ x_f32 = x .to (torch .float32 )
660
+ w_f32 = w .to (torch .float32 )
661
+ # Next convert the e8m0 scales to f32.
662
+ x_scales = x_scales .repeat_interleave (SCALE_GROUP_SIZE , dim = 1 ).to (torch .float32 )
663
+ x_scales_f32 = e8m0_to_f32 (x_scales )
664
+ x_f32 = x_f32 * x_scales_f32
665
+ w_scales = w_scales .repeat_interleave (SCALE_GROUP_SIZE , dim = 1 ).to (torch .float32 )
666
+ w_scales_f32 = e8m0_to_f32 (w_scales )
667
+ w_f32 = w_f32 * w_scales_f32
668
+ return torch .mm (x_f32 , w_f32 .T ).to (dtype )
669
+
670
+ def generate_gemm_afp4wfp4_inputs (M , N , K ):
671
+ torch .manual_seed (5 )
672
+ SCALE_GROUP_SIZE = 32
673
+
674
+ x = MXFP4Tensor (size = (M , K ), device = "cuda" ).random ()
675
+ w = MXFP4Tensor (size = (N , K ), device = "cuda" ).random ()
676
+
677
+ x_scales = torch .randint (124 , 128 , (K // SCALE_GROUP_SIZE , M ), dtype = torch .uint8 , device = "cuda" )
678
+ w_scales = torch .randint (124 , 128 , (K // SCALE_GROUP_SIZE , N ), dtype = torch .uint8 , device = "cuda" )
679
+ x_scales = x_scales .T
680
+ w_scales = w_scales .T
681
+ x_scales_shuffled = shuffle_scales_cdna4 (x_scales )
682
+ w_scales_shuffled = shuffle_scales_cdna4 (w_scales )
683
+
684
+ return (
685
+ x ,
686
+ w ,
687
+ x_scales ,
688
+ w_scales ,
689
+ x_scales_shuffled ,
690
+ w_scales_shuffled ,
691
+ )
692
+
693
+ x_mxfp4 , w_mxfp4 , x_scales , w_scales , x_scales_triton , w_scales_triton = generate_gemm_afp4wfp4_inputs (M , N , K )
694
+
695
+ x = x_mxfp4 .to_packed_tensor (dim = 1 )
696
+ w = w_mxfp4 .to_packed_tensor (dim = 1 )
697
+
698
+ torch_out = run_torch (x_mxfp4 , w_mxfp4 , x_scales , w_scales , torch .float32 )
699
+ M , K = x .shape
700
+ N , K = w .shape
701
+ w = w .T
702
+ triton_out = torch .empty ((M , N ), device = x .device )
703
+
704
+ kernel_kwargs = {}
705
+ if is_hip ():
706
+ kernel_kwargs ["matrix_instr_nonkdim" ] = mfma_nonkdim
707
+
708
+ grid = (triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , BLOCK_N ), 1 )
709
+ _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4 [grid ](x , w , triton_out , x_scales_triton , w_scales_triton , M , N , K ,
710
+ x .stride (0 ), x .stride (1 ), w .stride (0 ), w .stride (1 ), 0 ,
711
+ triton_out .stride (0 ), triton_out .stride (1 ),
712
+ x_scales_triton .stride (0 ), x_scales_triton .stride (1 ),
713
+ w_scales_triton .stride (0 ), w_scales_triton .stride (1 ), BLOCK_M ,
714
+ BLOCK_N , BLOCK_K , mfma_nonkdim , preshuffle , num_warps = 8 ,
715
+ num_stages = 1 , ** kernel_kwargs )
716
+ triton_out = triton_out .to (torch .float32 )
717
+ torch .testing .assert_close (torch_out , triton_out )
718
+
719
+
478
720
@pytest .mark .parametrize ("M, N, K" , [(1024 , 512 , 512 ), (998 , 111 , 512 ), (63 , 128 , 512 )])
479
721
@pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K" , [(128 , 128 , 128 ), (256 , 128 , 128 ), (128 , 256 , 128 ),
480
722
(128 , 128 , 256 ), (128 , 256 , 256 )])
0 commit comments