From b06e1e90761273a3f5ea0f690ee3b819a677367a Mon Sep 17 00:00:00 2001 From: fairywreath Date: Sun, 6 Jul 2025 22:11:58 -0400 Subject: [PATCH 1/2] [mlir][spirv] Reject cooperative matrix operands on unsupported arithmetic ops --- .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 57 ++++++++++++------- .../SPIRV/IR/khr-cooperative-matrix-ops.mlir | 43 ++++++++++++++ 2 files changed, 81 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td index 309079e549846..2601debce3520 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -26,6 +26,25 @@ class SPIRV_ArithmeticBinaryOp])> { // In addition to normal types arithmetic instructions can support cooperative // matrix. + let arguments = (ins + SPIRV_ScalarOrVectorOf:$operand1, + SPIRV_ScalarOrVectorOf:$operand2 + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + let assemblyFormat = "operands attr-dict `:` type($result)"; +} + +class SPIRV_ArithmeticBinaryOpWithCoopMatrix traits = []> : + // Operands type same as result type. + SPIRV_BinaryOp])> { + // In addition to normal types these arithmetic instructions can support + // cooperative matrix. let arguments = (ins SPIRV_ScalarOrVectorOrCoopMatrixOf:$operand1, SPIRV_ScalarOrVectorOrCoopMatrixOf:$operand2 @@ -82,7 +101,7 @@ class SPIRV_ArithmeticExtendedBinaryOp { +def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FAdd", SPIRV_Float, [Commutative]> { let summary = "Floating-point addition of Operand 1 and Operand 2."; let description = [{ @@ -104,7 +123,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> // ----- -def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> { +def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FDiv", SPIRV_Float, []> { let summary = "Floating-point division of Operand 1 divided by Operand 2."; let description = [{ @@ -154,7 +173,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> { // ----- -def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> { +def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FMul", SPIRV_Float, [Commutative]> { let summary = "Floating-point multiplication of Operand 1 and Operand 2."; let description = [{ @@ -229,7 +248,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> { // ----- -def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> { +def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FSub", SPIRV_Float, []> { let summary = "Floating-point subtraction of Operand 2 from Operand 1."; let description = [{ @@ -251,9 +270,9 @@ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> { // ----- -def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd", - SPIRV_Integer, - [Commutative, UsableInSpecConstantOp]> { +def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IAdd", + SPIRV_Integer, + [Commutative, UsableInSpecConstantOp]> { let summary = "Integer addition of Operand 1 and Operand 2."; let description = [{ @@ -322,9 +341,9 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry", // ----- -def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul", - SPIRV_Integer, - [Commutative, UsableInSpecConstantOp]> { +def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IMul", + SPIRV_Integer, + [Commutative, UsableInSpecConstantOp]> { let summary = "Integer multiplication of Operand 1 and Operand 2."; let description = [{ @@ -354,9 +373,9 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul", // ----- -def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub", - SPIRV_Integer, - [UsableInSpecConstantOp]> { +def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"ISub", + SPIRV_Integer, + [UsableInSpecConstantOp]> { let summary = "Integer subtraction of Operand 2 from Operand 1."; let description = [{ @@ -460,9 +479,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot", // ----- -def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv", - SPIRV_Integer, - [UsableInSpecConstantOp]> { +def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"SDiv", + SPIRV_Integer, + [UsableInSpecConstantOp]> { let summary = "Signed-integer division of Operand 1 divided by Operand 2."; let description = [{ @@ -622,9 +641,9 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem", // ----- -def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv", - SPIRV_Integer, - [UnsignedOp, UsableInSpecConstantOp]> { +def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"UDiv", + SPIRV_Integer, + [UnsignedOp, UsableInSpecConstantOp]> { let summary = "Unsigned-integer division of Operand 1 divided by Operand 2."; let description = [{ diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir index 8733ff93768ab..6aff7b5039638 100644 --- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir @@ -549,3 +549,46 @@ spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, Matrix %p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16 spirv.Return } + +// ----- + +// These binary arithmetic instructions do not support coop matrix operands. + +spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" { + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + %p = spirv.FMod %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA> + spirv.Return +} + +// ----- + +spirv.func @frem(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" { + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + %p = spirv.FRem %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA> + spirv.Return +} + +// ----- +spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" { + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} + %p = spirv.SMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> + spirv.Return +} + +// ----- + +spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" { + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} + %p = spirv.SRem %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> + spirv.Return +} + +// ----- + +spirv.func @umod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" { + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} + %p = spirv.UMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> + spirv.Return +} + +// ----- From d6f6a66d2bccf17987a1db16d1c98d4233c66ff5 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Mon, 7 Jul 2025 15:48:03 -0600 Subject: [PATCH 2/2] Remove stale comment --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td index 2601debce3520..46a705eefc262 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -24,8 +24,6 @@ class SPIRV_ArithmeticBinaryOp])> { - // In addition to normal types arithmetic instructions can support cooperative - // matrix. let arguments = (ins SPIRV_ScalarOrVectorOf:$operand1, SPIRV_ScalarOrVectorOf:$operand2