Skip to content

Commit 2203b12

Browse files
Resolve small review comments
Signed-off-by: Davide Grohmann <[email protected]> Change-Id: I891e89bd0ac6ead942f6b9f9807ea271b3627ba1
1 parent 234c095 commit 2203b12

File tree

7 files changed

+22
-25
lines changed

7 files changed

+22
-25
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
6666
#### Example:
6767

6868
```mlir
69-
spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
69+
spirv.ARM.Graph @graph(%arg0: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
7070
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
7171
}
7272
```

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ class OpaqueType<string dialect, string name, string summary>
387387
def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
388388
"function type", "::mlir::FunctionType">;
389389

390-
// Graph Type
390+
// Graph Type.
391391

392392
// Any graph type.
393393
def GraphType : Type<CPred<"::llvm::isa<::mlir::GraphType>($_self)">,

mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,8 @@ ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
5050
resultAttrs))
5151
return failure();
5252

53-
SmallVector<Type> argTypes;
54-
for (OpAsmParser::Argument &arg : entryArgs)
55-
argTypes.push_back(arg.type);
53+
SmallVector<Type> argTypes = llvm::map_to_vector(
54+
entryArgs, [](const OpAsmParser::Argument &arg) { return arg.type; });
5655
GraphType grType = builder.getGraphType(argTypes, resultTypes);
5756
result.addAttribute(getFunctionTypeAttrName(result.name),
5857
TypeAttr::get(grType));
@@ -197,14 +196,13 @@ LogicalResult spirv::GraphOutputsARMOp::verify() {
197196
<< getNumOperands() << " operands, but enclosing spirv.ARM.Graph (@"
198197
<< graph.getName() << ") returns " << results.size();
199198

200-
for (unsigned i = 0, size = results.size(); i < size; ++i)
201-
if (getOperand(i).getType() != results[i])
202-
return emitError() << "type of return operand " << i << " ("
203-
<< getOperand(i).getType()
199+
for (const auto &result : llvm::enumerate(results))
200+
if (getOperand(result.index()).getType() != result.value())
201+
return emitError() << "type of return operand " << result.index() << " ("
202+
<< getOperand(result.index()).getType()
204203
<< ") doesn't match spirv.ARM.Graph result type ("
205-
<< results[i] << ")"
204+
<< result.value() << ")"
206205
<< " in graph @" << graph.getName();
207-
208206
return success();
209207
}
210208

@@ -228,9 +226,9 @@ ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
228226

229227
SmallVector<Attribute, 4> interfaceVars;
230228
if (!parser.parseOptionalComma()) {
231-
// Parse the interface variables
229+
// Parse the interface variables.
232230
if (parser.parseCommaSeparatedList([&]() -> ParseResult {
233-
// The name of the interface variable attribute isnt important
231+
// The name of the interface variable attribute is not important.
234232
FlatSymbolRefAttr var;
235233
NamedAttrList attrs;
236234
if (parser.parseAttribute(var, Type(), "var_symbol", attrs))

mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ void UpdateVCEPass::runOnOperation() {
158158
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
159159
valueTypes.push_back(globalVar.getType());
160160

161-
// If the op is FunctionLike make sure to process input and result types
161+
// If the op is FunctionLike make sure to process input and result types.
162162
if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
163163
llvm::append_range(valueTypes, funcOpInterface.getArgumentTypes());
164164
llvm::append_range(valueTypes, funcOpInterface.getResultTypes());

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,8 +2841,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
28412841
interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); });
28422842
os << ") -> ";
28432843
ArrayRef<Type> results = graphTy.getResults();
2844-
if (results.size() == 1 &&
2845-
!(isa<FunctionType>(results[0]) || isa<GraphType>(results[0]))) {
2844+
if (results.size() == 1 && !isa<FunctionType, GraphType>(results[0])) {
28462845
printType(results[0]);
28472846
} else {
28482847
os << '(';

mlir/test/Dialect/SPIRV/IR/graph-ops.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
//===----------------------------------------------------------------------===//
66

77
// CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
8-
spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
8+
spirv.ARM.Graph @graphAndOutputs(%arg0: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
99
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
1010
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
1111
}
@@ -43,7 +43,7 @@ spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0
4343
//===----------------------------------------------------------------------===//
4444

4545
// expected-error @+1 {{empty block: expect at least a terminator}}
46-
spirv.ARM.Graph @graphNoterminator(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
46+
spirv.ARM.Graph @graphNoterminator(%arg0: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
4747
}
4848

4949
// -----
@@ -53,7 +53,7 @@ spirv.ARM.Graph @graphNoterminator(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spi
5353
//===----------------------------------------------------------------------===//
5454

5555
// expected-error @+1 {{'spirv.ARM.Graph' op there should be at least one result}}
56-
spirv.ARM.Graph @graphNoOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> () {
56+
spirv.ARM.Graph @graphNoOutputs(%arg0: !spirv.arm.tensor<14x19xi16>) -> () {
5757
}
5858

5959
// -----
@@ -80,7 +80,7 @@ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1xi16>
8080
// spirv.ARM.Graph return type does not match spirv.ARM.GraphOutputs
8181
//===----------------------------------------------------------------------===//
8282

83-
spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<5x3xi16> {
83+
spirv.ARM.Graph @graphAndOutputs(%arg0: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<5x3xi16> {
8484
// expected-error @+1 {{type of return operand 0 ('!spirv.arm.tensor<14x19xi16>') doesn't match graph result type ('!spirv.arm.tensor<5x3xi16>')}}
8585
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
8686
}
@@ -91,14 +91,14 @@ spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv
9191
// spirv.ARM.Graph return type does not match number of results in spirv.ARM.GraphOutputs
9292
//===----------------------------------------------------------------------===//
9393

94-
spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> (!spirv.arm.tensor<14x19xi16>, !spirv.arm.tensor<14x19xi16>) {
94+
spirv.ARM.Graph @graphAndOutputs(%arg0: !spirv.arm.tensor<14x19xi16>) -> (!spirv.arm.tensor<14x19xi16>, !spirv.arm.tensor<14x19xi16>) {
9595
// expected-error @+1 {{'spirv.ARM.GraphOutputs' op is returning 1 value(s) but enclosing spirv.ARM.Graph requires 2 result(s)}}
9696
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
9797
}
9898

9999
// -----
100100

101-
spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
101+
spirv.ARM.Graph @graphAndOutputs(%arg0: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
102102
// expected-error @+1 {{'spirv.ARM.GraphOutputs' op is returning 2 value(s) but enclosing spirv.ARM.Graph requires 1 result(s)}}
103103
spirv.ARM.GraphOutputs %arg0, %arg0 : !spirv.arm.tensor<14x19xi16>, !spirv.arm.tensor<14x19xi16>
104104
}
@@ -110,7 +110,7 @@ spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv
110110
//===----------------------------------------------------------------------===//
111111

112112
// expected-error @+1 {{'spirv.ARM.Graph' op type of argument #0 must be a TensorArmType, but got 'i8'}}
113-
spirv.ARM.Graph @graphAndOutputs(%arg0 : i8) -> !spirv.arm.tensor<14x19xi16> {
113+
spirv.ARM.Graph @graphAndOutputs(%arg0: i8) -> !spirv.arm.tensor<14x19xi16> {
114114
}
115115

116116
// -----
@@ -120,5 +120,5 @@ spirv.ARM.Graph @graphAndOutputs(%arg0 : i8) -> !spirv.arm.tensor<14x19xi16> {
120120
//===----------------------------------------------------------------------===//
121121

122122
// expected-error @+1 {{'spirv.ARM.Graph' op type of result #0 must be a TensorArmType, but got 'i8'}}
123-
spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> i8 {
123+
spirv.ARM.Graph @graphAndOutputs(%arg0: !spirv.arm.tensor<14x19xi16>) -> i8 {
124124
}

mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ spirv.module Logical Vulkan attributes {
238238
#spirv.vce<v1.5, [VulkanMemoryModel, GraphARM, TensorsARM, Float16], [SPV_ARM_tensors, SPV_ARM_graph]>,
239239
#spirv.resource_limits<>>
240240
} {
241-
spirv.ARM.Graph @argmax(%arg0 : !spirv.arm.tensor<14x19xi8>, %arg1 : !spirv.arm.tensor<1xf16>) -> !spirv.arm.tensor<14x19xi8> {
241+
spirv.ARM.Graph @argmax(%arg0: !spirv.arm.tensor<14x19xi8>, %arg1 : !spirv.arm.tensor<1xf16>) -> !spirv.arm.tensor<14x19xi8> {
242242
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi8>
243243
}
244244
}

0 commit comments

Comments
 (0)