Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tools/clang/include/clang/AST/HlslTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,10 @@ bool IsHLSLObjectWithImplicitMemberAccess(clang::QualType type);
bool IsHLSLObjectWithImplicitROMemberAccess(clang::QualType type);
bool IsHLSLRWNodeInputRecordType(clang::QualType type);
bool IsHLSLRONodeInputRecordType(clang::QualType type);
bool IsHLSLDispatchNodeInputRecordType(clang::QualType type);
bool IsHLSLNodeRecordArrayType(clang::QualType type);
bool IsHLSLNodeOutputType(clang::QualType type);
bool IsHLSLEmptyNodeRecordType(clang::QualType type);

DXIL::NodeIOKind GetNodeIOType(clang::QualType type);

Expand All @@ -498,6 +501,8 @@ bool IsHLSLCopyableAnnotatableRecord(clang::QualType QT);
bool IsHLSLBuiltinRayAttributeStruct(clang::QualType QT);
bool IsHLSLAggregateType(clang::QualType type);
clang::QualType GetHLSLResourceResultType(clang::QualType type);
clang::QualType GetHLSLNodeIOResultType(clang::ASTContext &astContext,
clang::QualType type);
unsigned GetHLSLResourceTemplateUInt(clang::QualType type);
bool IsIncompleteHLSLResourceArrayType(clang::ASTContext &context,
clang::QualType type);
Expand Down
2 changes: 2 additions & 0 deletions tools/clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -8063,6 +8063,8 @@ def err_hlsl_vk_pointer_cast_alignment: Error<
"Vulkan buffer pointer cannot be cast to greater alignment">;
def err_hlsl_vk_static_pointer_cast_type: Error<
"vk::static_pointer_cast() content type must be base class of argument's content type">;
def warn_spirv_node_shaders_experimental : Warning<
"SPIR-V implementation of node shaders is experimental and subject to change">;
// SPIRV Change Ends

let CategoryName = "OpenMP Issue" in {
Expand Down
1 change: 1 addition & 0 deletions tools/clang/include/clang/SPIRV/FeatureManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ enum class Extension {
KHR_ray_query,
EXT_shader_image_int64,
KHR_physical_storage_buffer,
AMD_shader_enqueue,
KHR_vulkan_memory_model,
NV_compute_shader_derivatives,
KHR_compute_shader_derivatives,
Expand Down
20 changes: 20 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,25 @@ class SpirvBuilder {
QualType resultType, NonSemanticDebugPrintfInstructions instId,
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation);

SpirvInstruction *createIsNodePayloadValid(SpirvInstruction *payloadArray,
SpirvInstruction *nodeIndex,
SourceLocation);

SpirvInstruction *createNodePayloadArrayLength(SpirvInstruction *payloadArray,
SourceLocation);

SpirvInstruction *createAllocateNodePayloads(QualType resultType,
spv::Scope allocationScope,
SpirvInstruction *shaderIndex,
SpirvInstruction *recordCount,
SourceLocation);

void createEnqueueOutputNodePayloads(SpirvInstruction *payload,
SourceLocation);

SpirvInstruction *createFinishWritingNodePayload(SpirvInstruction *payload,
SourceLocation);

/// \brief Creates an OpMemoryBarrier or OpControlBarrier instruction with the
/// given flags. If execution scope (exec) is provided, an OpControlBarrier
/// is created; otherwise an OpMemoryBarrier is created.
Expand Down Expand Up @@ -766,6 +785,7 @@ class SpirvBuilder {
llvm::ArrayRef<SpirvConstant *> constituents,
bool specConst = false);
SpirvConstant *getConstantNull(QualType);
SpirvConstant *getConstantString(llvm::StringRef str, bool specConst = false);
SpirvUndef *getUndef(QualType);

SpirvString *createString(llvm::StringRef str);
Expand Down
53 changes: 53 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ struct RuntimeArrayTypeMapInfo {
}
};

// Provides DenseMapInfo for NodePayloadArrayType so we can create a DenseSet of
// node payload array types.
struct NodePayloadArrayTypeMapInfo {
static inline NodePayloadArrayType *getEmptyKey() { return nullptr; }
static inline NodePayloadArrayType *getTombstoneKey() { return nullptr; }
static unsigned getHashValue(const NodePayloadArrayType *Val) {
return llvm::hash_combine(Val->getElementType(), Val->getNodeDecl());
}
static bool isEqual(const NodePayloadArrayType *LHS,
const NodePayloadArrayType *RHS) {
// Either both are null, or both should have the same underlying type.
return (LHS == RHS) || (LHS && RHS && *LHS == *RHS);
}
};

// Provides DenseMapInfo for ImageType so we can create a DenseSet of
// image types.
struct ImageTypeMapInfo {
Expand Down Expand Up @@ -273,6 +288,9 @@ class SpirvContext {
const RuntimeArrayType *
getRuntimeArrayType(const SpirvType *elemType,
llvm::Optional<uint32_t> arrayStride);
const NodePayloadArrayType *
getNodePayloadArrayType(const SpirvType *elemType,
const ParmVarDecl *nodeDecl);

const StructType *getStructType(
llvm::ArrayRef<StructType::FieldInfo> fields, llvm::StringRef name,
Expand Down Expand Up @@ -349,6 +367,7 @@ class SpirvContext {
bool isDS() const { return curShaderModelKind == ShaderModelKind::Domain; }
bool isCS() const { return curShaderModelKind == ShaderModelKind::Compute; }
bool isLib() const { return curShaderModelKind == ShaderModelKind::Library; }
bool isNode() const { return curShaderModelKind == ShaderModelKind::Node; }
bool isRay() const {
return curShaderModelKind >= ShaderModelKind::RayGeneration &&
curShaderModelKind <= ShaderModelKind::Callable;
Expand Down Expand Up @@ -440,6 +459,31 @@ class SpirvContext {
instructionsWithLoweredType.end();
}

void registerDispatchGridIndex(const RecordDecl *decl, unsigned index) {
auto iter = dispatchGridIndices.find(decl);
if (iter == dispatchGridIndices.end()) {
dispatchGridIndices[decl] = index;
}
}

llvm::Optional<unsigned> getDispatchGridIndex(const RecordDecl *decl) {
auto iter = dispatchGridIndices.find(decl);
if (iter != dispatchGridIndices.end()) {
return iter->second;
}
return llvm::None;
}

void registerNodeDeclPayloadType(const NodePayloadArrayType *type,
const ParmVarDecl *decl) {
nodeDecls[decl] = type;
}

const NodePayloadArrayType *getNodeDeclPayloadType(const ParmVarDecl *decl) {
auto iter = nodeDecls.find(decl);
return iter == nodeDecls.end() ? nullptr : iter->second;
}

private:
/// \brief The allocator used to create SPIR-V entity objects.
///
Expand Down Expand Up @@ -484,6 +528,8 @@ class SpirvContext {
llvm::DenseSet<const ArrayType *, ArrayTypeMapInfo> arrayTypes;
llvm::DenseSet<const RuntimeArrayType *, RuntimeArrayTypeMapInfo>
runtimeArrayTypes;
llvm::DenseSet<const NodePayloadArrayType *, NodePayloadArrayTypeMapInfo>
nodePayloadArrayTypes;
llvm::SmallVector<const StructType *, 8> structTypes;
llvm::SmallVector<const HybridStructType *, 8> hybridStructTypes;
llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
Expand All @@ -510,6 +556,9 @@ class SpirvContext {
llvm::StringMap<RichDebugInfo> debugInfo;
SpirvDebugInstruction *currentLexicalScope;

// Mapping from graphics node input record types to member decoration maps.
llvm::MapVector<const RecordDecl *, unsigned> dispatchGridIndices;

// Mapping from SPIR-V type to debug type instruction.
// The purpose is not to generate several DebugType* instructions for the same
// type if the type is used for several variables.
Expand Down Expand Up @@ -541,6 +590,10 @@ class SpirvContext {

// Set of instructions that already have lowered SPIR-V types.
llvm::DenseSet<const SpirvInstruction *> instructionsWithLoweredType;

// Mapping from shader entry function parameter declaration to node payload
// array type.
llvm::MapVector<const ParmVarDecl *, const NodePayloadArrayType *> nodeDecls;
};

} // end namespace spirv
Expand Down
146 changes: 146 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class SpirvInstruction {
IK_ConstantInteger,
IK_ConstantFloat,
IK_ConstantComposite,
IK_ConstantString,
IK_ConstantNull,

// Pointer <-> uint conversions.
Expand Down Expand Up @@ -168,6 +169,13 @@ class SpirvInstruction {
IK_DebugTypeMember,
IK_DebugTypeTemplate,
IK_DebugTypeTemplateParameter,

// For workgraph instructions
IK_IsNodePayloadValid,
IK_NodePayloadArrayLength,
IK_AllocateNodePayloads,
IK_EnqueueNodePayloads,
IK_FinishWritingNodePayload,
};

// All instruction classes should include a releaseMemory method.
Expand Down Expand Up @@ -443,9 +451,13 @@ class SpirvExecutionMode : public SpirvExecutionModeBase {

bool invokeVisitor(Visitor *v) override;

SpirvFunction *getEntryPoint() const { return entryPoint; }
spv::ExecutionMode getExecutionMode() const { return execMode; }
llvm::ArrayRef<uint32_t> getParams() const { return params; }

private:
SpirvFunction *entryPoint;
spv::ExecutionMode execMode;
llvm::SmallVector<uint32_t, 4> params;
};

Expand Down Expand Up @@ -1059,6 +1071,119 @@ class SpirvBarrier : public SpirvInstruction {
llvm::Optional<spv::Scope> executionScope;
};

/// \brief OpIsNodePayloadValidAMDX instruction
class SpirvIsNodePayloadValid : public SpirvInstruction {
public:
SpirvIsNodePayloadValid(QualType resultType, SourceLocation loc,
SpirvInstruction *payloadArray,
SpirvInstruction *nodeIndex);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvIsNodePayloadValid)

// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) {
return inst->getKind() == IK_IsNodePayloadValid;
}

bool invokeVisitor(Visitor *v) override;

SpirvInstruction *getPayloadArray() { return payloadArray; }
SpirvInstruction *getNodeIndex() { return nodeIndex; }

private:
SpirvInstruction *payloadArray;
SpirvInstruction *nodeIndex;
};

/// \brief OpNodePayloadArrayLengthAMDX instruction
class SpirvNodePayloadArrayLength : public SpirvInstruction {
public:
SpirvNodePayloadArrayLength(QualType resultType, SourceLocation loc,
SpirvInstruction *payloadArray);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNodePayloadArrayLength)

// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) {
return inst->getKind() == IK_NodePayloadArrayLength;
}

bool invokeVisitor(Visitor *v) override;

SpirvInstruction *getPayloadArray() { return payloadArray; }

private:
SpirvInstruction *payloadArray;
};

/// \brief OpAllocateNodePayloadsAMDX instruction
class SpirvAllocateNodePayloads : public SpirvInstruction {
public:
SpirvAllocateNodePayloads(QualType resultType, SourceLocation loc,
spv::Scope allocationScope,
SpirvInstruction *shaderIndex,
SpirvInstruction *recordCount);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvAllocateNodePayloads)

// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) {
return inst->getKind() == IK_AllocateNodePayloads;
}

bool invokeVisitor(Visitor *v) override;

spv::Scope getAllocationScope() { return allocationScope; }
SpirvInstruction *getShaderIndex() { return shaderIndex; }
SpirvInstruction *getRecordCount() { return recordCount; }

private:
spv::Scope allocationScope;
SpirvInstruction *shaderIndex;
SpirvInstruction *recordCount;
};

/// \brief OpReleaseOutputNodePayloadAMDX instruction
class SpirvEnqueueNodePayloads : public SpirvInstruction {
public:
SpirvEnqueueNodePayloads(SourceLocation loc, SpirvInstruction *payload);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvEnqueueNodePayloads)

// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) {
return inst->getKind() == IK_EnqueueNodePayloads;
}

bool invokeVisitor(Visitor *v) override;

SpirvInstruction *getPayload() { return payload; }

private:
SpirvInstruction *payload;
};

/// \brief OpFinishWritingNodePayloadAMDX instruction
class SpirvFinishWritingNodePayload : public SpirvInstruction {
public:
SpirvFinishWritingNodePayload(QualType resultType, SourceLocation loc,
SpirvInstruction *payload);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvFinishWritingNodePayload)

// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) {
return inst->getKind() == IK_FinishWritingNodePayload;
}

bool invokeVisitor(Visitor *v) override;

SpirvInstruction *getPayload() { return payload; }

private:
SpirvInstruction *payload;
};

/// \brief Represents SPIR-V binary operation instructions.
///
/// This class includes:
Expand Down Expand Up @@ -1355,6 +1480,27 @@ class SpirvConstantNull : public SpirvConstant {
bool operator==(const SpirvConstantNull &that) const;
};

class SpirvConstantString : public SpirvConstant {
public:
SpirvConstantString(llvm::StringRef stringLiteral, bool isSpecConst = false);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantString)

// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) {
return inst->getKind() == IK_ConstantString;
}

bool invokeVisitor(Visitor *v) override;

bool operator==(const SpirvConstantString &that) const;

llvm::StringRef getString() const { return str; }

private:
std::string str;
};

class SpirvConvertPtrToU : public SpirvInstruction {
public:
SpirvConvertPtrToU(SpirvInstruction *ptr, QualType type,
Expand Down
Loading