Skip to content

Commit 5ae57fe

Browse files
Resolve more review comments and expand testing
In particular add negative testing. Signed-off-by: Davide Grohmann <[email protected]> Change-Id: Iee4ba17c74b451eda7f76c6f905ca12c734d39d6
1 parent 3cf2ee9 commit 5ae57fe

File tree

4 files changed

+151
-53
lines changed

4 files changed

+151
-53
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
5757
spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
5858
region
5959
```
60+
61+
#### Example:
62+
63+
```mlir
64+
spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
65+
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
66+
}
67+
```
6068
}];
6169

6270
let arguments = (ins
@@ -114,6 +122,12 @@ def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [InGraphScope,
114122
```
115123
spv-graph-constant-arm-op ::= `spirv.ARM.GraphConstant` { graph_constant_id = 42 : i32 }
116124
```
125+
126+
#### Example:
127+
128+
```mlir
129+
%0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
130+
```
117131
}];
118132

119133
let arguments = (ins
@@ -157,6 +171,17 @@ def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleSc
157171
entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint`
158172
symbol-reference (`, ` symbol-reference)*
159173
```
174+
175+
#### Example:
176+
177+
```mlir
178+
spirv.GlobalVariable @arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
179+
spirv.GlobalVariable @res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
180+
spirv.ARM.GraphEntryPoint @graph, @arg_0, @res_0
181+
spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
182+
...
183+
}
184+
```
160185
}];
161186

162187
let arguments = (ins
@@ -166,6 +191,9 @@ def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleSc
166191

167192
let results = (outs);
168193

194+
// Checks for graph and interface symbol reference are done in spirv::ModuleOp verification.
195+
let hasVerifier = 0;
196+
169197
let autogenSerialization = 0;
170198

171199
let builders = [
@@ -189,6 +217,14 @@ def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pu
189217
```
190218
graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens
191219
```
220+
221+
#### Example:
222+
223+
```mlir
224+
spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
225+
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
226+
}
227+
```
192228
}];
193229

194230
let arguments = (ins

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,6 @@ def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
393393
def GraphType : Type<CPred<"::llvm::isa<::mlir::GraphType>($_self)">,
394394
"graph type", "::mlir::GraphType">;
395395

396-
397396
// A container type is a type that has another type embedded within it.
398397
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
399398
string descr, string cppType = "::mlir::Type"> :

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,12 +1174,6 @@ void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
11741174
}
11751175
}
11761176

1177-
LogicalResult spirv::GraphEntryPointARMOp::verify() {
1178-
// Checks for fn and interface symbol reference are done in spirv::ModuleOp
1179-
// verification.
1180-
return success();
1181-
}
1182-
11831177
//===----------------------------------------------------------------------===//
11841178
// spirv.GraphARM
11851179
//===----------------------------------------------------------------------===//
@@ -1257,7 +1251,19 @@ LogicalResult spirv::GraphARMOp::verifyType() {
12571251
}
12581252

12591253
LogicalResult spirv::GraphARMOp::verifyBody() {
1260-
GraphType grType = getFunctionType();
1254+
for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
1255+
if (!isa<spirv::TensorArmType>(graphArgType)) {
1256+
return emitOpError("type of argument #")
1257+
<< index << " must be a TensorArmType, but got " << graphArgType;
1258+
}
1259+
}
1260+
for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
1261+
if (!isa<spirv::TensorArmType>(graphResType)) {
1262+
return emitOpError("type of result #")
1263+
<< index << " must be a TensorArmType, but got " << graphResType;
1264+
}
1265+
}
1266+
12611267
if (!isExternal()) {
12621268
Block &entryBlock = front();
12631269

@@ -1277,15 +1283,17 @@ LogicalResult spirv::GraphARMOp::verifyBody() {
12771283
}
12781284
}
12791285

1286+
GraphType grType = getFunctionType();
12801287
auto walkResult = walk([grType](Operation *op) -> WalkResult {
12811288
if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
12821289
if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
1283-
return graphOutputsARMOp.emitOpError("has GraphOutputsARM returning ")
1290+
return graphOutputsARMOp.emitOpError("is returning ")
12841291
<< graphOutputsARMOp.getNumOperands()
1285-
<< "value(s) but enclosing graph requires "
1286-
<< grType.getNumResults() << " results";
1292+
<< " value(s) but enclosing spirv.ARM.Graph requires "
1293+
<< grType.getNumResults() << " result(s)";
12871294

1288-
auto graphOutputOperandTypes = graphOutputsARMOp.getValue().getType();
1295+
ValueTypeRange<OperandRange> graphOutputOperandTypes =
1296+
graphOutputsARMOp.getValue().getType();
12891297
for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
12901298
++i) {
12911299
Type graphOutputOperandType = graphOutputOperandTypes[i];
@@ -1341,15 +1349,15 @@ LogicalResult spirv::GraphOutputsARMOp::verify() {
13411349
const ArrayRef<Type> &results = graph.getFunctionType().getResults();
13421350
if (getNumOperands() != results.size())
13431351
return emitOpError("has ")
1344-
<< getNumOperands() << " operands, but enclosing graph (@"
1352+
<< getNumOperands() << " operands, but enclosing spirv.ARM.Graph (@"
13451353
<< graph.getName() << ") returns " << results.size();
13461354

13471355
for (unsigned i = 0, size = results.size(); i < size; ++i)
13481356
if (getOperand(i).getType() != results[i])
13491357
return emitError() << "type of return operand " << i << " ("
13501358
<< getOperand(i).getType()
1351-
<< ") doesn't match graph result type (" << results[i]
1352-
<< ")"
1359+
<< ") doesn't match spirv.ARM.Graph result type ("
1360+
<< results[i] << ")"
13531361
<< " in graph @" << graph.getName();
13541362

13551363
return success();

mlir/test/Dialect/SPIRV/IR/graph-ops.mlir

Lines changed: 93 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
// spirv.ARM.Graph and spirv.ARM.GraphOutputs
55
//===----------------------------------------------------------------------===//
66

7-
spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
8-
// CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
9-
spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
10-
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
11-
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
12-
}
7+
// CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
8+
spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
9+
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
10+
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
1311
}
1412

1513
// -----
@@ -18,52 +16,109 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8,
1816
// spirv.ARM.GraphConstant
1917
//===----------------------------------------------------------------------===//
2018

21-
spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
22-
// CHECK: spirv.ARM.Graph {{@.*}}() -> !spirv.arm.tensor<2x3xi16> {
23-
spirv.ARM.Graph @graphConstant() -> !spirv.arm.tensor<2x3xi16> {
24-
// CHECK: [[CONST:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
25-
%0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
26-
// CHECK: spirv.ARM.GraphOutputs [[CONST:%.*]] : !spirv.arm.tensor<2x3xi16>
27-
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16>
28-
}
19+
// CHECK: spirv.ARM.Graph {{@.*}}() -> !spirv.arm.tensor<2x3xi16> {
20+
spirv.ARM.Graph @graphConstant() -> !spirv.arm.tensor<2x3xi16> {
21+
// CHECK: [[CONST:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
22+
%0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
23+
// CHECK: spirv.ARM.GraphOutputs [[CONST:%.*]] : !spirv.arm.tensor<2x3xi16>
24+
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16>
2925
}
3026
// -----
3127

3228
//===----------------------------------------------------------------------===//
3329
// spirv.ARM.GraphEntryPoint
3430
//===----------------------------------------------------------------------===//
3531

32+
// CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
33+
spirv.GlobalVariable @entrypoint_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
34+
// CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
35+
spirv.GlobalVariable @entrypoint_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
36+
// CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
37+
spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0
3638

37-
spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
38-
// CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
39-
spirv.GlobalVariable @entrypoint_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
40-
// CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
41-
spirv.GlobalVariable @entrypoint_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
42-
// CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
43-
spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0
44-
// CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
45-
spirv.ARM.Graph @entrypoint(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
46-
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
47-
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
48-
}
39+
// -----
40+
41+
//===----------------------------------------------------------------------===//
42+
// spirv.ARM.Graph with no terminator
43+
//===----------------------------------------------------------------------===//
44+
45+
// expected-error @+1 {{empty block: expect at least a terminator}}
46+
spirv.ARM.Graph @graphNoterminator(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
47+
}
48+
49+
// -----
50+
51+
//===----------------------------------------------------------------------===//
52+
// spirv.ARM.Graph with no result types
53+
//===----------------------------------------------------------------------===//
54+
55+
// expected-error @+1 {{'spirv.ARM.Graph' op there should be at least one result}}
56+
spirv.ARM.Graph @graphNoOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> () {
4957
}
5058

5159
// -----
5260

5361
//===----------------------------------------------------------------------===//
54-
// Multiple spirv.ARM.Graphs
62+
// spirv.ARM.GraphConstant outside graph scope
5563
//===----------------------------------------------------------------------===//
5664

57-
spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
58-
// CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
59-
spirv.ARM.Graph @graph1(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
60-
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
61-
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
62-
}
65+
// expected-error @+1 {{'spirv.ARM.GraphConstant' op failed to verify that op must appear in a spirv.ARM.Graph op's block}}
66+
%0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
67+
// -----
68+
69+
//===----------------------------------------------------------------------===//
70+
// spirv.ARM.GraphOutputs outside graph scope
71+
//===----------------------------------------------------------------------===//
72+
73+
%0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
74+
// expected-error @+1 {{'spirv.ARM.GraphOutputs' op failed to verify that op must appear in a spirv.ARM.Graph op's block}}
75+
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1xi16>
76+
77+
// -----
78+
79+
//===----------------------------------------------------------------------===//
80+
// spirv.ARM.Graph return type does not match spirv.ARM.GraphOutputs
81+
//===----------------------------------------------------------------------===//
82+
83+
spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<5x3xi16> {
84+
// expected-error @+1 {{type of return operand 0 ('!spirv.arm.tensor<14x19xi16>') doesn't match graph result type ('!spirv.arm.tensor<5x3xi16>')}}
85+
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
86+
}
87+
88+
// -----
89+
90+
//===----------------------------------------------------------------------===//
91+
// spirv.ARM.Graph return type does not match number of results in spirv.ARM.GraphOutputs
92+
//===----------------------------------------------------------------------===//
93+
94+
spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> (!spirv.arm.tensor<14x19xi16>, !spirv.arm.tensor<14x19xi16>) {
95+
// expected-error @+1 {{'spirv.ARM.GraphOutputs' op is returning 1 value(s) but enclosing spirv.ARM.Graph requires 2 result(s)}}
96+
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
97+
}
98+
99+
// -----
100+
101+
spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
102+
// expected-error @+1 {{'spirv.ARM.GraphOutputs' op is returning 2 value(s) but enclosing spirv.ARM.Graph requires 1 result(s)}}
103+
spirv.ARM.GraphOutputs %arg0, %arg0 : !spirv.arm.tensor<14x19xi16>, !spirv.arm.tensor<14x19xi16>
104+
}
105+
106+
// -----
107+
108+
//===----------------------------------------------------------------------===//
109+
// spirv.ARM.Graph using a non TensorArmType argument
110+
//===----------------------------------------------------------------------===//
111+
112+
// expected-error @+1 {{'spirv.ARM.Graph' op type of argument #0 must be a TensorArmType, but got 'i8'}}
113+
spirv.ARM.Graph @graphAndOutputs(%arg0 : i8) -> !spirv.arm.tensor<14x19xi16> {
114+
}
115+
116+
// -----
117+
118+
//===----------------------------------------------------------------------===//
119+
// spirv.ARM.Graph using a non TensorArmType result
120+
//===----------------------------------------------------------------------===//
63121

64-
// CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
65-
spirv.ARM.Graph @graph2(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
66-
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8>
67-
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
68-
}
122+
// expected-error @+1 {{'spirv.ARM.Graph' op type of result #0 must be a TensorArmType, but got 'i8'}}
123+
spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> i8 {
69124
}

0 commit comments

Comments
 (0)