22
22
#include " mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
23
23
#include " mlir/IR/BuiltinAttributes.h"
24
24
#include " mlir/Transforms/DialectConversion.h"
25
+ #include " llvm/Support/FormatVariadic.h"
25
26
26
27
namespace mlir {
27
28
namespace spirv {
@@ -85,10 +86,36 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
85
86
abiInfo.getBinding ());
86
87
}
87
88
89
+ // / Creates a global variable for an argument or result based on the ABI info.
90
+ static spirv::GlobalVariableOp
91
+ createGlobalVarForGraphEntryPoint (OpBuilder &builder, spirv::GraphARMOp graphOp,
92
+ unsigned index, bool isArg,
93
+ spirv::InterfaceVarABIAttr abiInfo) {
94
+ auto spirvModule = graphOp->getParentOfType <spirv::ModuleOp>();
95
+ if (!spirvModule)
96
+ return nullptr ;
97
+
98
+ OpBuilder::InsertionGuard moduleInsertionGuard (builder);
99
+ builder.setInsertionPoint (graphOp.getOperation ());
100
+ std::string varName = llvm::formatv (" {}_{}_{}" , graphOp.getName (),
101
+ isArg ? " arg" : " res" , index);
102
+
103
+ Type varType = isArg ? graphOp.getFunctionType ().getInput (index)
104
+ : graphOp.getFunctionType ().getResult (index);
105
+
106
+ auto pointerType = spirv::PointerType::get (
107
+ varType,
108
+ abiInfo.getStorageClass ().value_or (spirv::StorageClass::UniformConstant));
109
+
110
+ return spirv::GlobalVariableOp::create (builder, graphOp.getLoc (), pointerType,
111
+ varName, abiInfo.getDescriptorSet (),
112
+ abiInfo.getBinding ());
113
+ }
114
+
88
115
// / Gets the global variables that need to be specified as interface variable
89
116
// / with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
90
117
static LogicalResult
91
- getInterfaceVariables (spirv::FuncOp funcOp,
118
+ getInterfaceVariables (mlir::FunctionOpInterface funcOp,
92
119
SmallVectorImpl<Attribute> &interfaceVars) {
93
120
auto module = funcOp->getParentOfType <spirv::ModuleOp>();
94
121
if (!module ) {
@@ -224,6 +251,21 @@ class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
224
251
ConversionPatternRewriter &rewriter) const override ;
225
252
};
226
253
254
+ // / A pattern to convert graph signature according to interface variable ABI
255
+ // / attributes.
256
+ // /
257
+ // / Specifically, this pattern creates global variables according to interface
258
+ // / variable ABI attributes attached to graph arguments and results.
259
+ class ProcessGraphInterfaceVarABI final
260
+ : public OpConversionPattern<spirv::GraphARMOp> {
261
+ public:
262
+ using OpConversionPattern::OpConversionPattern;
263
+
264
+ LogicalResult
265
+ matchAndRewrite (spirv::GraphARMOp graphOp, OpAdaptor adaptor,
266
+ ConversionPatternRewriter &rewriter) const override ;
267
+ };
268
+
227
269
// / Pass to implement the ABI information specified as attributes.
228
270
class LowerABIAttributesPass final
229
271
: public spirv::impl::SPIRVLowerABIAttributesPassBase<
@@ -297,6 +339,63 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
297
339
return success ();
298
340
}
299
341
342
+ LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite (
343
+ spirv::GraphARMOp graphOp, OpAdaptor adaptor,
344
+ ConversionPatternRewriter &rewriter) const {
345
+ // Non-entry point graphs are not handled.
346
+ if (!graphOp.getEntryPoint ().value_or (false ))
347
+ return failure ();
348
+
349
+ TypeConverter::SignatureConversion signatureConverter (
350
+ graphOp.getFunctionType ().getNumInputs ());
351
+
352
+ StringRef attrName = spirv::getInterfaceVarABIAttrName ();
353
+ SmallVector<Attribute, 4 > interfaceVars;
354
+
355
+ // Convert arguments.
356
+ unsigned numInputs = graphOp.getFunctionType ().getNumInputs ();
357
+ unsigned numResults = graphOp.getFunctionType ().getNumResults ();
358
+ for (unsigned index = 0 ; index < numInputs; ++index) {
359
+ auto abiInfo =
360
+ graphOp.getArgAttrOfType <spirv::InterfaceVarABIAttr>(index, attrName);
361
+ if (!abiInfo)
362
+ return failure ();
363
+ spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint (
364
+ rewriter, graphOp, index, true , abiInfo);
365
+ if (!var)
366
+ return failure ();
367
+ interfaceVars.push_back (
368
+ SymbolRefAttr::get (rewriter.getContext (), var.getSymName ()));
369
+ }
370
+
371
+ for (unsigned index = 0 ; index < numResults; ++index) {
372
+ auto abiInfo = graphOp.getResultAttrOfType <spirv::InterfaceVarABIAttr>(
373
+ index, attrName);
374
+ if (!abiInfo)
375
+ return failure ();
376
+ spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint (
377
+ rewriter, graphOp, index, false , abiInfo);
378
+ if (!var)
379
+ return failure ();
380
+ interfaceVars.push_back (
381
+ SymbolRefAttr::get (rewriter.getContext (), var.getSymName ()));
382
+ }
383
+
384
+ // Update graph signature.
385
+ rewriter.modifyOpInPlace (graphOp, [&] {
386
+ for (unsigned index = 0 ; index < numInputs; ++index) {
387
+ graphOp.removeArgAttr (index, attrName);
388
+ }
389
+ for (unsigned index = 0 ; index < numResults; ++index) {
390
+ graphOp.removeResultAttr (index, rewriter.getStringAttr (attrName));
391
+ }
392
+ });
393
+
394
+ spirv::GraphEntryPointARMOp::create (rewriter, graphOp.getLoc (), graphOp,
395
+ interfaceVars);
396
+ return success ();
397
+ }
398
+
300
399
void LowerABIAttributesPass::runOnOperation () {
301
400
// Uses the signature conversion methodology of the dialect conversion
302
401
// framework to implement the conversion.
@@ -322,7 +421,8 @@ void LowerABIAttributesPass::runOnOperation() {
322
421
});
323
422
324
423
RewritePatternSet patterns (context);
325
- patterns.add <ProcessInterfaceVarABI>(typeConverter, context);
424
+ patterns.add <ProcessInterfaceVarABI, ProcessGraphInterfaceVarABI>(
425
+ typeConverter, context);
326
426
327
427
ConversionTarget target (*context);
328
428
// "Legal" function ops should have no interface variable ABI attributes.
@@ -333,6 +433,17 @@ void LowerABIAttributesPass::runOnOperation() {
333
433
return false ;
334
434
return true ;
335
435
});
436
+ target.addDynamicallyLegalOp <spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
437
+ StringRef attrName = spirv::getInterfaceVarABIAttrName ();
438
+ for (unsigned i = 0 , e = op.getNumArguments (); i < e; ++i)
439
+ if (op.getArgAttr (i, attrName))
440
+ return false ;
441
+ for (unsigned i = 0 , e = op.getNumResults (); i < e; ++i)
442
+ if (op.getResultAttr (i, attrName))
443
+ return false ;
444
+ return true ;
445
+ });
446
+
336
447
// All other SPIR-V ops are legal.
337
448
target.markUnknownOpDynamicallyLegal ([](Operation *op) {
338
449
return op->getDialect ()->getNamespace () ==
0 commit comments