Skip to content

Conversation

@kvederni
Copy link
Contributor

This change adds dense and sparse MMA intrinsics with block scaling. The implementation is based on PTX ISA version 9.0. Tests for new intrinsics are added for PTX 8.7 and SM 120a and are generated by llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py. The tests have been verified with ptxas from CUDA-13.0 release.
Dense MMA intrinsics with block scaling were supported by @schwarzschild-radius.

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2025

@llvm/pr-subscribers-backend-nvptx

@llvm/pr-subscribers-llvm-ir

Author: Kirill Vedernikov (kvederni)

Changes

This change adds dense and sparse MMA intrinsics with block scaling. The implementation is based on PTX ISA version 9.0. Tests for new intrinsics are added for PTX 8.7 and SM 120a and are generated by llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py. The tests have been verified with ptxas from CUDA-13.0 release.
Dense MMA intrinsics with block scaling were supported by @schwarzschild-radius.


Patch is 29.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163561.diff

3 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+189-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+129-1)
  • (modified) llvm/test/CodeGen/NVPTX/wmma.py (+340-2)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 3af1750ffcf3f..6256baa50a1c6 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -277,6 +277,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
       !eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
       !eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
 
+      // mma.block_scale e2m1 (mxf4, mxf4nvf4) -> f32 @ m16n8k64
+      !eq(gft,"m16n8k64:c:f32") : !listsplat(llvm_float_ty, 4),
+      !eq(gft,"m16n8k64:d:f32") : !listsplat(llvm_float_ty, 4),
+
       // wmma fp16 -> fp16/fp32 @  m16n16k16/m8n32k16/m32n8k16
       // All other supported geometries use the same fragment format for f32 and
       // f16, so we only need to consider {fragment, type}.
@@ -520,6 +524,18 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, strin
                   # signature;
 }
 
+class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+                           WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+  string record = "int_nvvm_mma_block_scale"
+                  # "_" # A.geom
+                  # "_row_col"
+                  # "_" # Kind
+                  # !subst(".", "_", ScaleVecSize)
+                  # signature
+                  # "_" # SType;
+}
+
 class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
                   WMMA_REGS A, WMMA_REGS B,
                   WMMA_REGS C, WMMA_REGS D> {
@@ -533,6 +549,19 @@ class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
                   # signature;
 }
 
+class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+                              WMMA_REGS A, WMMA_REGS B,
+                              WMMA_REGS C, WMMA_REGS D> {
+  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+  string record = "int_nvvm_mma_sp_ordered_metadata_block_scale"
+                  # "_" # A.geom
+                  # "_row_col"
+                  # "_" # Kind
+                  # !subst(".", "_", ScaleVecSize)
+                  # signature
+                  # "_" # SType;
+}
+
 class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
   string intr = "llvm.nvvm.ldmatrix.sync.aligned"
                 # "." # Frag.geom
@@ -672,6 +701,18 @@ class NVVM_MMA_OPS {
             fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
             int_mma_ops, subint_mma_ops, bit_mma_ops);
 
+  list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
+    ["m16n8k64"], ["e2m1"], ["e2m1"], ["f32"], ["f32"]
+  >.ret;
+
+  list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
+    ["m16n8k32"], ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+                  ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], ["f32"], ["f32"]
+  >.ret;
+
+  list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
+            mxf4_mma_ops, mxf8f6f4_mma_ops);
+
   list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
             ["m16n8k16", "m16n8k32"],
             ["bf16"], [], ["f32"], [], true>.ret;
@@ -696,6 +737,18 @@ class NVVM_MMA_OPS {
             bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
             subint_mma_sp_ops, int_mma_sp_ops);
 
+  // combines available geoms and types for mxf4 and mxf4nvf4 kinds
+  list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
+            ["m16n8k128"],
+            ["e2m1"], ["e2m1"], ["f32"], [], true>.ret;
+  list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
+            ["m16n8k64"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f32"], [], true>.ret;
+  list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
+            mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
+
   list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
             ["m16n16k16", "m32n8k16", "m8n32k16"],
             ["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
@@ -889,6 +942,32 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
   );
 }
 
+class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind, string stype, string scale_vec_size> {
+  string geom = frags[0].geom;
+
+  bit ret = !cond(
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf4"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_2x")),
+         !eq(stype, "ue8m0")) : true,
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(scale_vec_size, ".scale_2x"),
+         !eq(stype, "ue8m0")) : true,
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(scale_vec_size, ".scale_4x"),
+         !eq(stype, "ue4m3")) : true,
+    !and(!eq(geom, "m16n8k32"),
+         !eq(kind, "mxf8f6f4"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_1x")),
+         !eq(stype, "ue8m0")) : true,
+    true: false
+  );
+}
+
 // Returns true if the fragment is valid for ldmatrix ops is supported;
 // false otherwise.
 // E.g.
