Skip to content

Conversation

@fairywreath
Copy link
Contributor

Cooperative matrix operands are only supported for add/sub/mul/div binary arithmetic ops, but currently all binary arithmetic ops accept cooperative matrix operands, including mod/rem. This change fixes this behaviour.

@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Darren Wihandi (fairywreath)

Changes

Cooperative matrix operands are only supported for add/sub/mul/div binary arithmetic ops, but currently all binary arithmetic ops accept cooperative matrix operands, including mod/rem. This change fixes this behaviour.


Full diff: https://github.com/llvm/llvm-project/pull/147230.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td (+38-19)
  • (modified) mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir (+43)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 309079e549846..2601debce3520 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -26,6 +26,25 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
                                [Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
   // In addition to normal types arithmetic instructions can support cooperative
   // matrix.
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<type>:$operand1,
+    SPIRV_ScalarOrVectorOf<type>:$operand2
+  );
+
+  let results = (outs
+    SPIRV_ScalarOrVectorOf<type>:$result
+  );
+  let assemblyFormat = "operands attr-dict `:` type($result)";
+}
+
+class SPIRV_ArithmeticBinaryOpWithCoopMatrix<string mnemonic, Type type,
+                                             list<Trait> traits = []> :
+      // Operands type same as result type.
+      SPIRV_BinaryOp<mnemonic, type, type,
+                   !listconcat(traits,
+                               [Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
+  // In addition to normal types these arithmetic instructions can support
+  // cooperative matrix.
   let arguments = (ins
     SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
     SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
@@ -82,7 +101,7 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
 
 // -----
 
-def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> {
+def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FAdd", SPIRV_Float, [Commutative]> {
   let summary = "Floating-point addition of Operand 1 and Operand 2.";
 
   let description = [{
@@ -104,7 +123,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]>
 
 // -----
 
-def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> {
+def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FDiv", SPIRV_Float, []> {
   let summary = "Floating-point division of Operand 1 divided by Operand 2.";
 
   let description = [{
@@ -154,7 +173,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> {
 
 // -----
 
-def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> {
+def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FMul", SPIRV_Float, [Commutative]> {
   let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
 
   let description = [{
@@ -229,7 +248,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> {
 
 // -----
 
-def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
+def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FSub", SPIRV_Float, []> {
   let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
 
   let description = [{
@@ -251,9 +270,9 @@ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
 
 // -----
 
-def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd",
-                                        SPIRV_Integer,
-                                        [Commutative, UsableInSpecConstantOp]> {
+def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IAdd",
+                                                          SPIRV_Integer,
+                                                          [Commutative, UsableInSpecConstantOp]> {
   let summary = "Integer addition of Operand 1 and Operand 2.";
 
   let description = [{
@@ -322,9 +341,9 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
 
 // -----
 
-def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
-                                        SPIRV_Integer,
-                                        [Commutative, UsableInSpecConstantOp]> {
+def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IMul",
+                                                          SPIRV_Integer,
+                                                          [Commutative, UsableInSpecConstantOp]> {
   let summary = "Integer multiplication of Operand 1 and Operand 2.";
 
   let description = [{
@@ -354,9 +373,9 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
 
 // -----
 
-def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
-                                        SPIRV_Integer,
-                                        [UsableInSpecConstantOp]> {
+def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"ISub",
+                                                          SPIRV_Integer,
+                                                          [UsableInSpecConstantOp]> {
   let summary = "Integer subtraction of Operand 2 from Operand 1.";
 
   let description = [{
@@ -460,9 +479,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
 
 // -----
 
-def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
-                                        SPIRV_Integer,
-                                        [UsableInSpecConstantOp]> {
+def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"SDiv",
+                                                          SPIRV_Integer,
+                                                          [UsableInSpecConstantOp]> {
   let summary = "Signed-integer division of Operand 1 divided by Operand 2.";
 
   let description = [{
@@ -622,9 +641,9 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
 
 // -----
 
-def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
-                                        SPIRV_Integer,
-                                        [UnsignedOp, UsableInSpecConstantOp]> {
+def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"UDiv",
+                                                          SPIRV_Integer,
+                                                          [UnsignedOp, UsableInSpecConstantOp]> {
   let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
 
   let description = [{
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 8733ff93768ab..6aff7b5039638 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -549,3 +549,46 @@ spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, Matrix
   %p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16
   spirv.Return
 }
+
+// -----
+
+// These binary arithmetic instructions do not support coop matrix operands.
+
+spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+  %p = spirv.FMod %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @frem(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+  %p = spirv.FRem %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
+  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
+  %p = spirv.SMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
+  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
+  %p = spirv.SRem %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @umod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
+  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
+  %p = spirv.UMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks for addressing this.

@kuhar kuhar merged commit 4a68562 into llvm:main Jul 8, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants