Skip to content

Commit cf3a887

Browse files
[mlir][spirv] Add support for SPV_ARM_graph extension - part 1 (#151934)
This is the first patch to add support for the SPV_ARM_graph SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new Graph abstraction for expressing dataflow computations over full resources. The part 1 implementation includes: - A new `GraphType`, modeled similarly to `FunctionType`, for typed graph signatures. - New operations in the `spirv.arm` namespace: - `spirv.arm.Graph` - `spirv.arm.GraphEntryPoint` - `spirv.arm.GraphConstant` - `spirv.arm.GraphOutput` - Verifier and VCE updates to properly gate usage under SPV_ARM_graph. - Tests covering parsing and verification. Graphs currently support only SPV_ARM_tensors, but are designed to generalize to other resource types, such as images. Spec: KhronosGroup/SPIRV-Registry#346 RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947 --------- Signed-off-by: Davide Grohmann <[email protected]>
1 parent 0c0c55a commit cf3a887

File tree

17 files changed

+777
-15
lines changed

17 files changed

+777
-15
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m
425425
def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
426426

427427
def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
428+
def SPV_ARM_graph : I32EnumAttrCase<"SPV_ARM_graph", 6001>;
428429

429430
def SPIRV_ExtensionAttr :
430431
SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
@@ -449,7 +450,7 @@ def SPIRV_ExtensionAttr :
449450
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
450451
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
451452
SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
452-
SPV_ARM_tensors,
453+
SPV_ARM_tensors, SPV_ARM_graph,
453454
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
454455
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
455456
SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1341,6 +1342,12 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"Stora
13411342
Extension<[SPV_ARM_tensors]>
13421343
];
13431344
}
1345+
def SPIRV_C_GraphARM : I32EnumAttrCase<"GraphARM", 4191> {
1346+
list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM];
1347+
list<Availability> availability = [
1348+
Extension<[SPV_ARM_graph]>
1349+
];
1350+
}
13441351
def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
13451352
list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
13461353
list<Availability> availability = [
@@ -1560,7 +1567,7 @@ def SPIRV_CapabilityAttr :
15601567
SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
15611568
SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
15621569
SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
1563-
SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
1570+
SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, SPIRV_C_GraphARM,
15641571
SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
15651572
SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
15661573
SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4569,6 +4576,13 @@ def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUnifo
45694576
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
45704577
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
45714578
def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
4579+
def SPIRV_OC_OpGraphConstantARM : I32EnumAttrCase<"OpGraphConstantARM", 4181>;
4580+
def SPIRV_OC_OpGraphEntryPointARM : I32EnumAttrCase<"OpGraphEntryPointARM", 4182>;
4581+
def SPIRV_OC_OpGraphARM : I32EnumAttrCase<"OpGraphARM", 4183>;
4582+
def SPIRV_OC_OpGraphInputARM : I32EnumAttrCase<"OpGraphInputARM", 4184>;
4583+
def SPIRV_OC_OpGraphSetOutputARM : I32EnumAttrCase<"OpGraphSetOutputARM", 4185>;
4584+
def SPIRV_OC_OpGraphEndARM : I32EnumAttrCase<"OpGraphEndARM", 4186>;
4585+
def SPIRV_OC_OpTypeGraphARM : I32EnumAttrCase<"OpTypeGraphARM", 4190>;
45724586
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
45734587
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
45744588
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4689,6 +4703,9 @@ def SPIRV_OpcodeAttr :
46894703
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
46904704
SPIRV_OC_OpGroupNonUniformLogicalXor,
46914705
SPIRV_OC_OpTypeTensorARM,
4706+
SPIRV_OC_OpGraphEntryPointARM, SPIRV_OC_OpGraphARM,
4707+
SPIRV_OC_OpGraphInputARM, SPIRV_OC_OpGraphSetOutputARM, SPIRV_OC_OpGraphEndARM,
4708+
SPIRV_OC_OpTypeGraphARM, SPIRV_OC_OpGraphConstantARM,
46924709
SPIRV_OC_OpSubgroupBallotKHR,
46934710
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
46944711
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
@@ -4862,6 +4879,10 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
48624879
SPIRV_VendorOp<mnemonic, "NV", traits> {
48634880
}
48644881

4882+
class SPIRV_ArmVendorOp<string mnemonic, list<Trait> traits = []> :
4883+
SPIRV_VendorOp<mnemonic, "ARM", traits> {
4884+
}
4885+
48654886
def SPIRV_FPFMM_None : I32BitEnumAttrCaseNone<"None">;
48664887
def SPIRV_FPFMM_NotNaN : I32BitEnumAttrCaseBit<"NotNaN", 0>;
48674888
def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>;
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
//===- SPIRVGraphOps.td - Graph extended insts spec file -----*- tablegen -*-=//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This is the op definition spec of Graph extension ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
14+
#define MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
15+
16+
include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
17+
include "mlir/Interfaces/CallInterfaces.td"
18+
include "mlir/Interfaces/SideEffectInterfaces.td"
19+
include "mlir/Interfaces/FunctionInterfaces.td"
20+
21+
//===----------------------------------------------------------------------===//
22+
// SPIR-V Graph opcode specification.
23+
//===----------------------------------------------------------------------===//
24+
25+
// Base class for all Graph ops.
26+
class SPIRV_GraphARMOp<string mnemonic, list<Trait> traits = []> :
27+
SPIRV_ArmVendorOp<mnemonic, traits> {
28+
29+
let availability = [
30+
MinVersion<SPIRV_V_1_0>,
31+
MaxVersion<SPIRV_V_1_6>,
32+
Extension<[SPV_ARM_graph, SPV_ARM_tensors]>,
33+
Capability<[SPIRV_C_GraphARM]>
34+
];
35+
}
36+
37+
def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
38+
AutomaticAllocationScope, DeclareOpInterfaceMethods<CallableOpInterface>,
39+
FunctionOpInterface, InModuleScope, IsolatedFromAbove
40+
]> {
41+
42+
let summary = "Declare or define a SPIR-V graph";
43+
44+
let description = [{
45+
This op declares or defines a SPIR-V graph using one region, which
46+
contains one or more blocks.
47+
48+
This op is not allowed to implicitly capture global values, and all external
49+
references must use function arguments or symbol references. This op itself
50+
defines a symbol that is unique in the enclosing module op.
51+
52+
Note that this op does not have a 1:1 mapping to the SPIR-V ops representing
53+
a graph. Indeed during serialization a single GraphARMOp is serialized into
54+
several different SPIR-V ops: OpGraphARM, OpGraphInputARM and OpGraphEndARM.
55+
There are as many occurences of OpGraphInputARM ops as many inputs in the
56+
graph. Deserialization maps that set of operations into a single GraphARMOp.
57+
58+
This op itself takes no operands and generates no results. Its region
59+
can take zero or more arguments and return one or more values.
60+
61+
```
62+
spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
63+
region
64+
```
65+
66+
#### Example:
67+
68+
```mlir
69+
spirv.ARM.Graph @graph(%arg0: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
70+
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
71+
}
72+
```
73+
}];
74+
75+
let arguments = (ins
76+
TypeAttrOf<GraphType>:$function_type,
77+
OptionalAttr<DictArrayAttr>:$arg_attrs,
78+
OptionalAttr<DictArrayAttr>:$res_attrs,
79+
OptionalAttr<BoolAttr>:$entry_point,
80+
StrAttr:$sym_name
81+
);
82+
83+
let results = (outs);
84+
85+
let regions = (region AnyRegion:$body);
86+
87+
let hasVerifier = 0;
88+
89+
let builders = [
90+
OpBuilder<(ins "StringRef":$name, "GraphType":$type,
91+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs, CArg<"bool", "false">:$entry_point)>];
92+
93+
let hasOpcode = 0;
94+
95+
let autogenSerialization = 0;
96+
97+
let extraClassDeclaration = [{
98+
/// Hook for FunctionOpInterface, called after verifying that the 'type'
99+
/// attribute is present and checks if it holds a function type. Ensures
100+
/// getType, getNumArguments, and getNumResults can be called safely
101+
LogicalResult verifyType();
102+
103+
/// Hook for FunctionOpInterface, called after verifying the function
104+
/// type and the presence of the (potentially empty) function body.
105+
/// Ensures SPIR-V specific semantics.
106+
LogicalResult verifyBody();
107+
}];
108+
}
109+
110+
// -----
111+
112+
// Check that an op can only be used within the scope of a spirv.ARM.Graph op.
113+
def InGraphScope : PredOpTrait<
114+
"op must appear in a spirv.ARM.Graph op's block",
115+
CPred<"isNestedInGraphARMOpInterface($_op.getParentOp())">>;
116+
117+
// -----
118+
119+
def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [InGraphScope, Pure, ConstantLike]> {
120+
let summary = "Declare a graph constant.";
121+
122+
let description = [{
123+
Declare a graph constant.
124+
Result Type must be an OpTypeTensorARM.
125+
GraphConstantID must be a 32-bit integer literal.
126+
127+
#### Example:
128+
129+
```mlir
130+
%0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
131+
```
132+
133+
GraphConstantID is a unique identifier which is use to map the contants
134+
defined by GraphConstantARM in the SPIRV module with the one provided at
135+
shader creation time via the VkDataGraphPipelineShaderModuleCreateInfoARM.
136+
That Vulkan structure provides a list of VkDataGraphPipelineConstantARM
137+
which contains the bindings from id to data. (For more details see
138+
https://registry.khronos.org/vulkan/specs/latest/html/vkspec.html#graphs)
139+
}];
140+
141+
let arguments = (ins
142+
I32Attr: $graph_constant_id
143+
);
144+
145+
let results = (outs
146+
SPIRV_AnyTensorArm:$output
147+
);
148+
149+
let hasVerifier = 0;
150+
151+
let autogenSerialization = 0;
152+
153+
let assemblyFormat = [{
154+
attr-dict `:` type($output)
155+
}];
156+
}
157+
158+
// -----
159+
160+
def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> {
161+
let summary = [{
162+
Declare a graph entry point and its interface.
163+
}];
164+
165+
let description = [{
166+
Graph Entry Point must be the Result <id> of an OpGraphARM instruction.
167+
168+
Name is a name string for the graphentry point. A module cannot have two
169+
OpGraphEntryPointARM instructions with the same Name string.
170+
171+
Interface is a list of symbol references to `spirv.GlobalVariable`
172+
operations. These declare the set of global variables from a
173+
module that form the interface of this entry point. The set of
174+
Interface symbols must be equal to or a superset of the
175+
`spirv.GlobalVariable`s referenced by the entry point’s static call
176+
tree, within the interface’s storage classes.
177+
178+
#### Example:
179+
180+
```mlir
181+
spirv.GlobalVariable @arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
182+
spirv.GlobalVariable @res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
183+
spirv.ARM.GraphEntryPoint @graph, @arg_0, @res_0
184+
spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
185+
...
186+
}
187+
```
188+
}];
189+
190+
let arguments = (ins
191+
FlatSymbolRefAttr:$fn,
192+
SymbolRefArrayAttr:$interface
193+
);
194+
195+
let results = (outs);
196+
197+
// Checks for graph and interface symbol reference are done in spirv::ModuleOp verification.
198+
let hasVerifier = 0;
199+
200+
let autogenSerialization = 0;
201+
202+
let builders = [
203+
OpBuilder<(ins "spirv::GraphARMOp":$graph, "ArrayRef<Attribute>":$interfaceVars)>];
204+
}
205+
206+
// -----
207+
208+
def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pure,
209+
Terminator]> {
210+
211+
let summary = "Define graph outputs.";
212+
213+
let description = [{
214+
Values are the graph outputs values and must match the GraphOutputs Type
215+
operand of the OpTypeGraphARM type of the OpGraphARM body this
216+
instruction is in.
217+
218+
This instruction must be the last instruction in a block.
219+
220+
#### Example:
221+
222+
```mlir
223+
spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
224+
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
225+
}
226+
```
227+
}];
228+
229+
let arguments = (ins
230+
Variadic<SPIRV_AnyTensorArm>:$value
231+
);
232+
233+
let results = (outs);
234+
235+
let autogenSerialization = 0;
236+
237+
let hasOpcode = 0;
238+
239+
let assemblyFormat = "$value attr-dict `:` type($value)";
240+
}
241+
242+
#endif // MLIR_DIALECT_SPIRV_IR_GRAPH_OPS

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
3232
include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
3333
include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
3434
include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
35+
include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
3536
include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
3637
include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
3738
include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td"

