Skip to content

Commit 8036edb

Browse files
[mlir][spirv] Add support for SPV_ARM_graph extension - part 3 (#156845)
This is the third patch to add support for the `SPV_ARM_graph` SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new `Graph` abstraction for expressing dataflow computations over full resources. The part 3 implementation includes: - ABI lowering support for graph entry points via `LowerABIAttributesPass`. - Tests covering ABI handling. Graphs currently support only `SPV_ARM_tensors`, but are designed to generalize to other resource types, such as images. Spec: KhronosGroup/SPIRV-Registry#346 RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947 --------- Signed-off-by: Davide Grohmann <[email protected]>
1 parent 248ad71 commit 8036edb

File tree

4 files changed

+150
-4
lines changed

4 files changed

+150
-4
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
10661066
}
10671067

10681068
LogicalResult SPIRVDialect::verifyRegionResultAttribute(
1069-
Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
1069+
Operation *op, unsigned /*regionIndex*/, unsigned resultIndex,
10701070
NamedAttribute attribute) {
1071-
return op->emitError("cannot attach SPIR-V attributes to region result");
1071+
if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op))
1072+
return verifyRegionAttribute(
1073+
op->getLoc(), graphOp.getResultTypes()[resultIndex], attribute);
1074+
return op->emitError(
1075+
"cannot attach SPIR-V attributes to region result which is "
1076+
"not part of a spirv::GraphARMOp type");
10721077
}

mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
2323
#include "mlir/IR/BuiltinAttributes.h"
2424
#include "mlir/Transforms/DialectConversion.h"
25+
#include "llvm/Support/FormatVariadic.h"
2526

2627
namespace mlir {
2728
namespace spirv {
@@ -85,10 +86,36 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
8586
abiInfo.getBinding());
8687
}
8788

89+
/// Creates a global variable for an argument or result based on the ABI info.
90+
static spirv::GlobalVariableOp
91+
createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp,
92+
unsigned index, bool isArg,
93+
spirv::InterfaceVarABIAttr abiInfo) {
94+
auto spirvModule = graphOp->getParentOfType<spirv::ModuleOp>();
95+
if (!spirvModule)
96+
return nullptr;
97+
98+
OpBuilder::InsertionGuard moduleInsertionGuard(builder);
99+
builder.setInsertionPoint(graphOp.getOperation());
100+
std::string varName = llvm::formatv("{}_{}_{}", graphOp.getName(),
101+
isArg ? "arg" : "res", index);
102+
103+
Type varType = isArg ? graphOp.getFunctionType().getInput(index)
104+
: graphOp.getFunctionType().getResult(index);
105+
106+
auto pointerType = spirv::PointerType::get(
107+
varType,
108+
abiInfo.getStorageClass().value_or(spirv::StorageClass::UniformConstant));
109+
110+
return spirv::GlobalVariableOp::create(builder, graphOp.getLoc(), pointerType,
111+
varName, abiInfo.getDescriptorSet(),
112+
abiInfo.getBinding());
113+
}
114+
88115
/// Gets the global variables that need to be specified as interface variable
89116
/// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
90117
static LogicalResult
91-
getInterfaceVariables(spirv::FuncOp funcOp,
118+
getInterfaceVariables(mlir::FunctionOpInterface funcOp,
92119
SmallVectorImpl<Attribute> &interfaceVars) {
93120
auto module = funcOp->getParentOfType<spirv::ModuleOp>();
94121
if (!module) {
@@ -224,6 +251,21 @@ class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
224251
ConversionPatternRewriter &rewriter) const override;
225252
};
226253

254+
/// A pattern to convert graph signature according to interface variable ABI
255+
/// attributes.
256+
///
257+
/// Specifically, this pattern creates global variables according to interface
258+
/// variable ABI attributes attached to graph arguments and results.
259+
class ProcessGraphInterfaceVarABI final
260+
: public OpConversionPattern<spirv::GraphARMOp> {
261+
public:
262+
using OpConversionPattern::OpConversionPattern;
263+
264+
LogicalResult
265+
matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor,
266+
ConversionPatternRewriter &rewriter) const override;
267+
};
268+
227269
/// Pass to implement the ABI information specified as attributes.
228270
class LowerABIAttributesPass final
229271
: public spirv::impl::SPIRVLowerABIAttributesPassBase<
@@ -297,6 +339,63 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
297339
return success();
298340
}
299341

342+
LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
343+
spirv::GraphARMOp graphOp, OpAdaptor adaptor,
344+
ConversionPatternRewriter &rewriter) const {
345+
// Non-entry point graphs are not handled.
346+
if (!graphOp.getEntryPoint().value_or(false))
347+
return failure();
348+
349+
TypeConverter::SignatureConversion signatureConverter(
350+
graphOp.getFunctionType().getNumInputs());
351+
352+
StringRef attrName = spirv::getInterfaceVarABIAttrName();
353+
SmallVector<Attribute, 4> interfaceVars;
354+
355+
// Convert arguments.
356+
unsigned numInputs = graphOp.getFunctionType().getNumInputs();
357+
unsigned numResults = graphOp.getFunctionType().getNumResults();
358+
for (unsigned index = 0; index < numInputs; ++index) {
359+
auto abiInfo =
360+
graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(index, attrName);
361+
if (!abiInfo)
362+
return failure();
363+
spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
364+
rewriter, graphOp, index, true, abiInfo);
365+
if (!var)
366+
return failure();
367+
interfaceVars.push_back(
368+
SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
369+
}
370+
371+
for (unsigned index = 0; index < numResults; ++index) {
372+
auto abiInfo = graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
373+
index, attrName);
374+
if (!abiInfo)
375+
return failure();
376+
spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
377+
rewriter, graphOp, index, false, abiInfo);
378+
if (!var)
379+
return failure();
380+
interfaceVars.push_back(
381+
SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
382+
}
383+
384+
// Update graph signature.
385+
rewriter.modifyOpInPlace(graphOp, [&] {
386+
for (unsigned index = 0; index < numInputs; ++index) {
387+
graphOp.removeArgAttr(index, attrName);
388+
}
389+
for (unsigned index = 0; index < numResults; ++index) {
390+
graphOp.removeResultAttr(index, rewriter.getStringAttr(attrName));
391+
}
392+
});
393+
394+
spirv::GraphEntryPointARMOp::create(rewriter, graphOp.getLoc(), graphOp,
395+
interfaceVars);
396+
return success();
397+
}
398+
300399
void LowerABIAttributesPass::runOnOperation() {
301400
// Uses the signature conversion methodology of the dialect conversion
302401
// framework to implement the conversion.
@@ -322,7 +421,8 @@ void LowerABIAttributesPass::runOnOperation() {
322421
});
323422

324423
RewritePatternSet patterns(context);
325-
patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
424+
patterns.add<ProcessInterfaceVarABI, ProcessGraphInterfaceVarABI>(
425+
typeConverter, context);
326426

327427
ConversionTarget target(*context);
328428
// "Legal" function ops should have no interface variable ABI attributes.
@@ -333,6 +433,17 @@ void LowerABIAttributesPass::runOnOperation() {
333433
return false;
334434
return true;
335435
});
436+
target.addDynamicallyLegalOp<spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
437+
StringRef attrName = spirv::getInterfaceVarABIAttrName();
438+
for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
439+
if (op.getArgAttr(i, attrName))
440+
return false;
441+
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i)
442+
if (op.getResultAttr(i, attrName))
443+
return false;
444+
return true;
445+
});
446+
336447
// All other SPIR-V ops are legal.
337448
target.markUnknownOpDynamicallyLegal([](Operation *op) {
338449
return op->getDialect()->getNamespace() ==

mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@ func.func @interface_var(
101101

102102
// -----
103103

104+
// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
105+
// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
106+
spirv.ARM.Graph @interface_var(%arg: !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (
107+
!spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
108+
) { spirv.ARM.GraphOutputs %arg : !spirv.arm.tensor<1xf32> }
109+
110+
// -----
111+
104112
//===----------------------------------------------------------------------===//
105113
// spirv.resource_limits
106114
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,28 @@ spirv.module Logical GLSL450 {
3535

3636
// -----
3737

38+
module attributes {
39+
spirv.target_env = #spirv.target_env<
40+
#spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>, #spirv.resource_limits<>>
41+
} {
42+
43+
// CHECK-LABEL: spirv.module
44+
spirv.module Logical Vulkan {
45+
// CHECK-DAG: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
46+
// CHECK-DAG: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
47+
48+
// CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
49+
// CHECK: spirv.ARM.Graph [[GN]]([[ARG0:%.*]]: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = true}
50+
spirv.ARM.Graph @main(%arg0: !spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
51+
-> (!spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
52+
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
53+
}
54+
} // end spirv.module
55+
56+
} // end module
57+
58+
// -----
59+
3860
module {
3961
// expected-error@+1 {{'spirv.module' op missing SPIR-V target env attribute}}
4062
spirv.module Logical GLSL450 {}

0 commit comments

Comments
 (0)