Skip to content

Conversation

@Muzammiluddin-Syed-ECE
Copy link
Contributor

The current implementation of the WMMA intrinsic ops as they are defined in the ROCDL tablegen is incorrect. They represent as operands what should be attributes such as clamp, opsel, signA/signB. This change performs a refactoring to bring it in line with what we expect.

@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-llvm

Author: Muzammiluddin Syed (Muzammiluddin-Syed-ECE)

Changes

The current implementation of the WMMA intrinsic ops as they are defined in the ROCDL tablegen is incorrect. They represent as operands what should be attributes such as clamp, opsel, signA/signB. This change performs a refactoring to bring it in line with what we expect.


Patch is 59.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167041.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+137-40)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+50-37)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir (+8-8)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir (+10-10)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir (+11-11)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+70-72)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 5241f9a6f2b43..3cc906030b1a4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -582,51 +582,148 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
 
 //===---------------------------------------------------------------------===//
 // WMMA intrinsics
-class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
-                        list<Trait> traits = []> :
-  ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
-  Arguments<(ins Variadic<LLVM_Type>:$args)> {
-  let assemblyFormat =
-    "$args attr-dict `:` functional-type($args, $res)";
+class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [0], [], 1, 0, 0, 0, [], []>,
+  Arguments<(ins 
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             LLVM_ScalarOrVectorOf<CD>:$C)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [1], [], 1, 0, 0, 0, [3], ["opsel"]>,
+  Arguments<(ins 
+             LLVM_ScalarOrVectorOf<AB>:$A,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$opsel)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [1], [], 1, 0, 0, 0, [0, 2, 5], ["signA", "signB", "clamp"]>,
+  Arguments<(ins 
+             DefaultValuedAttr<I1Attr, "0">:$signA,
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             DefaultValuedAttr<I1Attr, "0">:$signB,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$clamp)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [1], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
+  Arguments<(ins 
+             DefaultValuedAttr<I1Attr, "0">:$signA,
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             DefaultValuedAttr<I1Attr, "0">:$signB,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             DefaultValuedAttr<I16Attr, "0">:$modC, 
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseA, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [0], [], 1, 0, 0, 0, [2, 4, 5], ["modC","reuseA","reuseB"]>,
+  Arguments<(ins 
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             DefaultValuedAttr<I16Attr, "0">:$modC, 
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseA, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> : ROCDL_IntrOp<mnemonic,
+    [0], [1, 5], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
+  Arguments<(ins 
+             DefaultValuedAttr<I1Attr, "0">:$signA,
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             DefaultValuedAttr<I1Attr, "0">:$signB,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             DefaultValuedAttr<I16Attr, "0">:$modC, 
+             LLVM_ScalarOrVectorOf<C>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseA, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<D>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [1], [], 1, 0, 0, 0, [0, 2, 5, 6], ["signA", "signB", "reuseA","reuseB"]>,
+  Arguments<(ins 
+             DefaultValuedAttr<I1Attr, "0">:$signA,
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             DefaultValuedAttr<I1Attr, "0">:$signB,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseA, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
 }
 
 // Available from gfx11
-def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
-def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
-def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
-def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
-def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
-def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
+def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.f16", /*Type AB=*/F16, /*Type CD=*/F32>;
+def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf16", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.f16.16x16x16.f16", F16, F16>;
+def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.bf16.16x16x16.bf16", AnyInteger, AnyInteger>;
+def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu8", AnyInteger, AnyInteger>;
+def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu4", AnyInteger, AnyInteger>;
 // Available from gfx12
-def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
-def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>;
-def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
-def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
-def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
+def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x32.iu4", AnyInteger, AnyInteger>;
 // Available from gfx1250
-def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>;
-def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>;
-def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>;
-def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>;
-def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>;
-def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;
-def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
-def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
+def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x4.f32", F32, F32>;
+def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.bf16", BF16, F32>;
+def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.f16", F16, F32>;
+def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f16.16x16x32.f16", F16, F16>;
+def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.bf16.16x16x32.bf16", BF16, BF16>;
+def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Diff_IntrOp<"wmma.bf16f32.16x16x32.bf16", BF16, /*Type C=*/F32, /*Type D=*/BF16>;
+def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp<"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger>;
 
 //===---------------------------------------------------------------------===//
 // LDS transpose intrinsics (available in GFX950)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3a307a0756d93..996cf445a15d0 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -79,12 +80,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter,
   return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
 }
 
-static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
-                              bool value) {
-  Type llvmI1 = rewriter.getI1Type();
-  return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
-}
-
 /// Returns the linear index used to access an element in the memref.
 static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
                                Location loc, MemRefDescriptor &memRefDescriptor,
@@ -684,12 +679,11 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
 /// intrinsics having been defined before the AMD backend supported bfloat. We
 /// similarly need to pack 8-bit float types into integers as if they were i8
 /// (which they are for the backend's purposes).
-static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
-                                 Location loc,
-                                 const TypeConverter *typeConverter,
-                                 bool isUnsigned, Value llvmInput,
-                                 Value mlirInput,
-                                 SmallVector<Value, 4> &operands) {
+static void wmmaPushInputOperand(
+    ConversionPatternRewriter &rewriter, Location loc,
+    const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
+    Value mlirInput, SmallVector<Value, 4> &operands,
+    SmallVector<NamedAttribute, 4> &attrs, StringRef attrName) {
   Type inputType = llvmInput.getType();
   auto vectorType = dyn_cast<VectorType>(inputType);
   if (!vectorType) {
@@ -697,10 +691,6 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
     return;
   }
   Type elemType = vectorType.getElementType();
-
-  if (elemType.isBF16())
-    llvmInput = LLVM::BitcastOp::create(
-        rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
   if (elemType.getIntOrFloatBitWidth() > 8) {
     operands.push_back(llvmInput);
     return;
@@ -719,8 +709,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
     } else if (elemType.isSignedInteger()) {
       localIsUnsigned = false;
     }
-    Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
-    operands.push_back(sign);
+    attrs.push_back(NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
   }
 
   int64_t numBits =
@@ -751,18 +740,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
                                   Location loc,
                                   const TypeConverter *typeConverter,
                                   Value output, int32_t subwordOffset,
-                                  bool clamp, SmallVector<Value, 4> &operands) {
+                                  bool clamp, SmallVector<Value, 4> &operands,
+                                  SmallVector<NamedAttribute, 4> &attrs) {
   Type inputType = output.getType();
   auto vectorType = dyn_cast<VectorType>(inputType);
   Type elemType = vectorType.getElementType();
-  if (elemType.isBF16())
-    output = LLVM::BitcastOp::create(
-        rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
   operands.push_back(output);
   if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
-    operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
+    attrs.push_back(
+        NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
   } else if (elemType.isInteger(32)) {
-    operands.push_back(createI1Constant(rewriter, loc, clamp));
+    attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
   }
 }
 
@@ -1302,6 +1290,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   LogicalResult
   matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+
     Location loc = op.getLoc();
     auto outType =
         typeConverter->convertType<VectorType>(op.getDestD().getType());
@@ -1311,11 +1300,33 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
       return op->emitOpError("WMMA only supported on gfx11 and gfx12");
 
-    // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
-    // need to bitcast bfloats to i16 and then bitcast them back.
+    bool isGFX1250 = chipset == Chipset(12, 5, 0);
+
+    // The WMMA operations represent vectors of bf16s as vectors of i16s
+    // (except on gfx1250), so we need to bitcast bfloats to i16 and then
+    // bitcast them back.
+    auto aType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
+    auto bType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
+    auto destCType = dyn_cast<VectorType>(adaptor.getDestC().getType());
+    bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
+    bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
+    bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
+    bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
     VectorType rawOutType = outType;
-    if (outType.getElementType().isBF16())
+    if (castOutToI16)
       rawOutType = outType.clone(rewriter.getI16Type());
+    Value a = adaptor.getSourceA();
+    if (castAToI16)
+      a = LLVM::BitcastOp::create(rewriter, loc,
+                                  aType.clone(rewriter.getI16Type()), a);
+    Value b = adaptor.getSourceB();
+    if (castBToI16)
+      b = LLVM::BitcastOp::create(rewriter, loc,
+                                  bType.clone(rewriter.getI16Type()), b);
+    Value destC = adaptor.getDestC();
+    if (castDestCToI16)
+      destC = LLVM::BitcastOp::create(
+          rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
 
     std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
 
@@ -1325,21 +1336,23 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
       return op.emitOpError("subwordOffset not supported on gfx12+");
 
-    OperationState loweredOp(loc, *maybeIntrinsic);
-    loweredOp.addTypes(rawOutType);
-
     SmallVector<Value, 4> operands;
-    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
-                         adaptor.getSourceA(), op.getSourceA(), operands);
-    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
-                         adaptor.getSourceB(), op.getSourceB(), operands);
-    wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
-                          op.getSubwordOffset(), op.getClamp(), operands);
+    SmallVector<NamedAttribute, 4> attrs;
+    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
+                         op.getSourceA(), operands, attrs, "signA");
+    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
+                         op.getSourceB(), operands, attrs, "signB");
+    wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
+                          op.getSubwordOffset(), op.getClamp(), operands,
+                          attrs);
 
