diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td index 309079e549846..46a705eefc262 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -24,8 +24,25 @@ class SPIRV_ArithmeticBinaryOp])> { - // In addition to normal types arithmetic instructions can support cooperative - // matrix. + let arguments = (ins + SPIRV_ScalarOrVectorOf:$operand1, + SPIRV_ScalarOrVectorOf:$operand2 + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + let assemblyFormat = "operands attr-dict `:` type($result)"; +} + +class SPIRV_ArithmeticBinaryOpWithCoopMatrix traits = []> : + // Operands type same as result type. + SPIRV_BinaryOp])> { + // In addition to normal types these arithmetic instructions can support + // cooperative matrix. let arguments = (ins SPIRV_ScalarOrVectorOrCoopMatrixOf:$operand1, SPIRV_ScalarOrVectorOrCoopMatrixOf:$operand2 @@ -82,7 +99,7 @@ class SPIRV_ArithmeticExtendedBinaryOp { +def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FAdd", SPIRV_Float, [Commutative]> { let summary = "Floating-point addition of Operand 1 and Operand 2."; let description = [{ @@ -104,7 +121,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> // ----- -def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> { +def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FDiv", SPIRV_Float, []> { let summary = "Floating-point division of Operand 1 divided by Operand 2."; let description = [{ @@ -154,7 +171,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> { // ----- -def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> { +def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FMul", SPIRV_Float, [Commutative]> { let summary = "Floating-point multiplication of Operand 1 and Operand 2."; let description = [{ @@ -229,7 +246,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> { // ----- -def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> { +def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FSub", SPIRV_Float, []> { let summary = "Floating-point subtraction of Operand 2 from Operand 1."; let description = [{ @@ -251,9 +268,9 @@ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> { // ----- -def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd", - SPIRV_Integer, - [Commutative, UsableInSpecConstantOp]> { +def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IAdd", + SPIRV_Integer, + [Commutative, UsableInSpecConstantOp]> { let summary = "Integer addition of Operand 1 and Operand 2."; let description = [{ @@ -322,9 +339,9 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry", // ----- -def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul", - SPIRV_Integer, - [Commutative, UsableInSpecConstantOp]> { +def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IMul", + SPIRV_Integer, + [Commutative, UsableInSpecConstantOp]> { let summary = "Integer multiplication of Operand 1 and Operand 2."; let description = [{ @@ -354,9 +371,9 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul", // ----- -def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub", - SPIRV_Integer, - [UsableInSpecConstantOp]> { +def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"ISub", + SPIRV_Integer, + [UsableInSpecConstantOp]> { let summary = "Integer subtraction of Operand 2 from Operand 1."; let description = [{ @@ -460,9 +477,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot", // ----- -def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv", - SPIRV_Integer, - [UsableInSpecConstantOp]> { +def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"SDiv", + SPIRV_Integer, + [UsableInSpecConstantOp]> { let summary = "Signed-integer division of Operand 1 divided by Operand 2."; let description = [{ @@ -622,9 +639,9 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem", // ----- -def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv", - SPIRV_Integer, - [UnsignedOp, UsableInSpecConstantOp]> { +def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"UDiv", + SPIRV_Integer, + [UnsignedOp, UsableInSpecConstantOp]> { let summary = "Unsigned-integer division of Operand 1 divided by Operand 2."; let description = [{ diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir index 8733ff93768ab..6aff7b5039638 100644 --- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir @@ -549,3 +549,46 @@ spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, Matrix %p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16 spirv.Return } + +// ----- + +// These binary arithmetic instructions do not support coop matrix operands. + +spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" { + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + %p = spirv.FMod %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA> + spirv.Return +} + +// ----- + +spirv.func @frem(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" { + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + %p = spirv.FRem %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA> + spirv.Return +} + +// ----- +spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" { + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} + %p = spirv.SMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> + spirv.Return +} + +// ----- + +spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" { + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} + %p = spirv.SRem %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> + spirv.Return +} + +// ----- + +spirv.func @umod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" { + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} + %p = spirv.UMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> + spirv.Return +} + +// -----