Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
}

LogicalResult SPIRVDialect::verifyRegionResultAttribute(
Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
Operation *op, unsigned /*regionIndex*/, unsigned resultIndex,
NamedAttribute attribute) {
return op->emitError("cannot attach SPIR-V attributes to region result");
if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op))
return verifyRegionAttribute(
op->getLoc(), graphOp.getResultTypes()[resultIndex], attribute);
return op->emitError(
"cannot attach SPIR-V attributes to region result which is "
"not part of a spirv::GraphARMOp type");
}
115 changes: 113 additions & 2 deletions mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"

namespace mlir {
namespace spirv {
Expand Down Expand Up @@ -85,10 +86,36 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
abiInfo.getBinding());
}

/// Creates a global variable for an argument or result based on the ABI info.
static spirv::GlobalVariableOp
createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp,
unsigned index, bool isArg,
spirv::InterfaceVarABIAttr abiInfo) {
auto spirvModule = graphOp->getParentOfType<spirv::ModuleOp>();
if (!spirvModule)
return nullptr;

OpBuilder::InsertionGuard moduleInsertionGuard(builder);
builder.setInsertionPoint(graphOp.getOperation());
std::string varName = llvm::formatv("{}_{}_{}", graphOp.getName(),
isArg ? "arg" : "res", index);

Type varType = isArg ? graphOp.getFunctionType().getInput(index)
: graphOp.getFunctionType().getResult(index);

auto pointerType = spirv::PointerType::get(
varType,
abiInfo.getStorageClass().value_or(spirv::StorageClass::UniformConstant));

return spirv::GlobalVariableOp::create(builder, graphOp.getLoc(), pointerType,
varName, abiInfo.getDescriptorSet(),
abiInfo.getBinding());
}

/// Gets the global variables that need to be specified as interface variable
/// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
static LogicalResult
getInterfaceVariables(spirv::FuncOp funcOp,
getInterfaceVariables(mlir::FunctionOpInterface funcOp,
SmallVectorImpl<Attribute> &interfaceVars) {
auto module = funcOp->getParentOfType<spirv::ModuleOp>();
if (!module) {
Expand Down Expand Up @@ -224,6 +251,21 @@ class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
ConversionPatternRewriter &rewriter) const override;
};

/// A pattern to convert graph signature according to interface variable ABI
/// attributes.
///
/// Specifically, this pattern creates global variables according to interface
/// variable ABI attributes attached to graph arguments and results.
class ProcessGraphInterfaceVarABI final
: public OpConversionPattern<spirv::GraphARMOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

/// Pass to implement the ABI information specified as attributes.
class LowerABIAttributesPass final
: public spirv::impl::SPIRVLowerABIAttributesPassBase<
Expand Down Expand Up @@ -297,6 +339,63 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
return success();
}

LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
spirv::GraphARMOp graphOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Non-entry point graphs are not handled.
if (!graphOp.getEntryPoint().value_or(false))
return failure();

TypeConverter::SignatureConversion signatureConverter(
graphOp.getFunctionType().getNumInputs());

StringRef attrName = spirv::getInterfaceVarABIAttrName();
SmallVector<Attribute, 4> interfaceVars;

// Convert arguments.
unsigned numInputs = graphOp.getFunctionType().getNumInputs();
unsigned numResults = graphOp.getFunctionType().getNumResults();
for (unsigned index = 0; index < numInputs; ++index) {
auto abiInfo =
graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(index, attrName);
if (!abiInfo)
return failure();
spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
rewriter, graphOp, index, true, abiInfo);
if (!var)
return failure();
interfaceVars.push_back(
SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
}

for (unsigned index = 0; index < numResults; ++index) {
auto abiInfo = graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
index, attrName);
if (!abiInfo)
return failure();
spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
rewriter, graphOp, index, false, abiInfo);
if (!var)
return failure();
interfaceVars.push_back(
SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
}

// Update graph signature.
rewriter.modifyOpInPlace(graphOp, [&] {
for (unsigned index = 0; index < numInputs; ++index) {
graphOp.removeArgAttr(index, attrName);
}
for (unsigned index = 0; index < numResults; ++index) {
graphOp.removeResultAttr(index, rewriter.getStringAttr(attrName));
}
});

spirv::GraphEntryPointARMOp::create(rewriter, graphOp.getLoc(), graphOp,
interfaceVars);
return success();
}

void LowerABIAttributesPass::runOnOperation() {
// Uses the signature conversion methodology of the dialect conversion
// framework to implement the conversion.
Expand All @@ -322,7 +421,8 @@ void LowerABIAttributesPass::runOnOperation() {
});

RewritePatternSet patterns(context);
patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
patterns.add<ProcessInterfaceVarABI, ProcessGraphInterfaceVarABI>(
typeConverter, context);

ConversionTarget target(*context);
// "Legal" function ops should have no interface variable ABI attributes.
Expand All @@ -333,6 +433,17 @@ void LowerABIAttributesPass::runOnOperation() {
return false;
return true;
});
target.addDynamicallyLegalOp<spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
StringRef attrName = spirv::getInterfaceVarABIAttrName();
for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
if (op.getArgAttr(i, attrName))
return false;
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i)
if (op.getResultAttr(i, attrName))
return false;
return true;
});

// All other SPIR-V ops are legal.
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return op->getDialect()->getNamespace() ==
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ func.func @interface_var(

// -----

// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
spirv.ARM.Graph @interface_var(%arg: !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (
!spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
) { spirv.ARM.GraphOutputs %arg : !spirv.arm.tensor<1xf32> }

// -----

//===----------------------------------------------------------------------===//
// spirv.resource_limits
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,28 @@ spirv.module Logical GLSL450 {

// -----

module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>, #spirv.resource_limits<>>
} {

// CHECK-LABEL: spirv.module
spirv.module Logical Vulkan {
// CHECK-DAG: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
// CHECK-DAG: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>

// CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
// CHECK: spirv.ARM.Graph [[GN]]([[ARG0:%.*]]: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = true}
spirv.ARM.Graph @main(%arg0: !spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
-> (!spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
}
} // end spirv.module

} // end module

// -----

module {
// expected-error@+1 {{'spirv.module' op missing SPIR-V target env attribute}}
spirv.module Logical GLSL450 {}
Expand Down