@@ -987,6 +1066,51 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
 }
 
 
+// Returns true if this combination of kind/scale_vec_size/stype
+// for MMA.SP ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<...>.ret then
+//   def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+                                        string stype, string scale_vec_size> {
+  // MMA.SP ops check both layouts.
+  string a_type = frags[0].ptx_elt_type;
+  string b_type = frags[1].ptx_elt_type;
+  string c_type = frags[2].ptx_elt_type;
+  string d_type = frags[3].ptx_elt_type;
+  string geom = frags[0].geom;
+
+  bit ret = !cond(
+    !and(!eq(geom, "m16n8k128"),
+         !eq(kind, "mxf4"),
+         !eq(stype, "ue8m0"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_2x"))): true,
+
+    !and(!eq(geom, "m16n8k128"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(stype, "ue8m0"),
+         !eq(scale_vec_size, ".scale_2x")): true,
+
+    !and(!eq(geom, "m16n8k128"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(stype, "ue4m3"),
+         !eq(scale_vec_size, ".scale_4x")): true,
+
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf8f6f4"),
+         !eq(stype, "ue8m0"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_1x"))): true,
+
+    // All other are NOT OK.
+    true: false
+  );
+}
+
+
 class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
   string Suffix = !if(sync, "sync_", "")
                   # mode # "_"
@@ -2340,6 +2464,31 @@ foreach layout_a = ["row", "col"] in {
   } // layout_b
 } // layout_a
 
+class NVVM_MMA_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+  : Intrinsic<D.regs,
+              !listconcat(A.regs, B.regs, C.regs,
+              [
+                llvm_i32_ty,              // scale-a-data
+                llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
+                llvm_i32_ty,              // scale-b-data,
+                llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
+              ]),
+              [IntrNoMem, IntrNoCallback]>;
+
+foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+  foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+    foreach stype = ["ue8m0", "ue4m3"] in {
+      foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
+        if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+          def MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
+                                   op[0], op[1], op[2], op[3]>.record
+            : NVVM_MMA_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
+        }
+      } // op
+    } // stype
+  } // scale_vec_size
+} // kind
+
 // MMA.SP
 class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
   : Intrinsic<D.regs,
@@ -2387,6 +2536,45 @@ foreach metadata = ["sp", "sp::ordered_metadata"] in {
   } // kind
 } // metadata
 
+// MMA.SP BLOCK SCALE
+class NVVM_MMA_SP_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+  : Intrinsic<D.regs,
+              !listconcat(A.regs, B.regs, C.regs,
+              [
+                llvm_i32_ty, // metadata
+                llvm_i32_ty, // sparsity selector
+                llvm_i32_ty, // scale-a-data
+                llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
+                llvm_i32_ty, // scale-b-data
+                llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
+              ])> {
+    int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
+
+    // The range [0;num_threads) is for the sparsity selector that indicates the threads
+    // which contribute metadata.
+    // According to PTX ISA 9.0, the sparsity selector is always 0
+    // for sparse MMA block scale instructions
+    int num_threads = 1;
+    let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
+                          Range<ArgIndex<pos>, 0, num_threads>];
+}
+
+// According to PTX ISA 9.0
+// a_layout = ["row"], b_layout = ["col"], spvariant = ["sp::ordered_metadata"]
+foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+  foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+    foreach stype = ["ue8m0", "ue4m3"] in {
+      foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
+        if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+          def MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
+                                      op[0], op[1], op[2], op[3]>.record
+            : NVVM_MMA_SP_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
+        }
+      } // op
+    } // stype
+  } // scale_vec_size
+} // kind
+
 // LDMATRIX
 class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
   : Intrinsic<Frag.regs, [llvm_anyptr_ty],
@@ -2984,4 +3172,4 @@ foreach sp = [0, 1] in {
   }
 }
 
-} // let TargetPrefix = "nvvm"
\ No newline at end of file
+} // let TargetPrefix = "nvvm"
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 22cf3a7eef2c1..febb9b41c84f3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4886,6 +4886,67 @@ defset list<WMMA_INSTR> MMAs  = {
 } // defset
 }
 
+// MMA.block_scale
+class MMA_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+                      WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+                      string Kind, string SType, string ScaleVecSize>
+  : WMMA_INSTR<MMA_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
+                                    FragA, FragB, FragC, FragD>.record,
+                                    [FragA.Ins, FragB.Ins, FragC.Ins,
+                                     (ins B32:$scale_a, B16:$byte_id_a,
+                                          B16:$thread_id_a, B32:$scale_b,
+                                          B16:$byte_id_b, B16:$thread_id_b)]>,
+    // Requires does not seem to have effect on Instruction w/o Patterns.
+    // We set it here anyways and propagate to the Pat<> we construct below.
+  Requires<FragA.Predicates> {
+  let OutOperandList = FragD.Outs;
+  let InOperandList  = !con(Args, (ins MmaCode:$ptx));
+  string TypeList = !interleave([FragD.ptx_elt_type,
+                                 FragA.ptx_elt_type,
+                                 FragB.ptx_elt_type,
+                                 FragC.ptx_elt_type], ".");
+  string ScaleVecSizeStr = !cond(
+    !eq(ScaleVecSize, "") : "",
+    !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X",
+    !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X",
+    !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X"
+  );
+  let AsmString = "mma.sync.aligned."
+                  # FragA.geom
+                  # ".row.col"
+                  # ".kind::" # Kind
+                  # ".block_scale"
+                  # ScaleVecSizeStr
+                  # "." # TypeList
+                  # "." # SType # " \n\t\t"
+                  # FragD.regstring # ",\n\t\t"
+                  # FragA.regstring # ",\n\t\t"
+                  # FragB.regstring # ",\n\t\t"
+                  # FragC.regstring # ",\n\t\t"
+                  # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t"
+                  # "$scale_b, {{$byte_id_b, $thread_id_b}};";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_BLOCK_SCALEs  = {
+  foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+    foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+      foreach stype = ["ue8m0", "ue4m3"] in {
+        foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
+          if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+            def : MMA_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.block_scale", "", kind>,
+                                  WMMA_REGINFO<op[1], "mma.block_scale", "", kind>,
+                                  WMMA_REGINFO<op[2], "mma.block_scale", "", kind>,
+                                  WMMA_REGINFO<op[3], "mma.block_scale", "", kind>,
+                                  kind, stype, scale_vec_size>;
+          }
+        } // op
+      } // stype
+    } // scale_vec_size
+  } // kind
+} // defset
+}
+
 // MMA SP
 class MMA_SP<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
              WMMA_REGINFO FragC, WMMA_REGINFO FragD,
@@ -4942,6 +5003,72 @@ defset list<WMMA_INSTR> MMA_SPs = {
 } // defset
 }
 
+// MMA SP BLOCK SCALE
+class MMA_SP_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+                         WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+                         string Kind, string SType, string ScaleVecSize>
+  : WMMA_INSTR<MMA_SP_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
+                                       FragA, FragB, FragC, FragD>.record,
+               [FragA.Ins, FragB.Ins, FragC.Ins,
+                (ins B32:$metadata, i32imm:$selector,
+                     B32:$scale_a, B16:$byte_id_a, B16:$thread_id_a,
+                     B32:$scale_b, B16:$byte_id_b, B16:$thread_id_b)]>,
+    // Requires does not seem to have effect on Instruction w/o Patterns.
+    // We set it here anyways and propagate to the Pat<> we construct below.
+    Requires<!listconcat(FragA.Predicates,
+                         FragB.Predicates,
+                         FragC.Predicates,
+                         FragD.Predicates)> {
+  let OutOperandList = FragD.Outs;
+  let InOperandList = !con(Args, (ins MmaCode:$ptx));
+  string TypeList = "." # FragD.ptx_elt_type
+                    # "." # FragA.ptx_elt_type
+                    # "." # FragB.ptx_elt_type
+                    # "." # FragC.ptx_elt_type;
+  string ScaleVecSizeStr = !cond(
+    !eq(ScaleVecSize, "") : "",
+    !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X",
+    !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X",
+    !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X"
+  );
+  let AsmString = "mma.sp::ordered_metadata.sync.aligned."
+                  # FragA.geom
+                  # ".row.col"
+                  # ".kind::" # Kind
+                  # ".block_scale"
+                  # ScaleVecSizeStr
+                  # TypeList
+                  # "." # SType # "\n\t\t"
+                  # FragD.regstring # ",\n\t\t"
+                  # FragA.regstring # ",\n\t\t"
+                  # FragB.regstring # ",\n\t\t"
+                  # FragC.regstring # ",\n\t\t"
+                  # "$metadata" # ",\n\t\t"
+                  # "$selector" # ",\n\t\t"
+                  # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t"
+                  # "$scale_b, {{$byte_id_b, $thread_id_b}};";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_SP_BLOCK_SCALEs = {
+  foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+    foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+      foreach stype = ["ue8m0", "ue4m3"] in {
+      foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
+        if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+          def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp", "sp::ordered_metadata", kind>,
+                       WMMA_REGINFO<op[1], "mma.sp", "sp::ordered_metadata", kind>,
+                       WMMA_REGINFO<op[2], "mma.sp", "sp::ordered_metadata", kind>,
+                       WMMA_REGINFO<op[3], "mma.sp", "sp::ordered_metadata", kind>,
+                       kind, stype, scale_vec_size>;
+        }
+      } // op
+      } // stype
+    } // scale_vec_size
+  } // kind
+} // defset
+}
+
 //
 // ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
 //
@@ -5023,7 +5150,8 @@ class MMA_PAT<WMMA_INSTR wi>
         Requires<wi.Predicates>;
 
 // Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs, MMA_SPs) in
+foreach mma = !listconcat(MMAs, MMA_BLOCK_SCALEs, WMMAs, MMA_LDSTs, LDMATRIXs,
+                          STMATRIXs, MMA_SPs, MMA_SP_BLOCK_SCALEs) in
   def : MMA_PAT<mma>;
 
 multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 8427ae4ad72da..81c78219075f3 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -131,7 +131,7 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
             "m16n8k64:b:e5m2": 4,
             "m16n8k64:b:e3m2": 4,
             "m16n8k64:b:e2m3": 4,
-            "m16n8k64:b:e2m1": 4,
+            "m16n8k64:b:e2m1": 4 if is_mma_sparse else 2,
             "m16n8k64:c:f16": 2,
             "m16n8k64:c:f32": 4,
             "m16n8k64:d:f16": 2,
@@ -1131,6 +1131,163 @@ def gen_mma_tests():
     return generated_items
 
 
+def get_mma_block_scale_ops():
+    return (
+        make_mma_ops(["m16n8k64"], ["e2m1"], [], ["f32"], [])
+        + make_mma_ops(
+            ["m16n8k32"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f32"],
+            [],
+        )
+    )
+
+
+def is_mma_block_scale_geom_supported(geom):
+    # geometries for FP.
+    if geom in [
+        "m16n8k32",
+        "m16n8k64",
+    ]:
+        return True
+    raise ValueError(f"Unexpected MMA block scale geometry: {geom}")
+
+
+def is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+    if not (
+        is_type_supported(op.a.mma_type.ptx_type)
+        and is_mma_block_scale_geom_supported(op.a.geom)
+    ):
+        return False
+
+    if (
+        op.a.geom == "m16n8k64"
+        and kind == "mxf4"
+        and stype == "ue8m0"
+        and scale_vec_size in ["", ".scale_vec::2X"]
+    ):
+        return True
+
+    if (
+        op.a.geom == "m16n8k64"
+        and kind == "mxf4nvf4"
+        and stype == "ue8m0"
+        and scale_vec_size == ".scale_vec::2X"
+    ):
+        return True
+
+    if (
+        op.a.geom == "m16n8k64"
+        and kind == "mxf4nvf4"
+        and stype == "ue4m3"
+        and scale_vec_size == ".scale_vec::4X"
+    ):
+        return True
+
+    if (
+        op.a.geom == "m16n8k32"
+        and kind == "mxf8f6f4"
+        and stype == "ue8m0"
+        and scale_vec_size in ["", ".scale_vec::1X"]
+    ):
+        return True
+
+    return False
+
+
+def common_mma_block_scale_test_gen(params, op, intrinsic_template, instruction_template):
+    mma_block_scale_template = """
+declare ${ret_ty} @${intrinsic}(
+        ${args});
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define ${ret_ty} @test_${function}(
+        ${args}) {
+; CHECK: ${instruction}
+; CHECK-NEXT: ${check_d}
+; CHECK-NEXT: ${check_a}
+; CHECK-NEXT: ${check_b}
+; CHECK-NEXT: ${check_c}
+; CHECK-NEXT: ${check_scale_a_data}
+; CHECK-NEXT: ${check_byte_id_a}
+; CHECK-NEXT: ${check_thread_id_a}
+; CHECK-NEXT: ${check_scale_b_data}
+; CHECK-NEXT: ${check_byte_id_b}
+; CHECK-NEXT: ${check_thread_id_b}
+  %r = call ${ret_ty} @${intrinsic}(
+        ${args});
+  ret ${ret_ty} %r;
+}
+"""
+
+    test_params = params
+    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+    test_params["function"] = test_params["intrinsic"].replace(".", "_")
+    test_...
[truncated]

@github-actions
Copy link

github-actions bot commented Oct 15, 2025

✅ With the latest revision this PR passed the Python code formatter.

Copy link
Member

Artem-B commented Oct 23, 2025

Please also make sure that the ptxas tests also still work with the ptxas from cuda-12.8.
Other than that the change LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants