Skip to content

Commit 1a4f0aa

Browse files
Extract SPV_ARM_graph operations in its own file
Signed-off-by: Davide Grohmann <[email protected]> Change-Id: Ia74db44157fb724c9f787387386338814413db30
1 parent 5ae57fe commit 1a4f0aa

File tree

3 files changed

+265
-237
lines changed

3 files changed

+265
-237
lines changed
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations
2+
//------------------------------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This file defines the SPV_ARM_graph operations in the SPIR-V dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15+
16+
#include "SPIRVParsingUtils.h"
17+
18+
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19+
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20+
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
21+
#include "mlir/IR/Builders.h"
22+
#include "mlir/IR/BuiltinTypes.h"
23+
#include "mlir/IR/Operation.h"
24+
#include "mlir/Interfaces/FunctionImplementation.h"
25+
26+
using namespace mlir;
27+
using namespace mlir::spirv::AttrNames;
28+
29+
//===----------------------------------------------------------------------===//
30+
// spirv.GraphARM
31+
//===----------------------------------------------------------------------===//
32+
33+
ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
34+
OperationState &result) {
35+
SmallVector<OpAsmParser::Argument> entryArgs;
36+
SmallVector<DictionaryAttr> resultAttrs;
37+
SmallVector<Type> resultTypes;
38+
auto &builder = parser.getBuilder();
39+
40+
// Parse the name as a symbol.
41+
StringAttr nameAttr;
42+
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
43+
result.attributes))
44+
return failure();
45+
46+
// Parse the function signature.
47+
bool isVariadic = false;
48+
if (function_interface_impl::parseFunctionSignatureWithArguments(
49+
parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
50+
resultAttrs))
51+
return failure();
52+
53+
SmallVector<Type> argTypes;
54+
for (auto &arg : entryArgs)
55+
argTypes.push_back(arg.type);
56+
auto grType = builder.getGraphType(argTypes, resultTypes);
57+
result.addAttribute(getFunctionTypeAttrName(result.name),
58+
TypeAttr::get(grType));
59+
60+
// If additional attributes are present, parse them.
61+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
62+
return failure();
63+
64+
// Add the attributes to the function arguments.
65+
assert(resultAttrs.size() == resultTypes.size());
66+
call_interface_impl::addArgAndResultAttrs(
67+
builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
68+
getResAttrsAttrName(result.name));
69+
70+
// Parse the optional function body.
71+
Region *body = result.addRegion();
72+
OptionalParseResult parseResult =
73+
parser.parseOptionalRegion(*body, entryArgs);
74+
return failure(parseResult.has_value() && failed(*parseResult));
75+
}
76+
77+
void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
78+
// Print graph name, signature, and control.
79+
printer << " ";
80+
printer.printSymbolName(getSymName());
81+
GraphType grType = getFunctionType();
82+
function_interface_impl::printFunctionSignature(
83+
printer, *this, grType.getInputs(),
84+
/*isVariadic=*/false, grType.getResults());
85+
function_interface_impl::printFunctionAttributes(printer, *this,
86+
{getFunctionTypeAttrName(),
87+
getArgAttrsAttrName(),
88+
getResAttrsAttrName()});
89+
90+
// Print the body.
91+
Region &body = this->getBody();
92+
if (!body.empty()) {
93+
printer << ' ';
94+
printer.printRegion(body, /*printEntryBlockArgs=*/false,
95+
/*printBlockTerminators=*/true);
96+
}
97+
}
98+
99+
LogicalResult spirv::GraphARMOp::verifyType() {
100+
if (getFunctionType().getNumResults() < 1)
101+
return emitOpError("there should be at least one result");
102+
return success();
103+
}
104+
105+
LogicalResult spirv::GraphARMOp::verifyBody() {
106+
for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
107+
if (!isa<spirv::TensorArmType>(graphArgType)) {
108+
return emitOpError("type of argument #")
109+
<< index << " must be a TensorArmType, but got " << graphArgType;
110+
}
111+
}
112+
for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
113+
if (!isa<spirv::TensorArmType>(graphResType)) {
114+
return emitOpError("type of result #")
115+
<< index << " must be a TensorArmType, but got " << graphResType;
116+
}
117+
}
118+
119+
if (!isExternal()) {
120+
Block &entryBlock = front();
121+
122+
unsigned numArguments = this->getNumArguments();
123+
if (entryBlock.getNumArguments() != numArguments)
124+
return emitOpError("entry block must have ")
125+
<< numArguments << " arguments to match graph signature";
126+
127+
for (auto [index, grArgType, blockArgType] :
128+
llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
129+
if (blockArgType != grArgType) {
130+
return emitOpError("type of entry block argument #")
131+
<< index << '(' << blockArgType
132+
<< ") must match the type of the corresponding argument in "
133+
<< "graph signature(" << grArgType << ')';
134+
}
135+
}
136+
}
137+
138+
GraphType grType = getFunctionType();
139+
auto walkResult = walk([grType](Operation *op) -> WalkResult {
140+
if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
141+
if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
142+
return graphOutputsARMOp.emitOpError("is returning ")
143+
<< graphOutputsARMOp.getNumOperands()
144+
<< " value(s) but enclosing spirv.ARM.Graph requires "
145+
<< grType.getNumResults() << " result(s)";
146+
147+
ValueTypeRange<OperandRange> graphOutputOperandTypes =
148+
graphOutputsARMOp.getValue().getType();
149+
for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
150+
++i) {
151+
Type graphOutputOperandType = graphOutputOperandTypes[i];
152+
Type grResultType = grType.getResult(i);
153+
if (graphOutputOperandType != grResultType)
154+
return graphOutputsARMOp.emitError("type of return operand ")
155+
<< i << " (" << graphOutputOperandType
156+
<< ") doesn't match graph result type (" << grResultType
157+
<< ")";
158+
}
159+
}
160+
return WalkResult::advance();
161+
});
162+
163+
return failure(walkResult.wasInterrupted());
164+
}
165+
166+
void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
167+
StringRef name, GraphType type,
168+
ArrayRef<NamedAttribute> attrs, bool entryPoint) {
169+
state.addAttribute(SymbolTable::getSymbolAttrName(),
170+
builder.getStringAttr(name));
171+
state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
172+
state.attributes.append(attrs.begin(), attrs.end());
173+
state.addAttribute(getEntryPointAttrName(state.name),
174+
builder.getBoolAttr(entryPoint));
175+
state.addRegion();
176+
}
177+
178+
// Returns the argument types of this function.
179+
ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
180+
return getFunctionType().getInputs();
181+
}
182+
183+
// Returns the result types of this function.
184+
ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
185+
return getFunctionType().getResults();
186+
}
187+
188+
// CallableOpInterface
189+
Region *spirv::GraphARMOp::getCallableRegion() {
190+
return isExternal() ? nullptr : &getBody();
191+
}
192+
193+
//===----------------------------------------------------------------------===//
194+
// spirv.GraphOutputsARM
195+
//===----------------------------------------------------------------------===//
196+
197+
LogicalResult spirv::GraphOutputsARMOp::verify() {
198+
auto graph = cast<GraphARMOp>((*this)->getParentOp());
199+
200+
// The operand number and types must match the graph signature.
201+
const ArrayRef<Type> &results = graph.getFunctionType().getResults();
202+
if (getNumOperands() != results.size())
203+
return emitOpError("has ")
204+
<< getNumOperands() << " operands, but enclosing spirv.ARM.Graph (@"
205+
<< graph.getName() << ") returns " << results.size();
206+
207+
for (unsigned i = 0, size = results.size(); i < size; ++i)
208+
if (getOperand(i).getType() != results[i])
209+
return emitError() << "type of return operand " << i << " ("
210+
<< getOperand(i).getType()
211+
<< ") doesn't match spirv.ARM.Graph result type ("
212+
<< results[i] << ")"
213+
<< " in graph @" << graph.getName();
214+
215+
return success();
216+
}
217+
218+
//===----------------------------------------------------------------------===//
219+
// spirv.GraphEntryPointARM
220+
//===----------------------------------------------------------------------===//
221+
222+
void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
223+
OperationState &state,
224+
spirv::GraphARMOp graph,
225+
ArrayRef<Attribute> interfaceVars) {
226+
build(builder, state, SymbolRefAttr::get(graph),
227+
builder.getArrayAttr(interfaceVars));
228+
}
229+
230+
ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
231+
OperationState &result) {
232+
SmallVector<Attribute, 4> interfaceVars;
233+
234+
FlatSymbolRefAttr fn;
235+
if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
236+
return failure();
237+
238+
if (!parser.parseOptionalComma()) {
239+
// Parse the interface variables
240+
if (parser.parseCommaSeparatedList([&]() -> ParseResult {
241+
// The name of the interface variable attribute isnt important
242+
FlatSymbolRefAttr var;
243+
NamedAttrList attrs;
244+
if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
245+
return failure();
246+
interfaceVars.push_back(var);
247+
return success();
248+
}))
249+
return failure();
250+
}
251+
result.addAttribute("interface",
252+
parser.getBuilder().getArrayAttr(interfaceVars));
253+
return success();
254+
}
255+
256+
void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
257+
printer << " ";
258+
printer.printSymbolName(getFn());
259+
ArrayRef<Attribute> interfaceVars = getInterface().getValue();
260+
if (!interfaceVars.empty()) {
261+
printer << ", ";
262+
llvm::interleaveComma(interfaceVars, printer);
263+
}
264+
}

mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters)
33
add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
44

55
add_mlir_dialect_library(MLIRSPIRVDialect
6+
ArmGraphOps.cpp
67
AtomicOps.cpp
78
CastOps.cpp
89
ControlFlowOps.cpp

0 commit comments

Comments
 (0)