mlir/include/mlir/IR/Builders.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Type;
2424
class IntegerType;
2525
class FloatType;
2626
class FunctionType;
27+
class GraphType;
2728
class IndexType;
2829
class MemRefType;
2930
class VectorType;
@@ -81,6 +82,7 @@ class Builder {
8182
IntegerType getIntegerType(unsigned width);
8283
IntegerType getIntegerType(unsigned width, bool isSigned);
8384
FunctionType getFunctionType(TypeRange inputs, TypeRange results);
85+
GraphType getGraphType(TypeRange inputs, TypeRange results);
8486
TupleType getTupleType(TypeRange elementTypes);
8587
NoneType getNoneType();
8688

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
403403
// FunctionType
404404
//===----------------------------------------------------------------------===//
405405

406-
def Builtin_Function : Builtin_Type<"Function", "function"> {
406+
class Builtin_FunctionLike<string Name, string typeMnemonic> : Builtin_Type<Name, typeMnemonic> {
407407
let summary = "Map from a list of inputs to a list of results";
408408
let description = [{
409409
Syntax:
@@ -434,6 +434,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
434434
}]>
435435
];
436436
let skipDefaultBuilders = 1;
437+
let storageClass = "FunctionTypeStorage";
437438
let genStorageClass = 0;
438439
let extraClassDeclaration = [{
439440
/// Input types.
@@ -444,23 +445,26 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
444445
unsigned getNumResults() const;
445446
Type getResult(unsigned i) const { return getResults()[i]; }
446447

447-
/// Returns a clone of this function type with the given argument
448+
/// Returns a clone of this function-like type with the given argument
448449
/// and result types.
449-
FunctionType clone(TypeRange inputs, TypeRange results) const;
450+
}] # Name # "Type" # [{ clone(TypeRange inputs, TypeRange results) const;
450451

451-
/// Returns a new function type with the specified arguments and results
452+
/// Returns a new function-like type with the specified arguments and results
452453
/// inserted.
453-
FunctionType getWithArgsAndResults(ArrayRef<unsigned> argIndices,
454+
}] # Name # "Type" # [{ getWithArgsAndResults(ArrayRef<unsigned> argIndices,
454455
TypeRange argTypes,
455456
ArrayRef<unsigned> resultIndices,
456457
TypeRange resultTypes);
457458

458-
/// Returns a new function type without the specified arguments and results.
459-
FunctionType getWithoutArgsAndResults(const BitVector &argIndices,
459+
/// Returns a new function-like type without the specified arguments and results.
460+
}] # Name # "Type" # [{ getWithoutArgsAndResults(const BitVector &argIndices,
460461
const BitVector &resultIndices);
461462
}];
462463
}
463464

465+
def Builtin_Function : Builtin_FunctionLike<"Function", "function">;
466+
def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">;
467+
464468
//===----------------------------------------------------------------------===//
465469
// IndexType
466470
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)