+    OperationState loweredOp(loc, *maybeIntrinsic);
+    loweredOp.addTypes(rawOutType);
     loweredOp.addOpe...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2025

@llvm/pr-subscribers-mlir

Author: Muzammiluddin Syed (Muzammiluddin-Syed-ECE)

Changes

The current implementation of the WMMA intrinsic ops as they are defined in the ROCDL tablegen is incorrect. They represent as operands what should be attributes such as clamp, opsel, signA/signB. This change performs a refactoring to bring it in line with what we expect.


Patch is 59.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167041.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+137-40)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+50-37)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir (+8-8)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir (+10-10)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir (+11-11)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+70-72)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 5241f9a6f2b43..3cc906030b1a4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -582,51 +582,148 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
 
 //===---------------------------------------------------------------------===//
 // WMMA intrinsics
-class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
-                        list<Trait> traits = []> :
-  ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
-  Arguments<(ins Variadic<LLVM_Type>:$args)> {
-  let assemblyFormat =
-    "$args attr-dict `:` functional-type($args, $res)";
+class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [0], [], 1, 0, 0, 0, [], []>,
+  Arguments<(ins 
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             LLVM_ScalarOrVectorOf<CD>:$C)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [1], [], 1, 0, 0, 0, [3], ["opsel"]>,
+  Arguments<(ins 
+             LLVM_ScalarOrVectorOf<AB>:$A,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$opsel)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [1], [], 1, 0, 0, 0, [0, 2, 5], ["signA", "signB", "clamp"]>,
+  Arguments<(ins 
+             DefaultValuedAttr<I1Attr, "0">:$signA,
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             DefaultValuedAttr<I1Attr, "0">:$signB,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$clamp)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [1], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
+  Arguments<(ins 
+             DefaultValuedAttr<I1Attr, "0">:$signA,
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             DefaultValuedAttr<I1Attr, "0">:$signB,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             DefaultValuedAttr<I16Attr, "0">:$modC, 
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseA, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [0], [], 1, 0, 0, 0, [2, 4, 5], ["modC","reuseA","reuseB"]>,
+  Arguments<(ins 
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             DefaultValuedAttr<I16Attr, "0">:$modC, 
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseA, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> : ROCDL_IntrOp<mnemonic,
+    [0], [1, 5], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
+  Arguments<(ins 
+             DefaultValuedAttr<I1Attr, "0">:$signA,
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             DefaultValuedAttr<I1Attr, "0">:$signB,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             DefaultValuedAttr<I16Attr, "0">:$modC, 
+             LLVM_ScalarOrVectorOf<C>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseA, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<D>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
+}
+
+class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+    [0], [1], [], 1, 0, 0, 0, [0, 2, 5, 6], ["signA", "signB", "reuseA","reuseB"]>,
+  Arguments<(ins 
+             DefaultValuedAttr<I1Attr, "0">:$signA,
+             LLVM_ScalarOrVectorOf<AB>:$A, 
+             DefaultValuedAttr<I1Attr, "0">:$signB,
+             LLVM_ScalarOrVectorOf<AB>:$B,
+             LLVM_ScalarOrVectorOf<CD>:$C, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseA, 
+             DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+  let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+  let assemblyFormat = [{
+    $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+  }];
 }
 
 // Available from gfx11
-def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
-def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
-def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
-def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
-def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
-def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
+def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.f16", /*Type AB=*/F16, /*Type CD=*/F32>;
+def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf16", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.f16.16x16x16.f16", F16, F16>;
+def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.bf16.16x16x16.bf16", AnyInteger, AnyInteger>;
+def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu8", AnyInteger, AnyInteger>;
+def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu4", AnyInteger, AnyInteger>;
 // Available from gfx12
-def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
-def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>;
-def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
-def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
-def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
+def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x32.iu4", AnyInteger, AnyInteger>;
 // Available from gfx1250
-def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>;
-def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>;
-def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>;
-def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>;
-def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>;
-def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;
-def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
-def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
+def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x4.f32", F32, F32>;
+def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.bf16", BF16, F32>;
+def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.f16", F16, F32>;
+def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f16.16x16x32.f16", F16, F16>;
+def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.bf16.16x16x32.bf16", BF16, BF16>;
+def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Diff_IntrOp<"wmma.bf16f32.16x16x32.bf16", BF16, /*Type C=*/F32, /*Type D=*/BF16>;
+def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp<"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger>;
 
 //===---------------------------------------------------------------------===//
 // LDS transpose intrinsics (available in GFX950)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3a307a0756d93..996cf445a15d0 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -79,12 +80,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter,
   return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
 }
 
-static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
-                              bool value) {
-  Type llvmI1 = rewriter.getI1Type();
-  return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
-}
-
 /// Returns the linear index used to access an element in the memref.
 static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
                                Location loc, MemRefDescriptor &memRefDescriptor,
@@ -684,12 +679,11 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
 /// intrinsics having been defined before the AMD backend supported bfloat. We
 /// similarly need to pack 8-bit float types into integers as if they were i8
 /// (which they are for the backend's purposes).
-static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
-                                 Location loc,
-                                 const TypeConverter *typeConverter,
-                                 bool isUnsigned, Value llvmInput,
-                                 Value mlirInput,
-                                 SmallVector<Value, 4> &operands) {
+static void wmmaPushInputOperand(
+    ConversionPatternRewriter &rewriter, Location loc,
+    const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
+    Value mlirInput, SmallVector<Value, 4> &operands,
+    SmallVector<NamedAttribute, 4> &attrs, StringRef attrName) {
   Type inputType = llvmInput.getType();
   auto vectorType = dyn_cast<VectorType>(inputType);
   if (!vectorType) {
@@ -697,10 +691,6 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
     return;
   }
   Type elemType = vectorType.getElementType();
-
-  if (elemType.isBF16())
-    llvmInput = LLVM::BitcastOp::create(
-        rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
   if (elemType.getIntOrFloatBitWidth() > 8) {
     operands.push_back(llvmInput);
     return;
@@ -719,8 +709,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
     } else if (elemType.isSignedInteger()) {
       localIsUnsigned = false;
     }
-    Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
-    operands.push_back(sign);
+    attrs.push_back(NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
   }
 
   int64_t numBits =
@@ -751,18 +740,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
                                   Location loc,
                                   const TypeConverter *typeConverter,
                                   Value output, int32_t subwordOffset,
-                                  bool clamp, SmallVector<Value, 4> &operands) {
+                                  bool clamp, SmallVector<Value, 4> &operands,
+                                  SmallVector<NamedAttribute, 4> &attrs) {
   Type inputType = output.getType();
   auto vectorType = dyn_cast<VectorType>(inputType);
   Type elemType = vectorType.getElementType();
-  if (elemType.isBF16())
-    output = LLVM::BitcastOp::create(
-        rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
   operands.push_back(output);
   if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
-    operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
+    attrs.push_back(
+        NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
   } else if (elemType.isInteger(32)) {
-    operands.push_back(createI1Constant(rewriter, loc, clamp));
+    attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
   }
 }
 
@@ -1302,6 +1290,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   LogicalResult
   matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+
     Location loc = op.getLoc();
     auto outType =
         typeConverter->convertType<VectorType>(op.getDestD().getType());
@@ -1311,11 +1300,33 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
       return op->emitOpError("WMMA only supported on gfx11 and gfx12");
 
-    // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
-    // need to bitcast bfloats to i16 and then bitcast them back.
+    bool isGFX1250 = chipset == Chipset(12, 5, 0);
+
+    // The WMMA operations represent vectors of bf16s as vectors of i16s
+    // (except on gfx1250), so we need to bitcast bfloats to i16 and then
+    // bitcast them back.
+    auto aType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
+    auto bType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
+    auto destCType = dyn_cast<VectorType>(adaptor.getDestC().getType());
+    bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
+    bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
+    bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
+    bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
     VectorType rawOutType = outType;
-    if (outType.getElementType().isBF16())
+    if (castOutToI16)
       rawOutType = outType.clone(rewriter.getI16Type());
+    Value a = adaptor.getSourceA();
+    if (castAToI16)
+      a = LLVM::BitcastOp::create(rewriter, loc,
+                                  aType.clone(rewriter.getI16Type()), a);
+    Value b = adaptor.getSourceB();
+    if (castBToI16)
+      b = LLVM::BitcastOp::create(rewriter, loc,
+                                  bType.clone(rewriter.getI16Type()), b);
+    Value destC = adaptor.getDestC();
+    if (castDestCToI16)
+      destC = LLVM::BitcastOp::create(
+          rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
 
     std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
 
@@ -1325,21 +1336,23 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
       return op.emitOpError("subwordOffset not supported on gfx12+");
 
-    OperationState loweredOp(loc, *maybeIntrinsic);
-    loweredOp.addTypes(rawOutType);
-
     SmallVector<Value, 4> operands;
-    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
-                         adaptor.getSourceA(), op.getSourceA(), operands);
-    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
-                         adaptor.getSourceB(), op.getSourceB(), operands);
-    wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
-                          op.getSubwordOffset(), op.getClamp(), operands);
+    SmallVector<NamedAttribute, 4> attrs;
+    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
+                         op.getSourceA(), operands, attrs, "signA");
+    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
+                         op.getSourceB(), operands, attrs, "signB");
+    wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
+                          op.getSubwordOffset(), op.getClamp(), operands,
+                          attrs);
 
