-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][spirv] Reject coop matrix operands on unsupported arithmetic ops #147230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Darren Wihandi (fairywreath) ChangesCooperative matrix operands are only supported for Full diff: https://github.com/llvm/llvm-project/pull/147230.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 309079e549846..2601debce3520 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -26,6 +26,25 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
[Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<type>:$operand1,
+ SPIRV_ScalarOrVectorOf<type>:$operand2
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOf<type>:$result
+ );
+ let assemblyFormat = "operands attr-dict `:` type($result)";
+}
+
+class SPIRV_ArithmeticBinaryOpWithCoopMatrix<string mnemonic, Type type,
+ list<Trait> traits = []> :
+ // Operands type same as result type.
+ SPIRV_BinaryOp<mnemonic, type, type,
+ !listconcat(traits,
+ [Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
+ // In addition to normal types these arithmetic instructions can support
+ // cooperative matrix.
let arguments = (ins
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
@@ -82,7 +101,7 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
// -----
-def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> {
+def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FAdd", SPIRV_Float, [Commutative]> {
let summary = "Floating-point addition of Operand 1 and Operand 2.";
let description = [{
@@ -104,7 +123,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]>
// -----
-def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> {
+def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FDiv", SPIRV_Float, []> {
let summary = "Floating-point division of Operand 1 divided by Operand 2.";
let description = [{
@@ -154,7 +173,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> {
// -----
-def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> {
+def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FMul", SPIRV_Float, [Commutative]> {
let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
let description = [{
@@ -229,7 +248,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> {
// -----
-def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
+def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FSub", SPIRV_Float, []> {
let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
let description = [{
@@ -251,9 +270,9 @@ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
// -----
-def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd",
- SPIRV_Integer,
- [Commutative, UsableInSpecConstantOp]> {
+def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IAdd",
+ SPIRV_Integer,
+ [Commutative, UsableInSpecConstantOp]> {
let summary = "Integer addition of Operand 1 and Operand 2.";
let description = [{
@@ -322,9 +341,9 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
// -----
-def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
- SPIRV_Integer,
- [Commutative, UsableInSpecConstantOp]> {
+def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IMul",
+ SPIRV_Integer,
+ [Commutative, UsableInSpecConstantOp]> {
let summary = "Integer multiplication of Operand 1 and Operand 2.";
let description = [{
@@ -354,9 +373,9 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
// -----
-def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
- SPIRV_Integer,
- [UsableInSpecConstantOp]> {
+def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"ISub",
+ SPIRV_Integer,
+ [UsableInSpecConstantOp]> {
let summary = "Integer subtraction of Operand 2 from Operand 1.";
let description = [{
@@ -460,9 +479,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
// -----
-def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
- SPIRV_Integer,
- [UsableInSpecConstantOp]> {
+def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"SDiv",
+ SPIRV_Integer,
+ [UsableInSpecConstantOp]> {
let summary = "Signed-integer division of Operand 1 divided by Operand 2.";
let description = [{
@@ -622,9 +641,9 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
// -----
-def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
- SPIRV_Integer,
- [UnsignedOp, UsableInSpecConstantOp]> {
+def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"UDiv",
+ SPIRV_Integer,
+ [UnsignedOp, UsableInSpecConstantOp]> {
let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
let description = [{
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 8733ff93768ab..6aff7b5039638 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -549,3 +549,46 @@ spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, Matrix
%p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16
spirv.Return
}
+
+// -----
+
+// These binary arithmetic instructions do not support coop matrix operands.
+
+spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ %p = spirv.FMod %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @frem(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ %p = spirv.FRem %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
+ %p = spirv.SMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
+ %p = spirv.SRem %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @umod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
+ %p = spirv.UMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
|
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, thanks for addressing this.
Cooperative matrix operands are only supported for
add/sub/mul/divbinary arithmetic ops, but currently all binary arithmetic ops accept cooperative matrix operands, includingmod/rem. This change fixes this behaviour.