Skip to content

Commit f4c2c3a

Browse files
Fix review comments
Signed-off-by: Davide Grohmann <[email protected]> Change-Id: Ib592a4b99af01c0d3c88eaf63a61cb4c6cca8cbe
1 parent 3776f37 commit f4c2c3a

File tree

6 files changed

+35
-60
lines changed

6 files changed

+35
-60
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
5151
that is unique in the enclosing module op.
5252

5353
This op itself takes no operands and generates no results. Its region
54-
can take zero or more arguments and return zero or more values.
54+
can take zero or more arguments and return one or more values.
5555

5656
```
5757
spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
@@ -119,10 +119,6 @@ def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [InGraphScope,
119119
Result Type must be an OpTypeTensorARM.
120120
GraphConstantID must be a 32-bit integer literal.
121121

122-
```
123-
spv-graph-constant-arm-op ::= `spirv.ARM.GraphConstant` { graph_constant_id = 42 : i32 }
124-
```
125-
126122
#### Example:
127123

128124
```mlir
@@ -167,11 +163,6 @@ def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleSc
167163
`spirv.GlobalVariable`s referenced by the entry point’s static call
168164
tree, within the interface’s storage classes.
169165

170-
```
171-
entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint`
172-
symbol-reference (`, ` symbol-reference)*
173-
```
174-
175166
#### Example:
176167

177168
```mlir
@@ -214,10 +205,6 @@ def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pu
214205

215206
This instruction must be the last instruction in a block.
216207

217-
```
218-
graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens
219-
```
220-
221208
#### Example:
222209

223210
```mlir

mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations
2-
//------------------------------===//
1+
//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations -------------===//
32
//
43
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
54
// See https://llvm.org/LICENSE.txt for license information.
@@ -22,6 +21,7 @@
2221
#include "mlir/IR/BuiltinTypes.h"
2322
#include "mlir/IR/Operation.h"
2423
#include "mlir/Interfaces/FunctionImplementation.h"
24+
#include "llvm/Support/InterleavedRange.h"
2525

2626
using namespace mlir;
2727
using namespace mlir::spirv::AttrNames;
@@ -32,10 +32,7 @@ using namespace mlir::spirv::AttrNames;
3232

3333
ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
3434
OperationState &result) {
35-
SmallVector<OpAsmParser::Argument> entryArgs;
36-
SmallVector<DictionaryAttr> resultAttrs;
37-
SmallVector<Type> resultTypes;
38-
auto &builder = parser.getBuilder();
35+
Builder &builder = parser.getBuilder();
3936

4037
// Parse the name as a symbol.
4138
StringAttr nameAttr;
@@ -45,15 +42,18 @@ ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
4542

4643
// Parse the function signature.
4744
bool isVariadic = false;
45+
SmallVector<OpAsmParser::Argument> entryArgs;
46+
SmallVector<Type> resultTypes;
47+
SmallVector<DictionaryAttr> resultAttrs;
4848
if (function_interface_impl::parseFunctionSignatureWithArguments(
4949
parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
5050
resultAttrs))
5151
return failure();
5252

5353
SmallVector<Type> argTypes;
54-
for (auto &arg : entryArgs)
54+
for (OpAsmParser::Argument &arg : entryArgs)
5555
argTypes.push_back(arg.type);
56-
auto grType = builder.getGraphType(argTypes, resultTypes);
56+
GraphType grType = builder.getGraphType(argTypes, resultTypes);
5757
result.addAttribute(getFunctionTypeAttrName(result.name),
5858
TypeAttr::get(grType));
5959

@@ -136,26 +136,22 @@ LogicalResult spirv::GraphARMOp::verifyBody() {
136136
}
137137

138138
GraphType grType = getFunctionType();
139-
auto walkResult = walk([grType](Operation *op) -> WalkResult {
140-
if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
141-
if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
142-
return graphOutputsARMOp.emitOpError("is returning ")
143-
<< graphOutputsARMOp.getNumOperands()
144-
<< " value(s) but enclosing spirv.ARM.Graph requires "
145-
<< grType.getNumResults() << " result(s)";
146-
147-
ValueTypeRange<OperandRange> graphOutputOperandTypes =
148-
graphOutputsARMOp.getValue().getType();
149-
for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
150-
++i) {
151-
Type graphOutputOperandType = graphOutputOperandTypes[i];
152-
Type grResultType = grType.getResult(i);
153-
if (graphOutputOperandType != grResultType)
154-
return graphOutputsARMOp.emitError("type of return operand ")
155-
<< i << " (" << graphOutputOperandType
156-
<< ") doesn't match graph result type (" << grResultType
157-
<< ")";
158-
}
139+
auto walkResult = walk([grType](spirv::GraphOutputsARMOp op) -> WalkResult {
140+
if (grType.getNumResults() != op.getNumOperands())
141+
return op.emitOpError("is returning ")
142+
<< op.getNumOperands()
143+
<< " value(s) but enclosing spirv.ARM.Graph requires "
144+
<< grType.getNumResults() << " result(s)";
145+
146+
ValueTypeRange<OperandRange> graphOutputOperandTypes =
147+
op.getValue().getType();
148+
for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size; ++i) {
149+
Type graphOutputOperandType = graphOutputOperandTypes[i];
150+
Type grResultType = grType.getResult(i);
151+
if (graphOutputOperandType != grResultType)
152+
return op.emitError("type of return operand ")
153+
<< i << " (" << graphOutputOperandType
154+
<< ") doesn't match graph result type (" << grResultType << ")";
159155
}
160156
return WalkResult::advance();
161157
});
@@ -169,23 +165,20 @@ void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
169165
state.addAttribute(SymbolTable::getSymbolAttrName(),
170166
builder.getStringAttr(name));
171167
state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
172-
state.attributes.append(attrs.begin(), attrs.end());
168+
state.attributes.append(attrs);
173169
state.addAttribute(getEntryPointAttrName(state.name),
174170
builder.getBoolAttr(entryPoint));
175171
state.addRegion();
176172
}
177173

178-
// Returns the argument types of this function.
179174
ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
180175
return getFunctionType().getInputs();
181176
}
182177

183-
// Returns the result types of this function.
184178
ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
185179
return getFunctionType().getResults();
186180
}
187181

188-
// CallableOpInterface
189182
Region *spirv::GraphARMOp::getCallableRegion() {
190183
return isExternal() ? nullptr : &getBody();
191184
}
@@ -229,12 +222,11 @@ void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
229222

230223
ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
231224
OperationState &result) {
232-
SmallVector<Attribute, 4> interfaceVars;
233-
234225
FlatSymbolRefAttr fn;
235226
if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
236227
return failure();
237228

229+
SmallVector<Attribute, 4> interfaceVars;
238230
if (!parser.parseOptionalComma()) {
239231
// Parse the interface variables
240232
if (parser.parseCommaSeparatedList([&]() -> ParseResult {
@@ -258,7 +250,6 @@ void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
258250
printer.printSymbolName(getFn());
259251
ArrayRef<Attribute> interfaceVars = getInterface().getValue();
260252
if (!interfaceVars.empty()) {
261-
printer << ", ";
262-
llvm::interleaveComma(interfaceVars, printer);
253+
printer << ", " << llvm::interleaved(interfaceVars);
263254
}
264255
}

mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,8 @@ void UpdateVCEPass::runOnOperation() {
160160

161161
// If the op is FunctionLike make sure to process input and result types
162162
if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
163-
ArrayRef<Type> inputTypes = funcOpInterface.getArgumentTypes();
164-
ArrayRef<Type> resultTypes = funcOpInterface.getResultTypes();
165-
valueTypes.append(inputTypes.begin(), inputTypes.end());
166-
valueTypes.append(resultTypes.begin(), resultTypes.end());
163+
llvm::append_range(valueTypes, funcOpInterface.getArgumentTypes());
164+
llvm::append_range(valueTypes, funcOpInterface.getResultTypes());
167165
}
168166

169167
// Requirements from values' types

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
104104
// it is a function (avoiding a grammar ambiguity).
105105
bool wrapped = op->getNumResults() != 1;
106106
if (!wrapped && op->getResult(0).getType() &&
107-
(llvm::isa<FunctionType>(op->getResult(0).getType()) ||
108-
llvm::isa<GraphType>(op->getResult(0).getType())))
107+
isa<GraphType>(op->getResult(0).getType()))
109108
wrapped = true;
110109

111110
if (wrapped)
@@ -2842,8 +2841,8 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
28422841
interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); });
28432842
os << ") -> ";
28442843
ArrayRef<Type> results = graphTy.getResults();
2845-
if (results.size() == 1 && !(llvm::isa<FunctionType>(results[0]) ||
2846-
llvm::isa<GraphType>(results[0]))) {
2844+
if (results.size() == 1 &&
2845+
!(isa<FunctionType>(results[0]) || isa<GraphType>(results[0]))) {
28472846
printType(results[0]);
28482847
} else {
28492848
os << '(';

mlir/test/Dialect/SPIRV/IR/graph-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
1+
// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
22

33
//===----------------------------------------------------------------------===//
44
// spirv.ARM.Graph and spirv.ARM.GraphOutputs

mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct PrintOpAvailability
3333
} // namespace
3434

3535
void PrintOpAvailability::runOnOperation() {
36-
auto moduleOp = getOperation();
36+
mlir::ModuleOp moduleOp = getOperation();
3737
Dialect *spirvDialect = getContext().getLoadedDialect("spirv");
3838

3939
auto opCallback = [&](Operation *op) {

0 commit comments

Comments
 (0)