1- // ===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations
2- // ------------------------------===//
1+ // ===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations -------------===//
32//
43// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
54// See https://llvm.org/LICENSE.txt for license information.
2221#include " mlir/IR/BuiltinTypes.h"
2322#include " mlir/IR/Operation.h"
2423#include " mlir/Interfaces/FunctionImplementation.h"
24+ #include " llvm/Support/InterleavedRange.h"
2525
2626using namespace mlir ;
2727using namespace mlir ::spirv::AttrNames;
@@ -32,10 +32,7 @@ using namespace mlir::spirv::AttrNames;
3232
3333ParseResult spirv::GraphARMOp::parse (OpAsmParser &parser,
3434 OperationState &result) {
35- SmallVector<OpAsmParser::Argument> entryArgs;
36- SmallVector<DictionaryAttr> resultAttrs;
37- SmallVector<Type> resultTypes;
38- auto &builder = parser.getBuilder ();
35+ Builder &builder = parser.getBuilder ();
3936
4037 // Parse the name as a symbol.
4138 StringAttr nameAttr;
@@ -45,15 +42,18 @@ ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
4542
4643 // Parse the function signature.
4744 bool isVariadic = false ;
45+ SmallVector<OpAsmParser::Argument> entryArgs;
46+ SmallVector<Type> resultTypes;
47+ SmallVector<DictionaryAttr> resultAttrs;
4848 if (function_interface_impl::parseFunctionSignatureWithArguments (
4949 parser, /* allowVariadic=*/ false , entryArgs, isVariadic, resultTypes,
5050 resultAttrs))
5151 return failure ();
5252
5353 SmallVector<Type> argTypes;
54- for (auto &arg : entryArgs)
54+ for (OpAsmParser::Argument &arg : entryArgs)
5555 argTypes.push_back (arg.type );
56- auto grType = builder.getGraphType (argTypes, resultTypes);
56+ GraphType grType = builder.getGraphType (argTypes, resultTypes);
5757 result.addAttribute (getFunctionTypeAttrName (result.name ),
5858 TypeAttr::get (grType));
5959
@@ -136,26 +136,22 @@ LogicalResult spirv::GraphARMOp::verifyBody() {
136136 }
137137
138138 GraphType grType = getFunctionType ();
139- auto walkResult = walk ([grType](Operation *op) -> WalkResult {
140- if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
141- if (grType.getNumResults () != graphOutputsARMOp.getNumOperands ())
142- return graphOutputsARMOp.emitOpError (" is returning " )
143- << graphOutputsARMOp.getNumOperands ()
144- << " value(s) but enclosing spirv.ARM.Graph requires "
145- << grType.getNumResults () << " result(s)" ;
146-
147- ValueTypeRange<OperandRange> graphOutputOperandTypes =
148- graphOutputsARMOp.getValue ().getType ();
149- for (unsigned i = 0 , size = graphOutputOperandTypes.size (); i < size;
150- ++i) {
151- Type graphOutputOperandType = graphOutputOperandTypes[i];
152- Type grResultType = grType.getResult (i);
153- if (graphOutputOperandType != grResultType)
154- return graphOutputsARMOp.emitError (" type of return operand " )
155- << i << " (" << graphOutputOperandType
156- << " ) doesn't match graph result type (" << grResultType
157- << " )" ;
158- }
139+ auto walkResult = walk ([grType](spirv::GraphOutputsARMOp op) -> WalkResult {
140+ if (grType.getNumResults () != op.getNumOperands ())
141+ return op.emitOpError (" is returning " )
142+ << op.getNumOperands ()
143+ << " value(s) but enclosing spirv.ARM.Graph requires "
144+ << grType.getNumResults () << " result(s)" ;
145+
146+ ValueTypeRange<OperandRange> graphOutputOperandTypes =
147+ op.getValue ().getType ();
148+ for (unsigned i = 0 , size = graphOutputOperandTypes.size (); i < size; ++i) {
149+ Type graphOutputOperandType = graphOutputOperandTypes[i];
150+ Type grResultType = grType.getResult (i);
151+ if (graphOutputOperandType != grResultType)
152+ return op.emitError (" type of return operand " )
153+ << i << " (" << graphOutputOperandType
154+ << " ) doesn't match graph result type (" << grResultType << " )" ;
159155 }
160156 return WalkResult::advance ();
161157 });
@@ -169,23 +165,20 @@ void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
169165 state.addAttribute (SymbolTable::getSymbolAttrName (),
170166 builder.getStringAttr (name));
171167 state.addAttribute (getFunctionTypeAttrName (state.name ), TypeAttr::get (type));
172- state.attributes .append (attrs. begin (), attrs. end () );
168+ state.attributes .append (attrs);
173169 state.addAttribute (getEntryPointAttrName (state.name ),
174170 builder.getBoolAttr (entryPoint));
175171 state.addRegion ();
176172}
177173
178- // Returns the argument types of this function.
179174ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes () {
180175 return getFunctionType ().getInputs ();
181176}
182177
183- // Returns the result types of this function.
184178ArrayRef<Type> spirv::GraphARMOp::getResultTypes () {
185179 return getFunctionType ().getResults ();
186180}
187181
188- // CallableOpInterface
189182Region *spirv::GraphARMOp::getCallableRegion () {
190183 return isExternal () ? nullptr : &getBody ();
191184}
@@ -229,12 +222,11 @@ void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
229222
230223ParseResult spirv::GraphEntryPointARMOp::parse (OpAsmParser &parser,
231224 OperationState &result) {
232- SmallVector<Attribute, 4 > interfaceVars;
233-
234225 FlatSymbolRefAttr fn;
235226 if (parser.parseAttribute (fn, Type (), kFnNameAttrName , result.attributes ))
236227 return failure ();
237228
229+ SmallVector<Attribute, 4 > interfaceVars;
238230 if (!parser.parseOptionalComma ()) {
239231 // Parse the interface variables
240232 if (parser.parseCommaSeparatedList ([&]() -> ParseResult {
@@ -258,7 +250,6 @@ void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
258250 printer.printSymbolName (getFn ());
259251 ArrayRef<Attribute> interfaceVars = getInterface ().getValue ();
260252 if (!interfaceVars.empty ()) {
261- printer << " , " ;
262- llvm::interleaveComma (interfaceVars, printer);
253+ printer << " , " << llvm::interleaved (interfaceVars);
263254 }
264255}
0 commit comments