Skip to content

Conversation

@Muzammiluddin-Syed-ECE
Copy link
Contributor

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE commented Oct 7, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Oct 7, 2025

@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Muzammil (Muzammiluddin-Syed-ECE)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/162343.diff

3 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+29-1)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+41-2)
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index ded00b1274670..04ce1aedfdb4d 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -4085,11 +4085,11 @@ class AMDGPUWmmaScaleF4IntrinsicModsC<LLVMType scale_ty> :
 
 defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
 def int_amdgcn_wmma_f32_16x16x4_f32       : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
-def int_amdgcn_wmma_f32_16x16x32_bf16     : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x32_bf16     : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x32_f16      : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f16_16x16x32_f16      : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
-def int_amdgcn_wmma_bf16_16x16x32_bf16    : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
-def int_amdgcn_wmma_bf16f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllDiff<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_bf16_16x16x32_bf16    : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_bf16f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllDiff<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x64_fp8_fp8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x64_fp8_bf8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x64_bf8_fp8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index db1b7e3af62fd..3814a2dae0f3f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -471,7 +471,7 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
 
 //===---------------------------------------------------------------------===//
 // WMMA intrinsics
-class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
+class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands = [],
                         list<Trait> traits = []> :
   ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
   Arguments<(ins Variadic<LLVM_Type>:$args)> {
@@ -492,6 +492,34 @@ def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_b
 def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
 def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
 def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
+// Available from gfx1250
+def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>;
+def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>;
+def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>;
+def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>;
+def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>;
+def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;
+def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
+def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
+def ROCDL_wmma_scale_f32_16x16x128_f8f6f4 : ROCDL_Wmma_IntrOp<"wmma.scale.f32.16x16x128.f8f6f4">;
+def ROCDL_wmma_scale16_f32_16x16x128_f8f6f4 : ROCDL_Wmma_IntrOp<"wmma.scale16.f32.16x16x128.f8f6f4">;
+def ROCDL_wmma_scale_f32_32x16x128_f4 : ROCDL_Wmma_IntrOp<"wmma.scale.f32.32x16x128.f4">;
+def ROCDL_wmma_scale16_f32_32x16x128_f4 : ROCDL_Wmma_IntrOp<"wmma.scale16.f32.32x16x128.f4">;
 
 //===---------------------------------------------------------------------===//
 // LDS transpose intrinsics (available in GFX950)
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 1c0c2eba002aa..e0400bf07b563 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -816,9 +816,12 @@ llvm.func @rocdl.mfma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
 }
 
 llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : vector<16 x i16>, %arg3 : vector<8 x i32>,
