Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
}

LogicalResult SPIRVDialect::verifyRegionResultAttribute(
Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
Operation *op, unsigned /*regionIndex*/, unsigned resultIndex,
NamedAttribute attribute) {
return op->emitError("cannot attach SPIR-V attributes to region result");
if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op))
return verifyRegionAttribute(
op->getLoc(), graphOp.getResultTypes()[resultIndex], attribute);
return op->emitError(
"cannot attach SPIR-V attributes to region result which is "
"not part of a spirv::GraphARMOp type");
}
114 changes: 113 additions & 1 deletion mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,36 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
abiInfo.getBinding());
}

/// Creates a global variable for an argument or result based on the ABI info.
static spirv::GlobalVariableOp
createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp,
unsigned index, bool isArg,
spirv::InterfaceVarABIAttr abiInfo) {
auto spirvModule = graphOp->getParentOfType<spirv::ModuleOp>();
if (!spirvModule)
return nullptr;

OpBuilder::InsertionGuard moduleInsertionGuard(builder);
builder.setInsertionPoint(graphOp.getOperation());
std::string varName = graphOp.getName().str() + (isArg ? "_arg_" : "_res_") +
std::to_string(index);

Type varType = isArg ? graphOp.getFunctionType().getInput(index)
: graphOp.getFunctionType().getResult(index);

auto pointerType = spirv::PointerType::get(
varType,
abiInfo.getStorageClass().value_or(spirv::StorageClass::UniformConstant));

return builder.create<spirv::GlobalVariableOp>(
graphOp.getLoc(), pointerType, varName, abiInfo.getDescriptorSet(),
abiInfo.getBinding());
}

/// Gets the global variables that need to be specified as interface variable
/// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
static LogicalResult
getInterfaceVariables(spirv::FuncOp funcOp,
getInterfaceVariables(mlir::FunctionOpInterface funcOp,
SmallVectorImpl<Attribute> &interfaceVars) {
auto module = funcOp->getParentOfType<spirv::ModuleOp>();
if (!module) {
Expand Down Expand Up @@ -224,6 +250,21 @@ class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
ConversionPatternRewriter &rewriter) const override;
};

/// A pattern to convert graph signature according to interface variable ABI
/// attributes.
///
/// Specifically, this pattern creates global variables according to interface
/// variable ABI attributes attached to graph arguments and results.
class ProcessGraphInterfaceVarABI final
: public OpConversionPattern<spirv::GraphARMOp> {
public:
using OpConversionPattern<spirv::GraphARMOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

/// Pass to implement the ABI information specified as attributes.
class LowerABIAttributesPass final
: public spirv::impl::SPIRVLowerABIAttributesPassBase<
Expand Down Expand Up @@ -297,6 +338,65 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
return success();
}

LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
spirv::GraphARMOp graphOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Non-entry point graphs are not handled.
if (!graphOp.getEntryPoint().value_or(false))
return failure();

TypeConverter::SignatureConversion signatureConverter(
graphOp.getFunctionType().getNumInputs());

StringRef attrName = spirv::getInterfaceVarABIAttrName();
SmallVector<Attribute, 4> interfaceVars;

// Convert arguments.
unsigned numInputs = graphOp.getFunctionType().getNumInputs();
unsigned numResults = graphOp.getFunctionType().getNumResults();
for (unsigned index = 0; index < numInputs; ++index) {
auto abiInfo =
graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(index, attrName);
if (!abiInfo)
return failure();
spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
rewriter, graphOp, index, true, abiInfo);
if (!var)
return failure();
interfaceVars.push_back(
SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
}

for (unsigned index = 0; index < numResults; ++index) {
auto abiInfo = graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
index, attrName);
if (!abiInfo)
return failure();
spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
rewriter, graphOp, index, false, abiInfo);
if (!var)
return failure();
interfaceVars.push_back(
SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
}

// Update signature.
rewriter.modifyOpInPlace(graphOp, [&] {
for (unsigned index = 0; index < numInputs; ++index) {
graphOp.removeArgAttr(index, attrName);
}
for (unsigned index = 0; index < numResults; ++index) {
graphOp.removeResultAttr(index, rewriter.getStringAttr(attrName));
}
});

OpBuilder::InsertionGuard insertionGuard(rewriter);
rewriter.setInsertionPoint(graphOp);
rewriter.create<spirv::GraphEntryPointARMOp>(graphOp.getLoc(), graphOp,
interfaceVars);
return success();
}

void LowerABIAttributesPass::runOnOperation() {
// Uses the signature conversion methodology of the dialect conversion
// framework to implement the conversion.
Expand All @@ -323,6 +423,7 @@ void LowerABIAttributesPass::runOnOperation() {

RewritePatternSet patterns(context);
patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
patterns.add<ProcessGraphInterfaceVarABI>(typeConverter, context);

ConversionTarget target(*context);
// "Legal" function ops should have no interface variable ABI attributes.
Expand All @@ -333,6 +434,17 @@ void LowerABIAttributesPass::runOnOperation() {
return false;
return true;
});
target.addDynamicallyLegalOp<spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
StringRef attrName = spirv::getInterfaceVarABIAttrName();
for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
if (op.getArgAttr(i, attrName))
return false;
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i)
if (op.getResultAttr(i, attrName))
return false;
return true;
});

// All other SPIR-V ops are legal.
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return op->getDialect()->getNamespace() ==
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ func.func @interface_var(

// -----

// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
spirv.ARM.Graph @interface_var(%arg: !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (
!spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
) { spirv.ARM.GraphOutputs %arg : !spirv.arm.tensor<1xf32> }

// -----

//===----------------------------------------------------------------------===//
// spirv.resource_limits
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,28 @@ spirv.module Logical GLSL450 {

// -----

module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>, #spirv.resource_limits<>>
} {

// CHECK-LABEL: spirv.module
spirv.module Logical Vulkan {
// CHECK-DAG: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
// CHECK-DAG: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>

// CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
// CHECK: spirv.ARM.Graph [[GN]]([[ARG0:%.*]]: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = true}
spirv.ARM.Graph @main(%arg0: !spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
-> (!spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
}
} // end spirv.module

} // end module

// -----

module {
// expected-error@+1 {{'spirv.module' op missing SPIR-V target env attribute}}
spirv.module Logical GLSL450 {}
Expand Down
Loading