Skip to content

Commit b06e1e9

Browse files
committed
[mlir][spirv] Reject cooperative matrix operands on unsupported arithmetic ops
1 parent cd46354 commit b06e1e9

File tree

2 files changed

+81
-19
lines changed

2 files changed

+81
-19
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,25 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
2626
[Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
2727
// In addition to normal types arithmetic instructions can support cooperative
2828
// matrix.
29+
let arguments = (ins
30+
SPIRV_ScalarOrVectorOf<type>:$operand1,
31+
SPIRV_ScalarOrVectorOf<type>:$operand2
32+
);
33+
34+
let results = (outs
35+
SPIRV_ScalarOrVectorOf<type>:$result
36+
);
37+
let assemblyFormat = "operands attr-dict `:` type($result)";
38+
}
39+
40+
class SPIRV_ArithmeticBinaryOpWithCoopMatrix<string mnemonic, Type type,
41+
list<Trait> traits = []> :
42+
// Operands type same as result type.
43+
SPIRV_BinaryOp<mnemonic, type, type,
44+
!listconcat(traits,
45+
[Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
46+
// In addition to normal types these arithmetic instructions can support
47+
// cooperative matrix.
2948
let arguments = (ins
3049
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
3150
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
@@ -82,7 +101,7 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
82101

83102
// -----
84103

85-
def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> {
104+
def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FAdd", SPIRV_Float, [Commutative]> {
86105
let summary = "Floating-point addition of Operand 1 and Operand 2.";
87106

88107
let description = [{
@@ -104,7 +123,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]>
104123

105124
// -----
106125

107-
def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> {
126+
def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FDiv", SPIRV_Float, []> {
108127
let summary = "Floating-point division of Operand 1 divided by Operand 2.";
109128

110129
let description = [{
@@ -154,7 +173,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> {
154173

155174
// -----
156175

157-
def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> {
176+
def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FMul", SPIRV_Float, [Commutative]> {
158177
let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
159178

160179
let description = [{
@@ -229,7 +248,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> {
229248

230249
// -----
231250

232-
def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
251+
def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FSub", SPIRV_Float, []> {
233252
let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
234253

235254
let description = [{
@@ -251,9 +270,9 @@ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
251270

252271
// -----
253272

254-
def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd",
255-
SPIRV_Integer,
256-
[Commutative, UsableInSpecConstantOp]> {
273+
def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IAdd",
274+
SPIRV_Integer,
275+
[Commutative, UsableInSpecConstantOp]> {
257276
let summary = "Integer addition of Operand 1 and Operand 2.";
258277

259278
let description = [{
@@ -322,9 +341,9 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
322341

323342
// -----
324343

325-
def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
326-
SPIRV_Integer,
327-
[Commutative, UsableInSpecConstantOp]> {
344+
def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IMul",
345+
SPIRV_Integer,
346+
[Commutative, UsableInSpecConstantOp]> {
328347
let summary = "Integer multiplication of Operand 1 and Operand 2.";
329348

330349
let description = [{
@@ -354,9 +373,9 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
354373

355374
// -----
356375

357-
def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
358-
SPIRV_Integer,
359-
[UsableInSpecConstantOp]> {
376+
def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"ISub",
377+
SPIRV_Integer,
378+
[UsableInSpecConstantOp]> {
360379
let summary = "Integer subtraction of Operand 2 from Operand 1.";
361380

362381
let description = [{
@@ -460,9 +479,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
460479

461480
// -----
462481

463-
def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
464-
SPIRV_Integer,
465-
[UsableInSpecConstantOp]> {
482+
def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"SDiv",
483+
SPIRV_Integer,
484+
[UsableInSpecConstantOp]> {
466485
let summary = "Signed-integer division of Operand 1 divided by Operand 2.";
467486

468487
let description = [{
@@ -622,9 +641,9 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
622641

623642
// -----
624643

625-
def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
626-
SPIRV_Integer,
627-
[UnsignedOp, UsableInSpecConstantOp]> {
644+
def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"UDiv",
645+
SPIRV_Integer,
646+
[UnsignedOp, UsableInSpecConstantOp]> {
628647
let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
629648

630649
let description = [{

mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,3 +549,46 @@ spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, Matrix
549549
%p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16
550550
spirv.Return
551551
}
552+
553+
// -----
554+
555+
// These binary arithmetic instructions do not support coop matrix operands.
556+
557+
spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
558+
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
559+
%p = spirv.FMod %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
560+
spirv.Return
561+
}
562+
563+
// -----
564+
565+
spirv.func @frem(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
566+
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
567+
%p = spirv.FRem %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
568+
spirv.Return
569+
}
570+
571+
// -----
572+
spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
573+
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
574+
%p = spirv.SMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
575+
spirv.Return
576+
}
577+
578+
// -----
579+
580+
spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
581+
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
582+
%p = spirv.SRem %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
583+
spirv.Return
584+
}
585+
586+
// -----
587+
588+
spirv.func @umod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
589+
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
590+
%p = spirv.UMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
591+
spirv.Return
592+
}
593+
594+
// -----

0 commit comments

Comments
 (0)