-                      %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>) -> vector<8xf32> {
+                      %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>,
+                      %arg9 : vector<32xf16>, %arg10 : vector<16xf32>, %arg11 : vector<4xf32>, %arg12 : vector<32xf32>,
+                      %arg13 : vector<16xi32>, %arg14 : vector<64xf32>, %arg15 : vector<64xi32>, %arg16 : i32) -> vector<8xf32> {
   %zero = llvm.mlir.constant(false) : i1
-
+  %zero_i16 = llvm.mlir.constant(0 : i16) : i16
+  %zero_i32 = llvm.mlir.constant(0 : i32) : i32
   // ---- Wave32 -----
 
   // f16 -> f32
@@ -849,6 +852,42 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
   // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
   %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
 
+  // f32 -> f32
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 false, <16 x float> %10, i1 false, <16 x float> %10, i16 0, <4 x float> %11, i1 false, i1 false)
+  %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %zero, %arg10, %zero, %arg10, %zero_i16, %arg11, %zero, %zero : (i1, vector<16xf32>, i1, vector<16xf32>, i16, vector<4xf32>, i1, i1) -> vector<4xf32>
+
+  // bf16 -> f32
+  // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16i16(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <32 x float> %12, i1 false, i1 false)
+  %r2.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero, %arg2, %zero, %arg2, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xi16>, i1, vector<16xi16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
+
+  // f16 -> f32
+  // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %1, i1 false, <16 x half> %1, i16 0, <32 x float> %12, i1 false, i1 false)
+  %r3.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
+
+  // f16 -> f16
+  // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 false, <16 x half> %1, i1 false, <16 x half> %1, i16 0, <32 x half> %9, i1 false, i1 false)
+  %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg9, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf16>, i1, i1) -> vector<32xf16>
+
+  // bf16 -> bf16
+  // CHECK: call <16 x i32> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v16i32.v16i16(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <16 x i32> %13, i1 false, i1 false)
+  %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero, %arg2, %zero, %arg2, %zero_i16, %arg13, %zero, %zero : (i1, vector<16xi16>, i1, vector<16xi16>, i16, vector<16xi32>, i1, i1) -> vector<16xi32>
+
+  // bf16 -> bf16 / f32
+  // CHECK: call <16 x i32> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v16i32.v16i16.v32f32(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <32 x float> %12, i1 false, i1 false)
+  %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero, %arg2, %zero, %arg2, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xi16>, i1, vector<16xi16>, i16, vector<32xf32>, i1, i1) -> vector<16xi32>
+
+  // f8 -> f32
+  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %5, <4 x i32> %5, i16 0, <64 x float> %14, i1 false, i1 false)
+  %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg14, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+
+  // iu8 -> i32
+  // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %5, i1 false, <4 x i32> %5, <64 x i32> %15, i1 false, i1 false)
+  %r8.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero, %arg5, %zero, %arg5, %arg15, %zero, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<64xi32>, i1, i1) -> vector<64xi32>
+
+  %r9.gfx1250 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %zero_i32, %arg5, %zero_i32, %arg5, %zero_i16, %arg11, %zero_i32, %zero_i32, %arg16, %zero_i32, %zero_i32, %arg16, %zero, %zero : (i32, vector<4xi32>, i32, vector<4xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  // %r7.gfx1250 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4
+  // %r7.gfx1250 = rocdl.wmma.scale.f32.32x16x128.f4
+  // %r7.gfx1250 = rocdl.wmma.scale16.f32.32x16x128.f4
   // ---- Wave64 -----
 
   // f16 -> f32

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE marked this pull request as draft October 7, 2025 18:54
@Muzammiluddin-Syed-ECE
Copy link
Contributor Author

Labelled this as draft while I do a sanity check on this line of suspicious code:

def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;

@krzysz00
Copy link
Contributor

krzysz00 commented Oct 7, 2025

@Muzammiluddin-Syed-ECE Yes, that code looks correct

That being said, While You're Here (tm) ... would it be possible to actually list out the WMMA signatures instead of just using a variadic set of arguments? (Yes, that'll mean more than one template - you can probably borrow names off of LLVM. It'll also let us encode the AllTypesMatch<["x", "y"]> that the LLVM definitions imply). You might not need to change the syntax, but it'll make it harder to construct invalid IR. It'll also let us use Attributes for immargs like we should.

Signed-off-by: Muzammiluddin Syed <[email protected]>
Signed-off-by: Muzammiluddin Syed <[email protected]>
Signed-off-by: Muzammiluddin Syed <[email protected]>
Copy link
Contributor

@amd-eochoalo amd-eochoalo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE merged commit 2f70482 into llvm:main Oct 17, 2025
10 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Oct 17, 2025

LLVM Buildbot has detected a new failure on builder ppc64le-mlir-rhel-clang running on ppc64le-mlir-rhel-test while building mlir at step 6 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/31700

Here is the relevant piece of the build log for the reference
Step 6 (test-build-check-mlir-build-only-check-mlir) failure: 1200 seconds without output running [b'ninja', b'check-mlir'], attempting to kill
...
PASS: MLIR :: mlir-runner/utils.mlir (3565 of 3577)
PASS: MLIR :: mlir-tblgen/op-error.td (3566 of 3577)
PASS: MLIR-Unit :: IR/./MLIRIRTests/0/130 (3567 of 3577)
PASS: MLIR-Unit :: IR/./MLIRIRTests/38/130 (3568 of 3577)
PASS: MLIR-Unit :: IR/./MLIRIRTests/39/130 (3569 of 3577)
PASS: MLIR-Unit :: Interfaces/./MLIRInterfacesTests/12/22 (3570 of 3577)
PASS: MLIR-Unit :: Pass/./MLIRPassTests/10/13 (3571 of 3577)
PASS: MLIR-Unit :: Interfaces/./MLIRInterfacesTests/13/22 (3572 of 3577)
PASS: MLIR-Unit :: Interfaces/./MLIRInterfacesTests/11/22 (3573 of 3577)
PASS: MLIR :: mlir-tblgen/llvm-intrinsics.td (3574 of 3577)
command timed out: 1200 seconds without output running [b'ninja', b'check-mlir'], attempting to kill
process killed by signal 9
program finished with exit code -1
elapsedTime=2145.035424

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants