Skip to content

Conversation

@kvederni
Copy link
Contributor

This change adds more MMA intrinsics for F8F6F4 and FP64 types. The implementation is based on PTX ISA version 9.0.
New restrictions were added for dtype/ctype combinations for MMA and sparse MMA intrinsics. MLIR restrictions for dtype/ctype MMA intrinsics were aligned with NVVM IR.

[NVPTX] Added restrictions for dtype/ctype combinations.
[MLIR] Aligned MMA restrictions with NVVM IR.

MMA description in PTX ISA 9.0 is at https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma
@llvmbot
Copy link
Member

llvmbot commented Aug 29, 2025

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-backend-nvptx
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-llvm-ir

Author: Kirill Vedernikov (kvederni)

Changes

This change adds more MMA intrinsics for F8F6F4 and FP64 types. The implementation is based on PTX ISA version 9.0.
New restrictions were added for dtype/ctype combinations for MMA and sparse MMA intrinsics. MLIR restrictions for dtype/ctype MMA intrinsics were aligned with NVVM IR.


Patch is 22.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156040.diff

5 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+90-12)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+21-9)
  • (modified) llvm/test/CodeGen/NVPTX/wmma.py (+104-11)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+3-2)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (-26)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 7b40841e45d0d..9015245f99983 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -272,6 +272,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
       !eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
       !eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4),
       !eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4),
+      !eq(gft,"m16n8k32:c:f16") : !listsplat(llvm_v2f16_ty, 2),
+      !eq(gft,"m16n8k32:c:f32") : !listsplat(llvm_float_ty, 4),
+      !eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
+      !eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
 
       // wmma fp16 -> fp16/fp32 @  m16n16k16/m8n32k16/m32n8k16
       // All other supported geometries use the same fragment format for f32 and
@@ -298,6 +302,21 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
       !eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2),
       !eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2),
 
+      !eq(gft,"m16n8k4:a:f64") : !listsplat(llvm_double_ty, 2),
+      !eq(gft,"m16n8k4:b:f64") : [llvm_double_ty],
+      !eq(gft,"m16n8k4:c:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k4:d:f64") : !listsplat(llvm_double_ty, 4),
+
+      !eq(gft,"m16n8k8:a:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k8:b:f64") : !listsplat(llvm_double_ty, 2),
+      !eq(gft,"m16n8k8:c:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k8:d:f64") : !listsplat(llvm_double_ty, 4),
+
+      !eq(gft,"m16n8k16:a:f64") : !listsplat(llvm_double_ty, 8),
+      !eq(gft,"m16n8k16:b:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k16:c:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k16:d:f64") : !listsplat(llvm_double_ty, 4),
+
       // wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
       !eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
       !eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4),
@@ -378,6 +397,26 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
       !eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
       !eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),
 
+      // mma e4m3/e5m2 -> f16/f32 @ m16n8k16
+      !eq(gft,"m16n8k16:a:e4m3") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k16:a:e5m2") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k16:b:e4m3") : [llvm_i32_ty],
+      !eq(gft,"m16n8k16:b:e5m2") : [llvm_i32_ty],
+      // mma e4m3/e5m2/e3m2/e2m3/e2m1 -> f32 @ m16n8k32
+      !eq(gft,"m16n8k32:a:e4m3") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:a:e5m2") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:a:e3m2") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:a:e2m3") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:a:e2m1") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:b:e4m3") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k32:b:e5m2") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k32:b:e3m2") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k32:b:e2m3") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k32:b:e2m1") : !listsplat(llvm_i32_ty, 2),
+      // mma e2m1 -> f32 @m16n8k64
+      !eq(gft,"m16n8k64:a:e2m1") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k64:b:e2m1") : !listsplat(llvm_i32_ty, 2),
+
       // wmma/mma b1 -> s32 @ m8n8k128(b1)
       !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
       !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
@@ -468,7 +507,7 @@ class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd, strin
                   # !if(Satfinite, "_satfinite", "");
 }
 
-class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
+class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, string Kind,
                WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
   string signature = MMA_SIGNATURE<A, B, C, D>.ret;
   string record = "int_nvvm_mma"
@@ -476,6 +515,7 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
                   # "_" # A.geom
                   # "_" # ALayout
                   # "_" # BLayout
+                  # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
                   # !if(Satfinite, "_satfinite", "")
                   # signature;
 }
@@ -601,7 +641,7 @@ class NVVM_MMA_OPS {
             ["m16n8k16", "m16n8k8"],
             ["bf16"], [], ["f32"], []>.ret;
   list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
-            ["m8n8k4"],
+            ["m8n8k4", "m16n8k4", "m16n8k8", "m16n8k16"],
             ["f64"], [], ["f64"], []>.ret;
   list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
             ["m8n8k4", "m16n8k8", "m16n8k16"],
@@ -609,6 +649,18 @@ class NVVM_MMA_OPS {
   list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
             ["m8n8k16", "m16n8k16", "m16n8k32"],
             ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
+  // m16n8k32 fp8 variants are intersected with f8f6f4 variants
+  // and processed there
+  list<list<WMMA_REGS>> fp8_mma_ops = MMA_OPS<
+            ["m16n8k16"],
+            ["e4m3", "e5m2"], ["e4m3", "e5m2"],
+            ["f16", "f32"], ["f16", "f32"]>.ret;
+  // it also contains e4m3/e5m2 from fp8 variants
+  list<list<WMMA_REGS>> f8f6f4_mma_ops = MMA_OPS<
+            ["m16n8k32"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f16", "f32"], ["f16", "f32"]>.ret;
   list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
             ["m8n8k32", "m16n8k32", "m16n8k64"],
             ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
@@ -617,7 +669,8 @@ class NVVM_MMA_OPS {
             ["b1"], [], ["s32"], []>.ret;
   list<list<WMMA_REGS>> all_mma_ops = !listconcat(
             tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
-            fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
+            fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
+            int_mma_ops, subint_mma_ops, bit_mma_ops);
 
   list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
             ["m16n8k16", "m16n8k32"],
@@ -770,7 +823,8 @@ class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
 // if NVVM_MMA_SUPPORTED<...>.ret then
 //   def : FOO<>; // The record will only be defined for supported ops.
 //
-class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
+class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b,
+                         string kind, int satf> {
   // MMA ops check both layouts.
   string layout = layout_a # ":" # layout_b;
   string a_type = frags[0].ptx_elt_type;
@@ -805,10 +859,31 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
          !or(!ne(a_type, b_type),
              !ne(c_type, d_type))): false,
 
-    // m16n8k8 requires C and D to be the same type.
-    !and(!eq(geom, "m16n8k8"),
+    // m16n8k16/m16n8k32 requires C and D to be the same type
+    !and(!or(!eq(geom, "m16n8k16"),
+             !eq(geom, "m16n8k32")),
          !ne(c_type, d_type)): false,
 
+    // Limit kind to valid types and geometries
+    !and(!ne(kind, ""),
+         !or(!ne(geom, "m16n8k32"),
+             !and(!ne(a_type, "e4m3"),
+                  !ne(a_type, "e5m2"),
+                  !ne(a_type, "e3m2"),
+                  !ne(a_type, "e2m3"),
+                  !ne(a_type, "e2m1")))): false,
+
+    // Limit m16n8k16/m16n8k32 with no kind to valid types
+    !and(!eq(kind, ""),
+         !or(!eq(geom, "m16n8k16"),
+             !eq(geom, "m16n8k32")),
+             !or(!eq(a_type, "e3m2"),
+                 !eq(a_type, "e2m3"),
+                 !eq(a_type, "e2m1"),
+                 !eq(b_type, "e3m2"),
+                 !eq(b_type, "e2m3"),
+                 !eq(b_type, "e2m1"))): false,
+
     // All other are OK.
     true: true
   );
@@ -882,9 +957,10 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
              !eq(a_type, "tf32")),
          !ne(a_type, b_type)): false,
 
-    // m16n8k16 and m16n8k32 requires C and D to be the same type.
+    // m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
     !and(!or(!eq(geom, "m16n8k16"),
-             !eq(geom, "m16n8k32")),
+             !eq(geom, "m16n8k32"),
+             !eq(geom, "m16n8k64")),
          !ne(c_type, d_type)): false,
 
     !and(!eq(kind, ""),
@@ -2143,10 +2219,12 @@ foreach layout_a = ["row", "col"] in {
     foreach satf = [0, 1] in {
       foreach op = NVVM_MMA_OPS.all_mma_ops in {
         foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
-          if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
-            def MMA_NAME<layout_a, layout_b, satf, b1op, op[0], op[1], op[2], op[3]>.record
-              : NVVM_MMA<op[0], op[1], op[2], op[3]>;
-          }
+          foreach kind = ["", "kind::f8f6f4"] in {
+            if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
+                def MMA_NAME<layout_a, layout_b, satf, b1op, kind, op[0], op[1], op[2], op[3]>.record
+                : NVVM_MMA<op[0], op[1], op[2], op[3]>;
+            }
+          } // kind
         } // b1op
       } // op
     } // satf
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index c544911bdf1e3..8f58c31d7e1c7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4461,6 +4461,10 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
         !eq(ptx_elt_type, "e2m1"),
         !ne(kind, "")) : [hasSM120a, hasPTX<87>],
 
+    !and(!or(!eq(ptx_elt_type,"e4m3"),
+             !eq(ptx_elt_type,"e5m2")),
+         !eq(geom, "m16n8k16")) : [hasSM<89>, hasPTX<87>],
+
     !or(!eq(ptx_elt_type, "e4m3"),
         !eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
 
@@ -4476,6 +4480,11 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
     !and(!eq(geom, "m8n8k4"),
          !eq(ptx_elt_type, "f64")) : [hasSM<80>, hasPTX<70>],
 
+    !and(!or(!eq(geom, "m16n8k4"),
+             !eq(geom, "m16n8k8"),
+             !eq(geom, "m16n8k16")),
+         !eq(ptx_elt_type, "f64")) : [hasSM<90>, hasPTX<78>],
+
     // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
     !and(!or(!eq(geom, "m8n32k16"),
              !eq(geom, "m32n8k16")),
@@ -4760,8 +4769,8 @@ defset list<WMMA_INSTR> WMMAs  = {
 // MMA
 class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                WMMA_REGINFO FragC, WMMA_REGINFO FragD,
-               string ALayout, string BLayout, int Satfinite, string b1op>
-  : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record,
+               string ALayout, string BLayout, int Satfinite, string b1op, string Kind>
+  : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record,
                         [FragA.Ins, FragB.Ins, FragC.Ins]>,
     // Requires does not seem to have effect on Instruction w/o Patterns.
     // We set it here anyways and propagate to the Pat<> we construct below.
@@ -4776,6 +4785,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragA.geom
                   # "." # ALayout
                   # "." # BLayout
+                  # !if(!ne(Kind, ""), "." # Kind, "")
                   # !if(Satfinite, ".satfinite", "")
                   # TypeList
                   # b1op # "\n\t\t"
@@ -4792,13 +4802,15 @@ defset list<WMMA_INSTR> MMAs  = {
       foreach satf = [0, 1] in {
         foreach op = NVVM_MMA_OPS.all_mma_ops in {
           foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
-            if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
-              def : MMA<WMMA_REGINFO<op[0], "mma">,
-                        WMMA_REGINFO<op[1], "mma">,
-                        WMMA_REGINFO<op[2], "mma">,
-                        WMMA_REGINFO<op[3], "mma">,
-                        layout_a, layout_b, satf, b1op>;
-            }
+            foreach kind = ["", "kind::f8f6f4"] in {
+              if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
+                def : MMA<WMMA_REGINFO<op[0], "mma", "", kind>,
+                          WMMA_REGINFO<op[1], "mma", "", kind>,
+                          WMMA_REGINFO<op[2], "mma", "", kind>,
+                          WMMA_REGINFO<op[3], "mma", "", kind>,
+                          layout_a, layout_b, satf, b1op, kind>;
+              }
+            } // kind
           } // b1op
         } // op
       } // satf
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 6d73bce46da7c..1c32856c1ce20 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -90,6 +90,21 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
             "m16n8k32:b:s8": 2,
             "m16n8k32:c:s32": 4,
             "m16n8k32:d:s32": 4,
+            # e4m3/e5m2/e3m2/e2m3/e2m1 -> f16/f32 @ m16n8k16/m16n8k32
+            "m16n8k16:a:e4m3": 2,
+            "m16n8k16:a:e5m2": 2,
+            "m16n8k32:a:e4m3": 4,
+            "m16n8k32:a:e5m2": 4,
+            "m16n8k32:a:e3m2": 4,
+            "m16n8k32:a:e2m3": 4,
+            "m16n8k32:a:e2m1": 4,
+            "m16n8k16:b:e4m3": 1,
+            "m16n8k16:b:e5m2": 1,
+            "m16n8k32:b:e4m3": 2,
+            "m16n8k32:b:e5m2": 2,
+            "m16n8k32:b:e3m2": 2,
+            "m16n8k32:b:e2m3": 2,
+            "m16n8k32:b:e2m1": 2,
             # mma sp
             "m16n8k32:a:bf16": 4,
             "m16n8k32:a:f16": 4,
@@ -182,6 +197,18 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
             "m8n8k4:b:f64": 1,
             "m8n8k4:c:f64": 2,
             "m8n8k4:d:f64": 2,
+            "m16n8k4:a:f64": 2,
+            "m16n8k4:b:f64": 1,
+            "m16n8k4:c:f64": 4,
+            "m16n8k4:d:f64": 4,
+            "m16n8k8:a:f64": 4,
+            "m16n8k8:b:f64": 2,
+            "m16n8k8:c:f64": 4,
+            "m16n8k8:d:f64": 4,
+            "m16n8k16:a:f64": 8,
+            "m16n8k16:b:f64": 4,
+            "m16n8k16:c:f64": 4,
+            "m16n8k16:d:f64": 4,
             # tf32 -> s32 @ m16n16k8
             "m16n16k8:a:tf32": 4,
             "m16n16k8:b:tf32": 4,
@@ -324,7 +351,9 @@ def get_wmma_ops():
 
 def get_mma_ops():
     return (
-        make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
+        make_mma_ops(
+            ["m8n8k4", "m16n8k4", "m16n8k8", "m16n8k16"], ["f64"], [], ["f64"], []
+        )
         + make_mma_ops(["m16n8k4", "m16n8k8"], ["tf32"], [], ["f32"], [])
         + make_mma_ops(["m16n8k16", "m16n8k8"], ["bf16"], [], ["f32"], [])
         + make_mma_ops(
@@ -341,6 +370,20 @@ def get_mma_ops():
             ["m8n8k32", "m16n8k32", "m16n8k64"], ["s4", "u4"], ["s4", "u4"], ["s32"], []
         )
         + make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], ["b1"], [], ["s32"], [])
+        + make_mma_ops(
+            ["m16n8k16"],
+            ["e4m3", "e5m2"],
+            ["e4m3", "e5m2"],
+            ["f16", "f32"],
+            ["f16", "f32"],
+        )
+        + make_mma_ops(
+            ["m16n8k32"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f16", "f32"],
+            ["f16", "f32"],
+        )
     )
 
 
@@ -492,7 +535,7 @@ def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
     return True
 
 
-def is_mma_variant_supported(op, layout_a, layout_b, satf):
+def is_mma_variant_supported(op, layout_a, layout_b, kind, satf):
     if not (
         is_type_supported(op.a.mma_type.ptx_type) and is_mma_geom_supported(op.a.geom)
     ):
@@ -516,13 +559,49 @@ def is_mma_variant_supported(op, layout_a, layout_b, satf):
     ):
         return False
 
+    if (
+        op.a.geom != "m8n8k4"
+        and op.a.mma_type.ptx_type == "f64"
+        and (ptx_version < 78 or gpu_arch < 90)
+    ):
+        return False
+
     # C and D type must be the same
-    if op.a.geom == "m16n8k16" and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type:
+    if (
+        op.a.geom in ["m16n8k16", "m16n8k32"]
+        and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
+    ):
+        return False
+
+    if (
+        op.a.geom in ["m16n8k16", "m16n8k32"]
+        and any(x in ["e4m3", "e5m2"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+        and ptx_version < 87
+    ):
+        return False
+
+    if kind != "" and (ptx_version < 87 or gpu_arch < 120 or not aa):
+        return False
+
+    if (
+        kind != ""
+        and (
+            op.a.geom != "m16n8k32"
+            or op.a.mma_type.ptx_type not in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
+        )
+    ):
+        return False
+
+    if (kind == ""
+        and op.a.geom in ["m16n8k16", "m16n8k32"]
+        and any(x in ["e3m2", "e2m3", "e2m1"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+    ):
         return False
 
     # Require row/col layout for all MMA except m8n8k4 on FP16
     if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
         return layout_a == "row" and layout_b == "col"
+
     return True
 
 
@@ -937,7 +1016,12 @@ def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
 """
 
     test_params = params
-    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+    test_params["intrinsic"] = (
+        Template(intrinsic_template)
+        .substitute(params)
+        .replace("::", ".")
+        .replace("_", ".")
+    )
     test_params["function"] = test_params["intrinsic"].replace(".", "_")
     test_params["instruction"] = Template(instruction_template).substitute(params)
     test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
@@ -1002,16 +1086,24 @@ def gen_wmma_mma_tests():
 
 
 def gen_mma_tests():
-    mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
-    mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
+    mma_intrinsic_template = (
+        "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${kind}${satf}.${intrinsic_signature}"
+    )
+    mma_instruction_template = (
+        "mma.sync${aligned}.${geom}.${alayout}.${blayout}${kind}${satf}.${ptx_signature}${b1op}"
+    )
 
     generated_items = []
 
-    for op, alayout, blayout, satf in product(
-        get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]
+    for op, alayout, blayout, kind, satf in product(
+        get_mma_ops(),
+        ["row", "col"],
+        ["row", "col"],
+        ["", ".kind::f8f6f4"],
+        [".satfinite", ""],
     ):
 
-        if not is_mma_variant_supported(op, alayout, blayout, satf):
+        if not is_mma_variant_supported(op, alayout, blayout, kind, satf):
             continue
 
         for b1op in get_b1_ops(op.a.mma_type.ptx_type):
@@ -1024,6 +1116,7 @@ def gen_mma_tests():
                 "satf": satf,
                 "geom": op.a.geom,
                 "b1op": b1op,
+                "kind": kind,
             }
 
             intrinsic_template = mma_intrinsic_template
@@ -1105,9 +1198,9 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf):
     ):
         return False
 
-    # C and D type must be the same for m16n8k16/m16n8k32
+    # C and D type must be the same for m16n8k16/m16n8k32/m16n8k64
     if (
-        op.a.geom in ["m16n8k16", "m16n8k32"]
+        op.a.geom in ["m16n8k16", "m16n8k32", "m16n8k64"]
         and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
     ):
         return False
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9528da05c9fd6..c1da1cf5d0c28 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1763,8 +1763,9 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
          !or(!ne(a_type, b_type),
              !ne(c_type, d_type))): false,
 
-    // m16n8k8 requires C and D to be the same type.
-    !and(!eq(geom, "m16n8k8"),
+    // m16n8k16/m16n8k32 requires C and D to be the same type
+    !and(!or(!eq(geom, "m16n8k16"),
+             !...
[truncated]

@github-actions
Copy link

github-actions bot commented Aug 29, 2025

✅ With the latest revision this PR passed the Python code formatter.

@durga4github
Copy link
Contributor

The MLIR side of changes LGTM.

@Artem-B , Please help with another round of review here.

Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLIR changes LGTM

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptxas test changes LGTM.

However, MLIR tests appear to not do what they were intended to do -- trigger verifier error and catch it.

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@schwarzschild-radius schwarzschild-radius merged commit bd8a7f9 into llvm:main Oct 6, 2025
9 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Oct 6, 2025

LLVM Buildbot has detected a new failure on builder openmp-offload-amdgpu-runtime-2 running on rocm-worker-hw-02 while building llvm,mlir at step 6 "test-openmp".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/10/builds/14838

Here is the relevant piece of the build log for the reference
Step 6 (test-openmp) failure: test (failure)
******************** TEST 'libarcher :: races/lock-nested-unrelated.c' FAILED ********************
Exit Code: 1

Command Output (stdout):
--
# RUN: at line 13
/home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/./bin/clang -fopenmp  -gdwarf-4 -O1 -fsanitize=thread  -I /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.src/openmp/tools/archer/tests -I /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -L /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -Wl,-rpath,/home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/runtimes/runtimes-bins/openmp/runtime/src   /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.src/openmp/tools/archer/tests/races/lock-nested-unrelated.c -o /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/runtimes/runtimes-bins/openmp/tools/archer/tests/races/Output/lock-nested-unrelated.c.tmp -latomic && env TSAN_OPTIONS='ignore_noninstrumented_modules=0:ignore_noninstrumented_modules=1' /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.src/openmp/tools/archer/tests/deflake.bash /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/runtimes/runtimes-bins/openmp/tools/archer/tests/races/Output/lock-nested-unrelated.c.tmp 2>&1 | tee /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/runtimes/runtimes-bins/openmp/tools/archer/tests/races/Output/lock-nested-unrelated.c.tmp.log | /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/./bin/FileCheck /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.src/openmp/tools/archer/tests/races/lock-nested-unrelated.c
# executed command: /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/./bin/clang -fopenmp -gdwarf-4 -O1 -fsanitize=thread -I /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.src/openmp/tools/archer/tests -I /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -L /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -Wl,-rpath,/home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/runtimes/runtimes-bins/openmp/runtime/src /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.src/openmp/tools/archer/tests/races/lock-nested-unrelated.c -o /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/runtimes/runtimes-bins/openmp/tools/archer/tests/races/Output/lock-nested-unrelated.c.tmp -latomic
# note: command had no output on stdout or stderr
# executed command: env TSAN_OPTIONS=ignore_noninstrumented_modules=0:ignore_noninstrumented_modules=1 /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.src/openmp/tools/archer/tests/deflake.bash /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/runtimes/runtimes-bins/openmp/tools/archer/tests/races/Output/lock-nested-unrelated.c.tmp
# note: command had no output on stdout or stderr
# executed command: tee /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/runtimes/runtimes-bins/openmp/tools/archer/tests/races/Output/lock-nested-unrelated.c.tmp.log
# note: command had no output on stdout or stderr
# executed command: /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.build/./bin/FileCheck /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.src/openmp/tools/archer/tests/races/lock-nested-unrelated.c
# .---command stderr------------
# | /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.src/openmp/tools/archer/tests/races/lock-nested-unrelated.c:47:11: error: CHECK: expected string not found in input
# | // CHECK: ThreadSanitizer: reported {{[1-7]}} warnings
# |           ^
# | <stdin>:23:5: note: scanning from here
# | DONE
# |     ^
# | <stdin>:24:1: note: possible intended match here
# | ThreadSanitizer: thread T4 finished with ignores enabled, created at:
# | ^
# | 
# | Input file: <stdin>
# | Check file: /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.src/openmp/tools/archer/tests/races/lock-nested-unrelated.c
# | 
# | -dump-input=help explains the following input dump.
# | 
# | Input was:
# | <<<<<<
# |             .
# |             .
# |             .
# |            18:  #0 pthread_create /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.src/compiler-rt/lib/tsan/rtl/tsan_interceptors_posix.cpp:1075:3 (lock-nested-unrelated.c.tmp+0xa3eaa) 
# |            19:  #1 __kmp_create_worker z_Linux_util.cpp (libomp.so+0xcbb42) 
# |            20:  
# |            21: SUMMARY: ThreadSanitizer: data race /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.src/openmp/tools/archer/tests/races/lock-nested-unrelated.c:33:8 in main.omp_outlined_debug__ 
# |            22: ================== 
# |            23: DONE 
# | check:47'0         X error: no match found
# |            24: ThreadSanitizer: thread T4 finished with ignores enabled, created at: 
# | check:47'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# | check:47'1     ?                                                                      possible intended match
# |            25:  #0 pthread_create /home/botworker/builds/openmp-offload-amdgpu-runtime-2/llvm.src/compiler-rt/lib/tsan/rtl/tsan_interceptors_posix.cpp:1075:3 (lock-nested-unrelated.c.tmp+0xa3eaa) 
# | check:47'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# |            26:  #1 __kmp_create_worker z_Linux_util.cpp (libomp.so+0xcbb42) 
# | check:47'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# |            27:  
...

@llvm-ci
Copy link
Collaborator

llvm-ci commented Oct 6, 2025

LLVM Buildbot has detected a new failure on builder arc-builder running on arc-worker while building llvm,mlir at step 6 "test-build-unified-tree-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/3/builds/22965

Here is the relevant piece of the build log for the reference
Step 6 (test-build-unified-tree-check-all) failure: test (failure)
******************** TEST 'LLVM :: CodeGen/X86/sse2-intrinsics-fast-isel.ll' FAILED ********************
Exit Code: 1

Command Output (stdout):
--
# RUN: at line 2
/buildbot/worker/arc-folder/build/bin/llc < /buildbot/worker/arc-folder/llvm-project/llvm/test/CodeGen/X86/sse2-intrinsics-fast-isel.ll -show-mc-encoding -fast-isel -mtriple=i386-unknown-unknown -mattr=+sse2 | /buildbot/worker/arc-folder/build/bin/FileCheck /buildbot/worker/arc-folder/llvm-project/llvm/test/CodeGen/X86/sse2-intrinsics-fast-isel.ll --check-prefixes=CHECK,X86,SSE,X86-SSE
# executed command: /buildbot/worker/arc-folder/build/bin/llc -show-mc-encoding -fast-isel -mtriple=i386-unknown-unknown -mattr=+sse2
# .---command stderr------------
# | LLVM ERROR: Cannot select: intrinsic %llvm.x86.sse2.clflush
# | PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace and instructions to reproduce the bug.
# | Stack dump:
# | 0.	Program arguments: /buildbot/worker/arc-folder/build/bin/llc -show-mc-encoding -fast-isel -mtriple=i386-unknown-unknown -mattr=+sse2
# | 1.	Running pass 'Function Pass Manager' on module '<stdin>'.
# | 2.	Running pass 'X86 DAG->DAG Instruction Selection' on function '@test_mm_clflush'
# |  #0 0x0000000002380428 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/buildbot/worker/arc-folder/build/bin/llc+0x2380428)
# |  #1 0x000000000237d335 SignalHandler(int, siginfo_t*, void*) Signals.cpp:0:0
# |  #2 0x00007f76facc7630 __restore_rt sigaction.c:0:0
# |  #3 0x00007f76f9a173d7 raise (/usr/lib64/libc.so.6+0x363d7)
# |  #4 0x00007f76f9a18ac8 abort (/usr/lib64/libc.so.6+0x37ac8)
# |  #5 0x000000000072474d llvm::json::operator==(llvm::json::Value const&, llvm::json::Value const&) (.cold) JSON.cpp:0:0
# |  #6 0x0000000002108ca9 llvm::SelectionDAGISel::CannotYetSelect(llvm::SDNode*) (/buildbot/worker/arc-folder/build/bin/llc+0x2108ca9)
# |  #7 0x000000000210d83a llvm::SelectionDAGISel::SelectCodeCommon(llvm::SDNode*, unsigned char const*, unsigned int) (/buildbot/worker/arc-folder/build/bin/llc+0x210d83a)
# |  #8 0x0000000000968b37 (anonymous namespace)::X86DAGToDAGISel::Select(llvm::SDNode*) X86ISelDAGToDAG.cpp:0:0
# |  #9 0x00000000021044ef llvm::SelectionDAGISel::DoInstructionSelection() (/buildbot/worker/arc-folder/build/bin/llc+0x21044ef)
# | #10 0x00000000021143f8 llvm::SelectionDAGISel::CodeGenAndEmitDAG() (/buildbot/worker/arc-folder/build/bin/llc+0x21143f8)
# | #11 0x000000000211852e llvm::SelectionDAGISel::SelectAllBasicBlocks(llvm::Function const&) (/buildbot/worker/arc-folder/build/bin/llc+0x211852e)
# | #12 0x0000000002119185 llvm::SelectionDAGISel::runOnMachineFunction(llvm::MachineFunction&) (/buildbot/worker/arc-folder/build/bin/llc+0x2119185)
# | #13 0x0000000002103cff llvm::SelectionDAGISelLegacy::runOnMachineFunction(llvm::MachineFunction&) (/buildbot/worker/arc-folder/build/bin/llc+0x2103cff)
# | #14 0x000000000121a4a7 llvm::MachineFunctionPass::runOnFunction(llvm::Function&) (.part.0) MachineFunctionPass.cpp:0:0
# | #15 0x000000000189589b llvm::FPPassManager::runOnFunction(llvm::Function&) (/buildbot/worker/arc-folder/build/bin/llc+0x189589b)
# | #16 0x0000000001895c41 llvm::FPPassManager::runOnModule(llvm::Module&) (/buildbot/worker/arc-folder/build/bin/llc+0x1895c41)
# | #17 0x0000000001896855 llvm::legacy::PassManagerImpl::run(llvm::Module&) (/buildbot/worker/arc-folder/build/bin/llc+0x1896855)
# | #18 0x0000000000805f22 compileModule(char**, llvm::LLVMContext&) llc.cpp:0:0
# | #19 0x000000000072cfe6 main (/buildbot/worker/arc-folder/build/bin/llc+0x72cfe6)
# | #20 0x00007f76f9a03555 __libc_start_main (/usr/lib64/libc.so.6+0x22555)
# | #21 0x00000000007fc146 _start (/buildbot/worker/arc-folder/build/bin/llc+0x7fc146)
# `-----------------------------
# error: command failed with exit status: -6
# executed command: /buildbot/worker/arc-folder/build/bin/FileCheck /buildbot/worker/arc-folder/llvm-project/llvm/test/CodeGen/X86/sse2-intrinsics-fast-isel.ll --check-prefixes=CHECK,X86,SSE,X86-SSE
# .---command stderr------------
# | /buildbot/worker/arc-folder/llvm-project/llvm/test/CodeGen/X86/sse2-intrinsics-fast-isel.ll:399:14: error: SSE-LABEL: expected string not found in input
# | ; SSE-LABEL: test_mm_bsrli_si128:
# |              ^
# | <stdin>:170:21: note: scanning from here
# | test_mm_bslli_si128: # @test_mm_bslli_si128
# |                     ^
# | <stdin>:178:9: note: possible intended match here
# |  .globl test_mm_bsrli_si128 # 
# |         ^
...

aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 6, 2025
…6040)

This change adds more MMA intrinsics for F8F6F4 and FP64 types. The implementation is based on [PTX ISA version 9.0](https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma). New restrictions were added for dtype/ctype combinations for MMA and sparse MMA intrinsics. MLIR restrictions for dtype/ctype MMA intrinsics were aligned with NVVM IR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants