diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index fcf1526491971..44c86bc8777e4 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -1066,7 +1066,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op, } LogicalResult SPIRVDialect::verifyRegionResultAttribute( - Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/, + Operation *op, unsigned /*regionIndex*/, unsigned resultIndex, NamedAttribute attribute) { - return op->emitError("cannot attach SPIR-V attributes to region result"); + if (auto graphOp = dyn_cast(op)) + return verifyRegionAttribute( + op->getLoc(), graphOp.getResultTypes()[resultIndex], attribute); + return op->emitError( + "cannot attach SPIR-V attributes to region result which is " + "not part of a spirv::GraphARMOp type"); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 3911ec08fcc27..5607a3cd3660f 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" namespace mlir { namespace spirv { @@ -85,10 +86,36 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, abiInfo.getBinding()); } +/// Creates a global variable for an argument or result based on the ABI info. +static spirv::GlobalVariableOp +createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp, + unsigned index, bool isArg, + spirv::InterfaceVarABIAttr abiInfo) { + auto spirvModule = graphOp->getParentOfType(); + if (!spirvModule) + return nullptr; + + OpBuilder::InsertionGuard moduleInsertionGuard(builder); + builder.setInsertionPoint(graphOp.getOperation()); + std::string varName = llvm::formatv("{}_{}_{}", graphOp.getName(), + isArg ? "arg" : "res", index); + + Type varType = isArg ? graphOp.getFunctionType().getInput(index) + : graphOp.getFunctionType().getResult(index); + + auto pointerType = spirv::PointerType::get( + varType, + abiInfo.getStorageClass().value_or(spirv::StorageClass::UniformConstant)); + + return spirv::GlobalVariableOp::create(builder, graphOp.getLoc(), pointerType, + varName, abiInfo.getDescriptorSet(), + abiInfo.getBinding()); +} + /// Gets the global variables that need to be specified as interface variable /// with an spirv.EntryPointOp. Traverses the body of a entry function to do so. static LogicalResult -getInterfaceVariables(spirv::FuncOp funcOp, +getInterfaceVariables(mlir::FunctionOpInterface funcOp, SmallVectorImpl &interfaceVars) { auto module = funcOp->getParentOfType(); if (!module) { @@ -224,6 +251,21 @@ class ProcessInterfaceVarABI final : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override; }; +/// A pattern to convert graph signature according to interface variable ABI +/// attributes. +/// +/// Specifically, this pattern creates global variables according to interface +/// variable ABI attributes attached to graph arguments and results. +class ProcessGraphInterfaceVarABI final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Pass to implement the ABI information specified as attributes. class LowerABIAttributesPass final : public spirv::impl::SPIRVLowerABIAttributesPassBase< @@ -297,6 +339,63 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite( return success(); } +LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite( + spirv::GraphARMOp graphOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Non-entry point graphs are not handled. + if (!graphOp.getEntryPoint().value_or(false)) + return failure(); + + TypeConverter::SignatureConversion signatureConverter( + graphOp.getFunctionType().getNumInputs()); + + StringRef attrName = spirv::getInterfaceVarABIAttrName(); + SmallVector interfaceVars; + + // Convert arguments. + unsigned numInputs = graphOp.getFunctionType().getNumInputs(); + unsigned numResults = graphOp.getFunctionType().getNumResults(); + for (unsigned index = 0; index < numInputs; ++index) { + auto abiInfo = + graphOp.getArgAttrOfType(index, attrName); + if (!abiInfo) + return failure(); + spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint( + rewriter, graphOp, index, true, abiInfo); + if (!var) + return failure(); + interfaceVars.push_back( + SymbolRefAttr::get(rewriter.getContext(), var.getSymName())); + } + + for (unsigned index = 0; index < numResults; ++index) { + auto abiInfo = graphOp.getResultAttrOfType( + index, attrName); + if (!abiInfo) + return failure(); + spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint( + rewriter, graphOp, index, false, abiInfo); + if (!var) + return failure(); + interfaceVars.push_back( + SymbolRefAttr::get(rewriter.getContext(), var.getSymName())); + } + + // Update graph signature. + rewriter.modifyOpInPlace(graphOp, [&] { + for (unsigned index = 0; index < numInputs; ++index) { + graphOp.removeArgAttr(index, attrName); + } + for (unsigned index = 0; index < numResults; ++index) { + graphOp.removeResultAttr(index, rewriter.getStringAttr(attrName)); + } + }); + + spirv::GraphEntryPointARMOp::create(rewriter, graphOp.getLoc(), graphOp, + interfaceVars); + return success(); +} + void LowerABIAttributesPass::runOnOperation() { // Uses the signature conversion methodology of the dialect conversion // framework to implement the conversion. @@ -322,7 +421,8 @@ void LowerABIAttributesPass::runOnOperation() { }); RewritePatternSet patterns(context); - patterns.add(typeConverter, context); + patterns.add( + typeConverter, context); ConversionTarget target(*context); // "Legal" function ops should have no interface variable ABI attributes. @@ -333,6 +433,17 @@ void LowerABIAttributesPass::runOnOperation() { return false; return true; }); + target.addDynamicallyLegalOp([&](spirv::GraphARMOp op) { + StringRef attrName = spirv::getInterfaceVarABIAttrName(); + for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) + if (op.getArgAttr(i, attrName)) + return false; + for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) + if (op.getResultAttr(i, attrName)) + return false; + return true; + }); + // All other SPIR-V ops are legal. target.markUnknownOpDynamicallyLegal([](Operation *op) { return op->getDialect()->getNamespace() == diff --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir index 10fbcf06eb052..63dea6af83556 100644 --- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir +++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir @@ -101,6 +101,14 @@ func.func @interface_var( // ----- +// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>} +// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>} +spirv.ARM.Graph @interface_var(%arg: !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> ( + !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>} +) { spirv.ARM.GraphOutputs %arg : !spirv.arm.tensor<1xf32> } + +// ----- + //===----------------------------------------------------------------------===// // spirv.resource_limits //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir index f3a3218e5aec0..04667c828bbd1 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir @@ -35,6 +35,28 @@ spirv.module Logical GLSL450 { // ----- +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: spirv.module +spirv.module Logical Vulkan { + // CHECK-DAG: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr, UniformConstant> + // CHECK-DAG: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr, UniformConstant> + + // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]] + // CHECK: spirv.ARM.Graph [[GN]]([[ARG0:%.*]]: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = true} + spirv.ARM.Graph @main(%arg0: !spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) + -> (!spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} { + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8> + } +} // end spirv.module + +} // end module + +// ----- + module { // expected-error@+1 {{'spirv.module' op missing SPIR-V target env attribute}} spirv.module Logical GLSL450 {}