Skip to content

Commit 1b70587

Browse files
authored
[mlir][spirv] Update integer dot product op syntax (#73468)
Make the syntax more concise and aligned with the `spirv.Dot` syntax in #73466. Move some type verification from C++ to ODS. Regexes to update existing code and tests: `(\s*\{format\s+=\s+#spirv.packed_vector_format([^}]+)\})` ==> `, $2` `(spirv.[SU]+Dot[a-zA-Z]*[^:]+:)(\s*\(([^,]+),[^\)]+\))(.+)` ==> `$1 $3$4`
1 parent 42d669f commit 1b70587

File tree

5 files changed

+125
-143
lines changed

5 files changed

+125
-143
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@ class SPIRV_IntegerDotProductOp<string mnemonic,
2626
SPIRV_Integer:$result
2727
);
2828

29-
let assemblyFormat = [{
30-
operands attr-dict `:` `(` type(operands) `)` `->` type($result)
31-
}];
32-
3329
// These ops require dynamic availability specification based on operand and
3430
// result types.
3531
bit autogenAvailability = 0;
@@ -40,23 +36,36 @@ class SPIRV_IntegerDotProductOp<string mnemonic,
4036

4137
class SPIRV_IntegerDotProductBinaryOp<string mnemonic,
4238
list<Trait> traits = []> :
43-
SPIRV_IntegerDotProductOp<mnemonic, traits> {
39+
SPIRV_IntegerDotProductOp<mnemonic,
40+
!listconcat(traits, [AllTypesMatch<["vector1", "vector2"]>])> {
4441
let arguments = (ins
4542
SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector1,
4643
SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector2,
4744
OptionalAttr<SPIRV_PackedVectorFormatAttr>:$format
4845
);
46+
47+
let assemblyFormat = [{
48+
$vector1 `,` $vector2 ( `,` $format^ )? attr-dict `:`
49+
type($vector1) `->` type($result)
50+
}];
4951
}
5052

5153
class SPIRV_IntegerDotProductTernaryOp<string mnemonic,
5254
list<Trait> traits = []> :
53-
SPIRV_IntegerDotProductOp<mnemonic, traits> {
55+
SPIRV_IntegerDotProductOp<mnemonic,
56+
!listconcat(traits, [AllTypesMatch<["vector1", "vector2"]>,
57+
AllTypesMatch<["accumulator", "result"]>])> {
5458
let arguments = (ins
5559
SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector1,
5660
SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector2,
5761
SPIRV_Integer:$accumulator,
5862
OptionalAttr<SPIRV_PackedVectorFormatAttr>:$format
5963
);
64+
65+
let assemblyFormat = [{
66+
$vector1 `,` $vector2 `,` $accumulator ( `,` $format^ )? attr-dict `:`
67+
type($vector1) `->` type($result)
68+
}];
6069
}
6170

6271
// -----
@@ -92,9 +101,9 @@ def SPIRV_SDotOp : SPIRV_IntegerDotProductBinaryOp<"SDot",
92101
#### Example:
93102

94103
```mlir
95-
%r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
96-
%r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
97-
%r = spirv.SDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
104+
%r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
105+
%r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
106+
%r = spirv.SDot %a, %b : vector<4xi8> -> i32
98107
```
99108
}];
100109
}
@@ -138,9 +147,9 @@ def SPIRV_SUDotOp : SPIRV_IntegerDotProductBinaryOp<"SUDot",
138147
#### Example:
139148

140149
```mlir
141-
%r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
142-
%r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
143-
%r = spirv.SUDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
150+
%r = spirv.SUDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
151+
%r = spirv.SUDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
152+
%r = spirv.SUDot %a, %b : vector<4xi8> -> i32
144153
```
145154
}];
146155
}
@@ -180,9 +189,9 @@ def SPIRV_UDotOp : SPIRV_IntegerDotProductBinaryOp<"UDot",
180189
#### Example:
181190

182191
```mlir
183-
%r = spirv.UDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
184-
%r = spirv.UDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
185-
%r = spirv.UDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
192+
%r = spirv.UDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
193+
%r = spirv.UDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
194+
%r = spirv.UDot %a, %b : vector<4xi8> -> i32
186195
```
187196
}];
188197
}
@@ -228,9 +237,9 @@ def SPIRV_SDotAccSatOp : SPIRV_IntegerDotProductTernaryOp<"SDotAccSat",
228237
#### Example:
229238

230239
```mlir
231-
%r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
232-
%r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
233-
%r = spirv.SDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
240+
%r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
241+
%r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
242+
%r = spirv.SDotAccSat %a, %b, %acc : vector<4xi8> -> i32
234243
```
235244
}];
236245
}
@@ -280,9 +289,9 @@ def SPIRV_SUDotAccSatOp : SPIRV_IntegerDotProductTernaryOp<"SUDotAccSat",
280289
#### Example:
281290

282291
```mlir
283-
%r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
284-
%r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
285-
%r = spirv.SUDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
292+
%r = spirv.SUDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
293+
%r = spirv.SUDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
294+
%r = spirv.SUDotAccSat %a, %b, %acc : vector<4xi8> -> i32
286295
```
287296
}];
288297
}
@@ -330,9 +339,9 @@ def SPIRV_UDotAccSatOp :
330339
#### Example:
331340

332341
```mlir
333-
%r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i32) -> i32
334-
%r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32, i64) -> i64
335-
%r = spirv.UDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32
342+
%r = spirv.UDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
343+
%r = spirv.UDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
344+
%r = spirv.UDotAccSat %a, %b, %acc : vector<4xi8> -> i32
336345
```
337346
}];
338347
}

mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,10 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
3030
"Not an integer dot product op?");
3131
assert(op->getNumResults() == 1 && "Expected a single result");
3232

33+
// ODS enforces that vector 1 and vector 2, and result and the accumulator
34+
// have the same types.
3335
Type factorTy = op->getOperand(0).getType();
34-
if (op->getOperand(1).getType() != factorTy)
35-
return op->emitOpError("requires the same type for both vector operands");
36-
37-
unsigned expectedNumAttrs = 0;
3836
if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
39-
++expectedNumAttrs;
4037
auto packedVectorFormat =
4138
llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
4239
op->getAttr(kPackedVectorFormatAttrName));
@@ -59,16 +56,7 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
5956
factorTy));
6057
}
6158

62-
if (op->getAttrs().size() > expectedNumAttrs)
63-
return op->emitError(
64-
"op only supports the 'format' #spirv.packed_vector_format attribute");
65-
6659
Type resultTy = op->getResultTypes().front();
67-
bool hasAccumulator = op->getNumOperands() == 3;
68-
if (hasAccumulator && op->getOperand(2).getType() != resultTy)
69-
return op->emitOpError(
70-
"requires the same accumulator operand and result types");
71-
7260
unsigned factorBitWidth = getBitWidth(factorTy);
7361
unsigned resultBitWidth = getBitWidth(resultTy);
7462
if (factorBitWidth > resultBitWidth)

mlir/test/Conversion/VectorToSPIRV/vector-reduction-to-spirv-dot-prod.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
// CHECK-LABEL: func.func @to_sdot
77
// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
8-
// CHECK-NEXT: [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
8+
// CHECK-NEXT: [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32
99
// CHECK-NEXT: return [[DOT]] : i32
1010
func.func @to_sdot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
1111
%lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
@@ -17,7 +17,7 @@ func.func @to_sdot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
1717

1818
// CHECK-LABEL: func.func @to_sdot_acc
1919
// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
20-
// CHECK-NEXT: [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
20+
// CHECK-NEXT: [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32
2121
// CHECK-NEXT: return [[DOT]] : i32
2222
func.func @to_sdot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
2323
%lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
@@ -29,7 +29,7 @@ func.func @to_sdot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i
2929

3030
// CHECK-LABEL: func.func @to_sdot_i64
3131
// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
32-
// CHECK-NEXT: [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i64
32+
// CHECK-NEXT: [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i64
3333
// CHECK-NEXT: return [[DOT]] : i64
3434
func.func @to_sdot_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i64 {
3535
%lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
@@ -41,7 +41,7 @@ func.func @to_sdot_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i64 {
4141

4242
// CHECK-LABEL: func.func @to_sdot_acc_i64
4343
// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i64)
44-
// CHECK-NEXT: [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i64) -> i64
44+
// CHECK-NEXT: [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i64
4545
// CHECK-NEXT: return [[DOT]] : i64
4646
func.func @to_sdot_acc_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i64) -> i64 {
4747
%lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
@@ -53,7 +53,7 @@ func.func @to_sdot_acc_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i64)
5353

5454
// CHECK-LABEL: func.func @to_udot
5555
// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
56-
// CHECK-NEXT: [[DOT:%.+]] = spirv.UDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
56+
// CHECK-NEXT: [[DOT:%.+]] = spirv.UDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32
5757
// CHECK-NEXT: return [[DOT]] : i32
5858
func.func @to_udot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
5959
%lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
@@ -65,7 +65,7 @@ func.func @to_udot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
6565

6666
// CHECK-LABEL: func.func @to_udot_acc
6767
// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
68-
// CHECK-NEXT: [[DOT:%.+]] = spirv.UDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
68+
// CHECK-NEXT: [[DOT:%.+]] = spirv.UDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32
6969
// CHECK-NEXT: return [[DOT]] : i32
7070
func.func @to_udot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
7171
%lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
@@ -77,7 +77,7 @@ func.func @to_udot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i
7777

7878
// CHECK-LABEL: func.func @to_signed_unsigned_dot
7979
// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
80-
// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
80+
// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDot [[ARG0]], [[ARG1]] : vector<4xi8> -> i32
8181
// CHECK-NEXT: return [[DOT]] : i32
8282
func.func @to_signed_unsigned_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
8383
%lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
@@ -89,7 +89,7 @@ func.func @to_signed_unsigned_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i
8989

9090
// CHECK-LABEL: func.func @to_signed_unsigned_dot_acc
9191
// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
92-
// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
92+
// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : vector<4xi8> -> i32
9393
// CHECK-NEXT: return [[DOT]] : i32
9494
func.func @to_signed_unsigned_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
9595
%lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
@@ -101,7 +101,7 @@ func.func @to_signed_unsigned_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>,
101101

102102
// CHECK-LABEL: func.func @to_unsigned_signed_dot
103103
// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
104-
// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDot [[ARG1]], [[ARG0]] : (vector<4xi8>, vector<4xi8>) -> i32
104+
// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDot [[ARG1]], [[ARG0]] : vector<4xi8> -> i32
105105
// CHECK-NEXT: return [[DOT]] : i32
106106
func.func @to_unsigned_signed_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
107107
%lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
@@ -113,7 +113,7 @@ func.func @to_unsigned_signed_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i
113113

114114
// CHECK-LABEL: func.func @to_unsigned_signed_dot_acc
115115
// CHECK-SAME: ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
116-
// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDotAccSat [[ARG1]], [[ARG0]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
116+
// CHECK-NEXT: [[DOT:%.+]] = spirv.SUDotAccSat [[ARG1]], [[ARG0]], [[ACC]] : vector<4xi8> -> i32
117117
// CHECK-NEXT: return [[DOT]] : i32
118118
func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
119119
%lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
@@ -128,7 +128,7 @@ func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>,
128128
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i8
129129
// CHECK: %[[LHS:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
130130
// CHECK: %[[RHS:.+]] = spirv.CompositeConstruct %[[ARG1]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
131-
// CHECK: %[[SDOT:.+]] = spirv.SDot %[[LHS]], %[[RHS]] : (vector<4xi8>, vector<4xi8>) -> i32
131+
// CHECK: %[[SDOT:.+]] = spirv.SDot %[[LHS]], %[[RHS]] : vector<4xi8> -> i32
132132
// CHECK: return %[[SDOT]]
133133
func.func @to_sdot_vector3(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 {
134134
%lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32>

0 commit comments

Comments
 (0)