+    OperationState loweredOp(loc, *maybeIntrinsic);
+    loweredOp.addTypes(rawOutType);
     loweredOp.addOpe...
[truncated]

…perands where possible

Signed-off-by: Muzammiluddin Syed <[email protected]>
@github-actions
Copy link

github-actions bot commented Nov 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Signed-off-by: Muzammiluddin Syed <[email protected]>
- Use SmallVectorImpl for output parameters per LLVM coding standards
- Fix type casting bug (getSourceA -> getSourceB on line 1309)
- Use cast<> instead of dyn_cast<> for guaranteed vector types
- Simplify assembly format with functional-type shorthand

Addresses review comments from @kuhar on PR llvm#167041
Update the Python test to match the new WMMA API where operands are
passed as separate positional arguments and attributes (like opsel) are
passed as keyword arguments with boolean values.

Changes:
- Changed from list-style arguments to positional arguments
- Changed opsel from MLIR Value operand to Python bool attribute
- Removed unnecessary false constant creation

This fixes the CI test failure where the Python bindings were expecting
the old API signature.
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but wait for @krzysz00 to review as well

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's ensure we have (working, precise) tests for all the things that just got made attributes with their non-default values, otherwise LGTM here

Signed-off-by: Muzammiluddin Syed <[email protected]>
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE merged commit b1262d1 into llvm:main Nov 14, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants