-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][spirv] Add support for SPV_ARM_graph extension - part 3 #156845
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is the third patch to add support for the `SPV_ARM_graph` SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new `Graph` abstraction for expressing dataflow computations over full resources. The part 3 implementation includes: - ABI lowering support for graph entry points via `LowerABIAttributesPass`. - Tests covering ABI handling. Graphs currently support only `SPV_ARM_tensors`, but are designed to generalize to other resource types, such as images. Spec: KhronosGroup/SPIRV-Registry#346 RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947 Signed-off-by: Davide Grohmann <[email protected]> Change-Id: I31896806a3e3a856530149ffd919b8568d5b6208
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Davide Grohmann (davidegrohmann) ChangesThis is the third patch to add support for the The part 3 implementation includes:
Graphs currently support only Spec: KhronosGroup/SPIRV-Registry#346 Full diff: https://github.com/llvm/llvm-project/pull/156845.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index fcf1526491971..44c86bc8777e4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1066,7 +1066,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
}
LogicalResult SPIRVDialect::verifyRegionResultAttribute(
- Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
+ Operation *op, unsigned /*regionIndex*/, unsigned resultIndex,
NamedAttribute attribute) {
- return op->emitError("cannot attach SPIR-V attributes to region result");
+ if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op))
+ return verifyRegionAttribute(
+ op->getLoc(), graphOp.getResultTypes()[resultIndex], attribute);
+ return op->emitError(
+ "cannot attach SPIR-V attributes to region result which is "
+ "not part of a spirv::GraphARMOp type");
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 3911ec08fcc27..91aa0e3823a31 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -85,10 +85,36 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
abiInfo.getBinding());
}
+/// Creates a global variable for an argument or result based on the ABI info.
+static spirv::GlobalVariableOp
+createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp,
+ unsigned index, bool isArg,
+ spirv::InterfaceVarABIAttr abiInfo) {
+ auto spirvModule = graphOp->getParentOfType<spirv::ModuleOp>();
+ if (!spirvModule)
+ return nullptr;
+
+ OpBuilder::InsertionGuard moduleInsertionGuard(builder);
+ builder.setInsertionPoint(graphOp.getOperation());
+ std::string varName = graphOp.getName().str() + (isArg ? "_arg_" : "_res_") +
+ std::to_string(index);
+
+ Type varType = isArg ? graphOp.getFunctionType().getInput(index)
+ : graphOp.getFunctionType().getResult(index);
+
+ auto pointerType = spirv::PointerType::get(
+ varType,
+ abiInfo.getStorageClass().value_or(spirv::StorageClass::UniformConstant));
+
+ return builder.create<spirv::GlobalVariableOp>(
+ graphOp.getLoc(), pointerType, varName, abiInfo.getDescriptorSet(),
+ abiInfo.getBinding());
+}
+
/// Gets the global variables that need to be specified as interface variable
/// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
static LogicalResult
-getInterfaceVariables(spirv::FuncOp funcOp,
+getInterfaceVariables(mlir::FunctionOpInterface funcOp,
SmallVectorImpl<Attribute> &interfaceVars) {
auto module = funcOp->getParentOfType<spirv::ModuleOp>();
if (!module) {
@@ -224,6 +250,21 @@ class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// A pattern to convert graph signature according to interface variable ABI
+/// attributes.
+///
+/// Specifically, this pattern creates global variables according to interface
+/// variable ABI attributes attached to graph arguments and results.
+class ProcessGraphInterfaceVarABI final
+ : public OpConversionPattern<spirv::GraphARMOp> {
+public:
+ using OpConversionPattern<spirv::GraphARMOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Pass to implement the ABI information specified as attributes.
class LowerABIAttributesPass final
: public spirv::impl::SPIRVLowerABIAttributesPassBase<
@@ -297,6 +338,65 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
return success();
}
+LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
+ spirv::GraphARMOp graphOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Non-entry point graphs are not handled.
+ if (!graphOp.getEntryPoint().value_or(false))
+ return failure();
+
+ TypeConverter::SignatureConversion signatureConverter(
+ graphOp.getFunctionType().getNumInputs());
+
+ StringRef attrName = spirv::getInterfaceVarABIAttrName();
+ SmallVector<Attribute, 4> interfaceVars;
+
+ // Convert arguments.
+ unsigned numInputs = graphOp.getFunctionType().getNumInputs();
+ unsigned numResults = graphOp.getFunctionType().getNumResults();
+ for (unsigned index = 0; index < numInputs; ++index) {
+ auto abiInfo =
+ graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(index, attrName);
+ if (!abiInfo)
+ return failure();
+ spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
+ rewriter, graphOp, index, true, abiInfo);
+ if (!var)
+ return failure();
+ interfaceVars.push_back(
+ SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
+ }
+
+ for (unsigned index = 0; index < numResults; ++index) {
+ auto abiInfo = graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+ index, attrName);
+ if (!abiInfo)
+ return failure();
+ spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
+ rewriter, graphOp, index, false, abiInfo);
+ if (!var)
+ return failure();
+ interfaceVars.push_back(
+ SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
+ }
+
+ // Update signature.
+ rewriter.modifyOpInPlace(graphOp, [&] {
+ for (unsigned index = 0; index < numInputs; ++index) {
+ graphOp.removeArgAttr(index, attrName);
+ }
+ for (unsigned index = 0; index < numResults; ++index) {
+ graphOp.removeResultAttr(index, rewriter.getStringAttr(attrName));
+ }
+ });
+
+ OpBuilder::InsertionGuard insertionGuard(rewriter);
+ rewriter.setInsertionPoint(graphOp);
+ rewriter.create<spirv::GraphEntryPointARMOp>(graphOp.getLoc(), graphOp,
+ interfaceVars);
+ return success();
+}
+
void LowerABIAttributesPass::runOnOperation() {
// Uses the signature conversion methodology of the dialect conversion
// framework to implement the conversion.
@@ -323,6 +423,7 @@ void LowerABIAttributesPass::runOnOperation() {
RewritePatternSet patterns(context);
patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
+ patterns.add<ProcessGraphInterfaceVarABI>(typeConverter, context);
ConversionTarget target(*context);
// "Legal" function ops should have no interface variable ABI attributes.
@@ -333,6 +434,17 @@ void LowerABIAttributesPass::runOnOperation() {
return false;
return true;
});
+ target.addDynamicallyLegalOp<spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
+ StringRef attrName = spirv::getInterfaceVarABIAttrName();
+ for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
+ if (op.getArgAttr(i, attrName))
+ return false;
+ for (unsigned i = 0, e = op.getNumResults(); i < e; ++i)
+ if (op.getResultAttr(i, attrName))
+ return false;
+ return true;
+ });
+
// All other SPIR-V ops are legal.
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return op->getDialect()->getNamespace() ==
diff --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
index 10fbcf06eb052..63dea6af83556 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
@@ -101,6 +101,14 @@ func.func @interface_var(
// -----
+// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+spirv.ARM.Graph @interface_var(%arg: !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (
+ !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+) { spirv.ARM.GraphOutputs %arg : !spirv.arm.tensor<1xf32> }
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.resource_limits
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index f3a3218e5aec0..04667c828bbd1 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -35,6 +35,28 @@ spirv.module Logical GLSL450 {
// -----
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: spirv.module
+spirv.module Logical Vulkan {
+ // CHECK-DAG: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
+ // CHECK-DAG: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
+
+ // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+ // CHECK: spirv.ARM.Graph [[GN]]([[ARG0:%.*]]: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = true}
+ spirv.ARM.Graph @main(%arg0: !spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
+ -> (!spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+ spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+ }
+} // end spirv.module
+
+} // end module
+
+// -----
+
module {
// expected-error@+1 {{'spirv.module' op missing SPIR-V target env attribute}}
spirv.module Logical GLSL450 {}
|
|
@kuhar appreciate some feedback on this pr whenever you have time, thanks |
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % nits
Signed-off-by: Davide Grohmann <[email protected]> Change-Id: I83992517e77c9dc53fd5da2e839ba4fc22f9f6a7
b0ef295 to
1441999
Compare
This is the third patch to add support for the
SPV_ARM_graphSPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a newGraphabstraction for expressing dataflow computations over full resources.The part 3 implementation includes:
LowerABIAttributesPass.Graphs currently support only
SPV_ARM_tensors, but are designed to generalize to other resource types, such as images.Spec: KhronosGroup/SPIRV-Registry#346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947