@@ -26,6 +26,25 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
2626 [Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
2727 // In addition to normal types arithmetic instructions can support cooperative
2828 // matrix.
29+ let arguments = (ins
30+ SPIRV_ScalarOrVectorOf<type>:$operand1,
31+ SPIRV_ScalarOrVectorOf<type>:$operand2
32+ );
33+
34+ let results = (outs
35+ SPIRV_ScalarOrVectorOf<type>:$result
36+ );
37+ let assemblyFormat = "operands attr-dict `:` type($result)";
38+ }
39+
40+ class SPIRV_ArithmeticBinaryOpWithCoopMatrix<string mnemonic, Type type,
41+ list<Trait> traits = []> :
42+ // Operands type same as result type.
43+ SPIRV_BinaryOp<mnemonic, type, type,
44+ !listconcat(traits,
45+ [Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
46+ // In addition to normal types these arithmetic instructions can support
47+ // cooperative matrix.
2948 let arguments = (ins
3049 SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
3150 SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
@@ -82,7 +101,7 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
82101
83102// -----
84103
85- def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp <"FAdd", SPIRV_Float, [Commutative]> {
104+ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix <"FAdd", SPIRV_Float, [Commutative]> {
86105 let summary = "Floating-point addition of Operand 1 and Operand 2.";
87106
88107 let description = [{
@@ -104,7 +123,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]>
104123
105124// -----
106125
107- def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp <"FDiv", SPIRV_Float, []> {
126+ def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix <"FDiv", SPIRV_Float, []> {
108127 let summary = "Floating-point division of Operand 1 divided by Operand 2.";
109128
110129 let description = [{
@@ -154,7 +173,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> {
154173
155174// -----
156175
157- def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp <"FMul", SPIRV_Float, [Commutative]> {
176+ def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix <"FMul", SPIRV_Float, [Commutative]> {
158177 let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
159178
160179 let description = [{
@@ -229,7 +248,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> {
229248
230249// -----
231250
232- def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp <"FSub", SPIRV_Float, []> {
251+ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix <"FSub", SPIRV_Float, []> {
233252 let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
234253
235254 let description = [{
@@ -251,9 +270,9 @@ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
251270
252271// -----
253272
254- def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp <"IAdd",
255- SPIRV_Integer,
256- [Commutative, UsableInSpecConstantOp]> {
273+ def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix <"IAdd",
274+ SPIRV_Integer,
275+ [Commutative, UsableInSpecConstantOp]> {
257276 let summary = "Integer addition of Operand 1 and Operand 2.";
258277
259278 let description = [{
@@ -322,9 +341,9 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
322341
323342// -----
324343
325- def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp <"IMul",
326- SPIRV_Integer,
327- [Commutative, UsableInSpecConstantOp]> {
344+ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix <"IMul",
345+ SPIRV_Integer,
346+ [Commutative, UsableInSpecConstantOp]> {
328347 let summary = "Integer multiplication of Operand 1 and Operand 2.";
329348
330349 let description = [{
@@ -354,9 +373,9 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
354373
355374// -----
356375
357- def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp <"ISub",
358- SPIRV_Integer,
359- [UsableInSpecConstantOp]> {
376+ def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix <"ISub",
377+ SPIRV_Integer,
378+ [UsableInSpecConstantOp]> {
360379 let summary = "Integer subtraction of Operand 2 from Operand 1.";
361380
362381 let description = [{
@@ -460,9 +479,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
460479
461480// -----
462481
463- def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp <"SDiv",
464- SPIRV_Integer,
465- [UsableInSpecConstantOp]> {
482+ def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix <"SDiv",
483+ SPIRV_Integer,
484+ [UsableInSpecConstantOp]> {
466485 let summary = "Signed-integer division of Operand 1 divided by Operand 2.";
467486
468487 let description = [{
@@ -622,9 +641,9 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
622641
623642// -----
624643
625- def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp <"UDiv",
626- SPIRV_Integer,
627- [UnsignedOp, UsableInSpecConstantOp]> {
644+ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix <"UDiv",
645+ SPIRV_Integer,
646+ [UnsignedOp, UsableInSpecConstantOp]> {
628647 let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
629648
630649 let description = [{
0 commit comments