2222#include " mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
2323#include " mlir/IR/BuiltinAttributes.h"
2424#include " mlir/Transforms/DialectConversion.h"
25+ #include " llvm/Support/FormatVariadic.h"
2526
2627namespace mlir {
2728namespace spirv {
@@ -85,10 +86,36 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
8586 abiInfo.getBinding ());
8687}
8788
89+ // / Creates a global variable for an argument or result based on the ABI info.
90+ static spirv::GlobalVariableOp
91+ createGlobalVarForGraphEntryPoint (OpBuilder &builder, spirv::GraphARMOp graphOp,
92+ unsigned index, bool isArg,
93+ spirv::InterfaceVarABIAttr abiInfo) {
94+ auto spirvModule = graphOp->getParentOfType <spirv::ModuleOp>();
95+ if (!spirvModule)
96+ return nullptr ;
97+
98+ OpBuilder::InsertionGuard moduleInsertionGuard (builder);
99+ builder.setInsertionPoint (graphOp.getOperation ());
100+ std::string varName = llvm::formatv (" {}_{}_{}" , graphOp.getName (),
101+ isArg ? " arg" : " res" , index);
102+
103+ Type varType = isArg ? graphOp.getFunctionType ().getInput (index)
104+ : graphOp.getFunctionType ().getResult (index);
105+
106+ auto pointerType = spirv::PointerType::get (
107+ varType,
108+ abiInfo.getStorageClass ().value_or (spirv::StorageClass::UniformConstant));
109+
110+ return spirv::GlobalVariableOp::create (builder, graphOp.getLoc (), pointerType,
111+ varName, abiInfo.getDescriptorSet (),
112+ abiInfo.getBinding ());
113+ }
114+
88115// / Gets the global variables that need to be specified as interface variable
89116// / with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
90117static LogicalResult
91- getInterfaceVariables (spirv::FuncOp funcOp,
118+ getInterfaceVariables (mlir::FunctionOpInterface funcOp,
92119 SmallVectorImpl<Attribute> &interfaceVars) {
93120 auto module = funcOp->getParentOfType <spirv::ModuleOp>();
94121 if (!module ) {
@@ -224,6 +251,21 @@ class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
224251 ConversionPatternRewriter &rewriter) const override ;
225252};
226253
254+ // / A pattern to convert graph signature according to interface variable ABI
255+ // / attributes.
256+ // /
257+ // / Specifically, this pattern creates global variables according to interface
258+ // / variable ABI attributes attached to graph arguments and results.
259+ class ProcessGraphInterfaceVarABI final
260+ : public OpConversionPattern<spirv::GraphARMOp> {
261+ public:
262+ using OpConversionPattern::OpConversionPattern;
263+
264+ LogicalResult
265+ matchAndRewrite (spirv::GraphARMOp graphOp, OpAdaptor adaptor,
266+ ConversionPatternRewriter &rewriter) const override ;
267+ };
268+
227269// / Pass to implement the ABI information specified as attributes.
228270class LowerABIAttributesPass final
229271 : public spirv::impl::SPIRVLowerABIAttributesPassBase<
@@ -297,6 +339,63 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
297339 return success ();
298340}
299341
342+ LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite (
343+ spirv::GraphARMOp graphOp, OpAdaptor adaptor,
344+ ConversionPatternRewriter &rewriter) const {
345+ // Non-entry point graphs are not handled.
346+ if (!graphOp.getEntryPoint ().value_or (false ))
347+ return failure ();
348+
349+ TypeConverter::SignatureConversion signatureConverter (
350+ graphOp.getFunctionType ().getNumInputs ());
351+
352+ StringRef attrName = spirv::getInterfaceVarABIAttrName ();
353+ SmallVector<Attribute, 4 > interfaceVars;
354+
355+ // Convert arguments.
356+ unsigned numInputs = graphOp.getFunctionType ().getNumInputs ();
357+ unsigned numResults = graphOp.getFunctionType ().getNumResults ();
358+ for (unsigned index = 0 ; index < numInputs; ++index) {
359+ auto abiInfo =
360+ graphOp.getArgAttrOfType <spirv::InterfaceVarABIAttr>(index, attrName);
361+ if (!abiInfo)
362+ return failure ();
363+ spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint (
364+ rewriter, graphOp, index, true , abiInfo);
365+ if (!var)
366+ return failure ();
367+ interfaceVars.push_back (
368+ SymbolRefAttr::get (rewriter.getContext (), var.getSymName ()));
369+ }
370+
371+ for (unsigned index = 0 ; index < numResults; ++index) {
372+ auto abiInfo = graphOp.getResultAttrOfType <spirv::InterfaceVarABIAttr>(
373+ index, attrName);
374+ if (!abiInfo)
375+ return failure ();
376+ spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint (
377+ rewriter, graphOp, index, false , abiInfo);
378+ if (!var)
379+ return failure ();
380+ interfaceVars.push_back (
381+ SymbolRefAttr::get (rewriter.getContext (), var.getSymName ()));
382+ }
383+
384+ // Update graph signature.
385+ rewriter.modifyOpInPlace (graphOp, [&] {
386+ for (unsigned index = 0 ; index < numInputs; ++index) {
387+ graphOp.removeArgAttr (index, attrName);
388+ }
389+ for (unsigned index = 0 ; index < numResults; ++index) {
390+ graphOp.removeResultAttr (index, rewriter.getStringAttr (attrName));
391+ }
392+ });
393+
394+ spirv::GraphEntryPointARMOp::create (rewriter, graphOp.getLoc (), graphOp,
395+ interfaceVars);
396+ return success ();
397+ }
398+
300399void LowerABIAttributesPass::runOnOperation () {
301400 // Uses the signature conversion methodology of the dialect conversion
302401 // framework to implement the conversion.
@@ -322,7 +421,8 @@ void LowerABIAttributesPass::runOnOperation() {
322421 });
323422
324423 RewritePatternSet patterns (context);
325- patterns.add <ProcessInterfaceVarABI>(typeConverter, context);
424+ patterns.add <ProcessInterfaceVarABI, ProcessGraphInterfaceVarABI>(
425+ typeConverter, context);
326426
327427 ConversionTarget target (*context);
328428 // "Legal" function ops should have no interface variable ABI attributes.
@@ -333,6 +433,17 @@ void LowerABIAttributesPass::runOnOperation() {
333433 return false ;
334434 return true ;
335435 });
436+ target.addDynamicallyLegalOp <spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
437+ StringRef attrName = spirv::getInterfaceVarABIAttrName ();
438+ for (unsigned i = 0 , e = op.getNumArguments (); i < e; ++i)
439+ if (op.getArgAttr (i, attrName))
440+ return false ;
441+ for (unsigned i = 0 , e = op.getNumResults (); i < e; ++i)
442+ if (op.getResultAttr (i, attrName))
443+ return false ;
444+ return true ;
445+ });
446+
336447 // All other SPIR-V ops are legal.
337448 target.markUnknownOpDynamicallyLegal ([](Operation *op) {
338449 return op->getDialect ()->getNamespace () ==
0 commit comments