Skip to content

Commit 8a8b29f

Browse files
authored
[spirv] AMD work graphs extension (#7353)
Enables work graphs for SPIR-V target, based on AMD_shader_enqueue extension. Closes #5960.
1 parent 48d6e3c commit 8a8b29f

File tree

61 files changed

+3193
-112
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+3193
-112
lines changed

tools/clang/include/clang/AST/HlslTypes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,10 @@ bool IsHLSLObjectWithImplicitMemberAccess(clang::QualType type);
485485
bool IsHLSLObjectWithImplicitROMemberAccess(clang::QualType type);
486486
bool IsHLSLRWNodeInputRecordType(clang::QualType type);
487487
bool IsHLSLRONodeInputRecordType(clang::QualType type);
488+
bool IsHLSLDispatchNodeInputRecordType(clang::QualType type);
489+
bool IsHLSLNodeRecordArrayType(clang::QualType type);
488490
bool IsHLSLNodeOutputType(clang::QualType type);
491+
bool IsHLSLEmptyNodeRecordType(clang::QualType type);
489492

490493
DXIL::NodeIOKind GetNodeIOType(clang::QualType type);
491494

@@ -495,6 +498,8 @@ bool IsHLSLCopyableAnnotatableRecord(clang::QualType QT);
495498
bool IsHLSLBuiltinRayAttributeStruct(clang::QualType QT);
496499
bool IsHLSLAggregateType(clang::QualType type);
497500
clang::QualType GetHLSLResourceResultType(clang::QualType type);
501+
clang::QualType GetHLSLNodeIOResultType(clang::ASTContext &astContext,
502+
clang::QualType type);
498503
unsigned GetHLSLResourceTemplateUInt(clang::QualType type);
499504
bool IsIncompleteHLSLResourceArrayType(clang::ASTContext &context,
500505
clang::QualType type);

tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8060,6 +8060,8 @@ def err_hlsl_vk_pointer_cast_alignment: Error<
80608060
"Vulkan buffer pointer cannot be cast to greater alignment">;
80618061
def err_hlsl_vk_static_pointer_cast_type: Error<
80628062
"vk::static_pointer_cast() content type must be base class of argument's content type">;
8063+
def warn_spirv_node_shaders_experimental : Warning<
8064+
"SPIR-V implementation of node shaders is experimental and subject to change">;
80638065
// SPIRV Change Ends
80648066

80658067
let CategoryName = "OpenMP Issue" in {

tools/clang/include/clang/SPIRV/FeatureManager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ enum class Extension {
5757
KHR_ray_query,
5858
EXT_shader_image_int64,
5959
KHR_physical_storage_buffer,
60+
AMD_shader_enqueue,
6061
KHR_vulkan_memory_model,
6162
NV_compute_shader_derivatives,
6263
KHR_compute_shader_derivatives,

tools/clang/include/clang/SPIRV/SpirvBuilder.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,25 @@ class SpirvBuilder {
434434
QualType resultType, NonSemanticDebugPrintfInstructions instId,
435435
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation);
436436

437+
SpirvInstruction *createIsNodePayloadValid(SpirvInstruction *payloadArray,
438+
SpirvInstruction *nodeIndex,
439+
SourceLocation);
440+
441+
SpirvInstruction *createNodePayloadArrayLength(SpirvInstruction *payloadArray,
442+
SourceLocation);
443+
444+
SpirvInstruction *createAllocateNodePayloads(QualType resultType,
445+
spv::Scope allocationScope,
446+
SpirvInstruction *shaderIndex,
447+
SpirvInstruction *recordCount,
448+
SourceLocation);
449+
450+
void createEnqueueOutputNodePayloads(SpirvInstruction *payload,
451+
SourceLocation);
452+
453+
SpirvInstruction *createFinishWritingNodePayload(SpirvInstruction *payload,
454+
SourceLocation);
455+
437456
/// \brief Creates an OpMemoryBarrier or OpControlBarrier instruction with the
438457
/// given flags. If execution scope (exec) is provided, an OpControlBarrier
439458
/// is created; otherwise an OpMemoryBarrier is created.
@@ -763,6 +782,7 @@ class SpirvBuilder {
763782
llvm::ArrayRef<SpirvConstant *> constituents,
764783
bool specConst = false);
765784
SpirvConstant *getConstantNull(QualType);
785+
SpirvConstant *getConstantString(llvm::StringRef str, bool specConst = false);
766786
SpirvUndef *getUndef(QualType);
767787

768788
SpirvString *createString(llvm::StringRef str);

tools/clang/include/clang/SPIRV/SpirvContext.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,21 @@ struct RuntimeArrayTypeMapInfo {
9898
}
9999
};
100100

101+
// Provides DenseMapInfo for NodePayloadArrayType so we can create a DenseSet of
102+
// node payload array types.
103+
struct NodePayloadArrayTypeMapInfo {
104+
static inline NodePayloadArrayType *getEmptyKey() { return nullptr; }
105+
static inline NodePayloadArrayType *getTombstoneKey() { return nullptr; }
106+
static unsigned getHashValue(const NodePayloadArrayType *Val) {
107+
return llvm::hash_combine(Val->getElementType(), Val->getNodeDecl());
108+
}
109+
static bool isEqual(const NodePayloadArrayType *LHS,
110+
const NodePayloadArrayType *RHS) {
111+
// Either both are null, or both should have the same underlying type.
112+
return (LHS == RHS) || (LHS && RHS && *LHS == *RHS);
113+
}
114+
};
115+
101116
// Provides DenseMapInfo for ImageType so we can create a DenseSet of
102117
// image types.
103118
struct ImageTypeMapInfo {
@@ -270,6 +285,9 @@ class SpirvContext {
270285
const RuntimeArrayType *
271286
getRuntimeArrayType(const SpirvType *elemType,
272287
llvm::Optional<uint32_t> arrayStride);
288+
const NodePayloadArrayType *
289+
getNodePayloadArrayType(const SpirvType *elemType,
290+
const ParmVarDecl *nodeDecl);
273291

274292
const StructType *getStructType(
275293
llvm::ArrayRef<StructType::FieldInfo> fields, llvm::StringRef name,
@@ -346,6 +364,7 @@ class SpirvContext {
346364
bool isDS() const { return curShaderModelKind == ShaderModelKind::Domain; }
347365
bool isCS() const { return curShaderModelKind == ShaderModelKind::Compute; }
348366
bool isLib() const { return curShaderModelKind == ShaderModelKind::Library; }
367+
bool isNode() const { return curShaderModelKind == ShaderModelKind::Node; }
349368
bool isRay() const {
350369
return curShaderModelKind >= ShaderModelKind::RayGeneration &&
351370
curShaderModelKind <= ShaderModelKind::Callable;
@@ -437,6 +456,31 @@ class SpirvContext {
437456
instructionsWithLoweredType.end();
438457
}
439458

459+
void registerDispatchGridIndex(const RecordDecl *decl, unsigned index) {
460+
auto iter = dispatchGridIndices.find(decl);
461+
if (iter == dispatchGridIndices.end()) {
462+
dispatchGridIndices[decl] = index;
463+
}
464+
}
465+
466+
llvm::Optional<unsigned> getDispatchGridIndex(const RecordDecl *decl) {
467+
auto iter = dispatchGridIndices.find(decl);
468+
if (iter != dispatchGridIndices.end()) {
469+
return iter->second;
470+
}
471+
return llvm::None;
472+
}
473+
474+
void registerNodeDeclPayloadType(const NodePayloadArrayType *type,
475+
const ParmVarDecl *decl) {
476+
nodeDecls[decl] = type;
477+
}
478+
479+
const NodePayloadArrayType *getNodeDeclPayloadType(const ParmVarDecl *decl) {
480+
auto iter = nodeDecls.find(decl);
481+
return iter == nodeDecls.end() ? nullptr : iter->second;
482+
}
483+
440484
private:
441485
/// \brief The allocator used to create SPIR-V entity objects.
442486
///
@@ -481,6 +525,8 @@ class SpirvContext {
481525
llvm::DenseSet<const ArrayType *, ArrayTypeMapInfo> arrayTypes;
482526
llvm::DenseSet<const RuntimeArrayType *, RuntimeArrayTypeMapInfo>
483527
runtimeArrayTypes;
528+
llvm::DenseSet<const NodePayloadArrayType *, NodePayloadArrayTypeMapInfo>
529+
nodePayloadArrayTypes;
484530
llvm::SmallVector<const StructType *, 8> structTypes;
485531
llvm::SmallVector<const HybridStructType *, 8> hybridStructTypes;
486532
llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
@@ -507,6 +553,9 @@ class SpirvContext {
507553
llvm::StringMap<RichDebugInfo> debugInfo;
508554
SpirvDebugInstruction *currentLexicalScope;
509555

556+
// Mapping from graphics node input record types to member decoration maps.
557+
llvm::MapVector<const RecordDecl *, unsigned> dispatchGridIndices;
558+
510559
// Mapping from SPIR-V type to debug type instruction.
511560
// The purpose is not to generate several DebugType* instructions for the same
512561
// type if the type is used for several variables.
@@ -538,6 +587,10 @@ class SpirvContext {
538587

539588
// Set of instructions that already have lowered SPIR-V types.
540589
llvm::DenseSet<const SpirvInstruction *> instructionsWithLoweredType;
590+
591+
// Mapping from shader entry function parameter declaration to node payload
592+
// array type.
593+
llvm::MapVector<const ParmVarDecl *, const NodePayloadArrayType *> nodeDecls;
541594
};
542595

543596
} // end namespace spirv

tools/clang/include/clang/SPIRV/SpirvInstruction.h

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class SpirvInstruction {
6767
IK_ConstantInteger,
6868
IK_ConstantFloat,
6969
IK_ConstantComposite,
70+
IK_ConstantString,
7071
IK_ConstantNull,
7172

7273
// Pointer <-> uint conversions.
@@ -165,6 +166,13 @@ class SpirvInstruction {
165166
IK_DebugTypeMember,
166167
IK_DebugTypeTemplate,
167168
IK_DebugTypeTemplateParameter,
169+
170+
// For workgraph instructions
171+
IK_IsNodePayloadValid,
172+
IK_NodePayloadArrayLength,
173+
IK_AllocateNodePayloads,
174+
IK_EnqueueNodePayloads,
175+
IK_FinishWritingNodePayload,
168176
};
169177

170178
// All instruction classes should include a releaseMemory method.
@@ -440,9 +448,13 @@ class SpirvExecutionMode : public SpirvExecutionModeBase {
440448

441449
bool invokeVisitor(Visitor *v) override;
442450

451+
SpirvFunction *getEntryPoint() const { return entryPoint; }
452+
spv::ExecutionMode getExecutionMode() const { return execMode; }
443453
llvm::ArrayRef<uint32_t> getParams() const { return params; }
444454

445455
private:
456+
SpirvFunction *entryPoint;
457+
spv::ExecutionMode execMode;
446458
llvm::SmallVector<uint32_t, 4> params;
447459
};
448460

@@ -1056,6 +1068,119 @@ class SpirvBarrier : public SpirvInstruction {
10561068
llvm::Optional<spv::Scope> executionScope;
10571069
};
10581070

1071+
/// \brief OpIsNodePayloadValidAMDX instruction
1072+
class SpirvIsNodePayloadValid : public SpirvInstruction {
1073+
public:
1074+
SpirvIsNodePayloadValid(QualType resultType, SourceLocation loc,
1075+
SpirvInstruction *payloadArray,
1076+
SpirvInstruction *nodeIndex);
1077+
1078+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvIsNodePayloadValid)
1079+
1080+
// For LLVM-style RTTI
1081+
static bool classof(const SpirvInstruction *inst) {
1082+
return inst->getKind() == IK_IsNodePayloadValid;
1083+
}
1084+
1085+
bool invokeVisitor(Visitor *v) override;
1086+
1087+
SpirvInstruction *getPayloadArray() { return payloadArray; }
1088+
SpirvInstruction *getNodeIndex() { return nodeIndex; }
1089+
1090+
private:
1091+
SpirvInstruction *payloadArray;
1092+
SpirvInstruction *nodeIndex;
1093+
};
1094+
1095+
/// \brief OpNodePayloadArrayLengthAMDX instruction
1096+
class SpirvNodePayloadArrayLength : public SpirvInstruction {
1097+
public:
1098+
SpirvNodePayloadArrayLength(QualType resultType, SourceLocation loc,
1099+
SpirvInstruction *payloadArray);
1100+
1101+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNodePayloadArrayLength)
1102+
1103+
// For LLVM-style RTTI
1104+
static bool classof(const SpirvInstruction *inst) {
1105+
return inst->getKind() == IK_NodePayloadArrayLength;
1106+
}
1107+
1108+
bool invokeVisitor(Visitor *v) override;
1109+
1110+
SpirvInstruction *getPayloadArray() { return payloadArray; }
1111+
1112+
private:
1113+
SpirvInstruction *payloadArray;
1114+
};
1115+
1116+
/// \brief OpAllocateNodePayloadsAMDX instruction
1117+
class SpirvAllocateNodePayloads : public SpirvInstruction {
1118+
public:
1119+
SpirvAllocateNodePayloads(QualType resultType, SourceLocation loc,
1120+
spv::Scope allocationScope,
1121+
SpirvInstruction *shaderIndex,
1122+
SpirvInstruction *recordCount);
1123+
1124+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvAllocateNodePayloads)
1125+
1126+
// For LLVM-style RTTI
1127+
static bool classof(const SpirvInstruction *inst) {
1128+
return inst->getKind() == IK_AllocateNodePayloads;
1129+
}
1130+
1131+
bool invokeVisitor(Visitor *v) override;
1132+
1133+
spv::Scope getAllocationScope() { return allocationScope; }
1134+
SpirvInstruction *getShaderIndex() { return shaderIndex; }
1135+
SpirvInstruction *getRecordCount() { return recordCount; }
1136+
1137+
private:
1138+
spv::Scope allocationScope;
1139+
SpirvInstruction *shaderIndex;
1140+
SpirvInstruction *recordCount;
1141+
};
1142+
1143+
/// \brief OpReleaseOutputNodePayloadAMDX instruction
1144+
class SpirvEnqueueNodePayloads : public SpirvInstruction {
1145+
public:
1146+
SpirvEnqueueNodePayloads(SourceLocation loc, SpirvInstruction *payload);
1147+
1148+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvEnqueueNodePayloads)
1149+
1150+
// For LLVM-style RTTI
1151+
static bool classof(const SpirvInstruction *inst) {
1152+
return inst->getKind() == IK_EnqueueNodePayloads;
1153+
}
1154+
1155+
bool invokeVisitor(Visitor *v) override;
1156+
1157+
SpirvInstruction *getPayload() { return payload; }
1158+
1159+
private:
1160+
SpirvInstruction *payload;
1161+
};
1162+
1163+
/// \brief OpFinishWritingNodePayloadAMDX instruction
1164+
class SpirvFinishWritingNodePayload : public SpirvInstruction {
1165+
public:
1166+
SpirvFinishWritingNodePayload(QualType resultType, SourceLocation loc,
1167+
SpirvInstruction *payload);
1168+
1169+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvFinishWritingNodePayload)
1170+
1171+
// For LLVM-style RTTI
1172+
static bool classof(const SpirvInstruction *inst) {
1173+
return inst->getKind() == IK_FinishWritingNodePayload;
1174+
}
1175+
1176+
bool invokeVisitor(Visitor *v) override;
1177+
1178+
SpirvInstruction *getPayload() { return payload; }
1179+
1180+
private:
1181+
SpirvInstruction *payload;
1182+
};
1183+
10591184
/// \brief Represents SPIR-V binary operation instructions.
10601185
///
10611186
/// This class includes:
@@ -1352,6 +1477,27 @@ class SpirvConstantNull : public SpirvConstant {
13521477
bool operator==(const SpirvConstantNull &that) const;
13531478
};
13541479

1480+
class SpirvConstantString : public SpirvConstant {
1481+
public:
1482+
SpirvConstantString(llvm::StringRef stringLiteral, bool isSpecConst = false);
1483+
1484+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantString)
1485+
1486+
// For LLVM-style RTTI
1487+
static bool classof(const SpirvInstruction *inst) {
1488+
return inst->getKind() == IK_ConstantString;
1489+
}
1490+
1491+
bool invokeVisitor(Visitor *v) override;
1492+
1493+
bool operator==(const SpirvConstantString &that) const;
1494+
1495+
llvm::StringRef getString() const { return str; }
1496+
1497+
private:
1498+
std::string str;
1499+
};
1500+
13551501
class SpirvConvertPtrToU : public SpirvInstruction {
13561502
public:
13571503
SpirvConvertPtrToU(SpirvInstruction *ptr, QualType type,

0 commit comments

Comments
 (0)