From 78e75735e050eb29e13e84af5ff7bbc360814a95 Mon Sep 17 00:00:00 2001 From: Dan Brown Date: Tue, 23 Jul 2024 15:10:51 -0600 Subject: [PATCH 1/2] AMD SPIR-V work graphs extensions --- tools/clang/include/clang/AST/HlslTypes.h | 5 + .../include/clang/SPIRV/FeatureManager.h | 1 + .../clang/include/clang/SPIRV/SpirvBuilder.h | 20 + .../clang/include/clang/SPIRV/SpirvContext.h | 53 ++ .../include/clang/SPIRV/SpirvInstruction.h | 146 ++++++ tools/clang/include/clang/SPIRV/SpirvType.h | 23 + .../clang/include/clang/SPIRV/SpirvVisitor.h | 6 + tools/clang/lib/AST/HlslTypes.cpp | 40 ++ tools/clang/lib/SPIRV/CapabilityVisitor.cpp | 64 ++- tools/clang/lib/SPIRV/DebugTypeVisitor.cpp | 11 + tools/clang/lib/SPIRV/DeclResultIdMapper.cpp | 35 +- tools/clang/lib/SPIRV/EmitVisitor.cpp | 207 +++++++- tools/clang/lib/SPIRV/EmitVisitor.h | 15 +- tools/clang/lib/SPIRV/FeatureManager.cpp | 3 + tools/clang/lib/SPIRV/GlPerVertex.cpp | 3 + tools/clang/lib/SPIRV/LowerTypeVisitor.cpp | 10 + tools/clang/lib/SPIRV/PreciseVisitor.cpp | 3 + tools/clang/lib/SPIRV/SpirvBuilder.cpp | 72 ++- tools/clang/lib/SPIRV/SpirvContext.cpp | 16 + tools/clang/lib/SPIRV/SpirvEmitter.cpp | 452 +++++++++++++++++- tools/clang/lib/SPIRV/SpirvEmitter.h | 44 ++ tools/clang/lib/SPIRV/SpirvInstruction.cpp | 57 ++- tools/clang/lib/SPIRV/SpirvType.cpp | 4 + .../test/CodeGenSPIRV/bezier.hull.hlsl2spv | 144 +++--- .../hs.const.output-patch.out.hlsl | 6 +- .../CodeGenSPIRV/node.barrier.compute.hlsl | 15 + .../CodeGenSPIRV/node.barrier.memory-arg.hlsl | 60 +++ .../CodeGenSPIRV/node.barrier.object-arg.hlsl | 213 +++++++++ .../node.broadcasting.no-input.hlsl | 15 + .../node.coalescing.num-threads.hlsl | 16 + .../test/CodeGenSPIRV/node.dispatch-grid.hlsl | 28 ++ .../CodeGenSPIRV/node.empty-node-input.hlsl | 28 ++ .../node.finished-cross-group-sharing.hlsl | 32 ++ .../node.get-input-record-count.hlsl | 25 + .../node.get-node-output-record.multiple.hlsl | 72 +++ .../node.get-remaining-recursion-levels.hlsl | 26 + .../node.group-shared.barrier.hlsl | 18 + .../test/CodeGenSPIRV/node.group-shared.hlsl | 24 + .../node.increment-output-count.group.hlsl | 22 + .../node.increment-output-count.thread.hlsl | 22 + ...node.input-record.dispatch-grid.array.hlsl | 26 + ...ode.input-record.dispatch-grid.nested.hlsl | 32 ++ .../CodeGenSPIRV/node.max-dispatch-grid.hlsl | 30 ++ .../test/CodeGenSPIRV/node.max-records.hlsl | 45 ++ .../test/CodeGenSPIRV/node.member.read.hlsl | 150 ++++++ .../CodeGenSPIRV/node.member.read.types.hlsl | 193 ++++++++ .../test/CodeGenSPIRV/node.member.write.hlsl | 88 ++++ .../node.member.write.matrix.hlsl | 123 +++++ .../CodeGenSPIRV/node.member.write.types.hlsl | 150 ++++++ tools/clang/test/CodeGenSPIRV/node.mesh.hlsl | 88 ++++ .../CodeGenSPIRV/node.output-complete.hlsl | 33 ++ .../node.output.is-valid.empty.hlsl | 19 + .../CodeGenSPIRV/node.output.is-valid.hlsl | 24 + .../clang/test/CodeGenSPIRV/node.renamed.hlsl | 23 + .../test/CodeGenSPIRV/node.share-input.hlsl | 42 ++ .../test/CodeGenSPIRV/node.sparse-nodes.hlsl | 141 ++++++ .../CodeGenSPIRV/node.thread.num-threads.hlsl | 15 + .../node.thread.num-threads.none.hlsl | 15 + .../vk.attribute.image-format.hlsl | 6 - 59 files changed, 3187 insertions(+), 112 deletions(-) create mode 100644 tools/clang/test/CodeGenSPIRV/node.barrier.compute.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.barrier.memory-arg.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.barrier.object-arg.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.broadcasting.no-input.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.coalescing.num-threads.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.dispatch-grid.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.empty-node-input.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.finished-cross-group-sharing.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.get-input-record-count.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.get-node-output-record.multiple.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.get-remaining-recursion-levels.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.group-shared.barrier.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.group-shared.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.increment-output-count.group.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.increment-output-count.thread.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.input-record.dispatch-grid.array.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.input-record.dispatch-grid.nested.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.max-dispatch-grid.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.max-records.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.member.read.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.member.read.types.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.member.write.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.member.write.matrix.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.member.write.types.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.mesh.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.output-complete.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.output.is-valid.empty.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.output.is-valid.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.renamed.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.share-input.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.sparse-nodes.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.thread.num-threads.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/node.thread.num-threads.none.hlsl diff --git a/tools/clang/include/clang/AST/HlslTypes.h b/tools/clang/include/clang/AST/HlslTypes.h index c14f562101..2362d2ded4 100644 --- a/tools/clang/include/clang/AST/HlslTypes.h +++ b/tools/clang/include/clang/AST/HlslTypes.h @@ -488,7 +488,10 @@ bool IsHLSLObjectWithImplicitMemberAccess(clang::QualType type); bool IsHLSLObjectWithImplicitROMemberAccess(clang::QualType type); bool IsHLSLRWNodeInputRecordType(clang::QualType type); bool IsHLSLRONodeInputRecordType(clang::QualType type); +bool IsHLSLDispatchNodeInputRecordType(clang::QualType type); +bool IsHLSLNodeRecordArrayType(clang::QualType type); bool IsHLSLNodeOutputType(clang::QualType type); +bool IsHLSLEmptyNodeRecordType(clang::QualType type); DXIL::NodeIOKind GetNodeIOType(clang::QualType type); @@ -498,6 +501,8 @@ bool IsHLSLCopyableAnnotatableRecord(clang::QualType QT); bool IsHLSLBuiltinRayAttributeStruct(clang::QualType QT); bool IsHLSLAggregateType(clang::QualType type); clang::QualType GetHLSLResourceResultType(clang::QualType type); +clang::QualType GetHLSLNodeIOResultType(clang::ASTContext &astContext, + clang::QualType type); unsigned GetHLSLResourceTemplateUInt(clang::QualType type); bool IsIncompleteHLSLResourceArrayType(clang::ASTContext &context, clang::QualType type); diff --git a/tools/clang/include/clang/SPIRV/FeatureManager.h b/tools/clang/include/clang/SPIRV/FeatureManager.h index 3c1871df37..94dc5bf1ab 100644 --- a/tools/clang/include/clang/SPIRV/FeatureManager.h +++ b/tools/clang/include/clang/SPIRV/FeatureManager.h @@ -57,6 +57,7 @@ enum class Extension { KHR_ray_query, EXT_shader_image_int64, KHR_physical_storage_buffer, + AMD_shader_enqueue, KHR_vulkan_memory_model, NV_compute_shader_derivatives, KHR_compute_shader_derivatives, diff --git a/tools/clang/include/clang/SPIRV/SpirvBuilder.h b/tools/clang/include/clang/SPIRV/SpirvBuilder.h index e4e6ef308f..c17f671a4a 100644 --- a/tools/clang/include/clang/SPIRV/SpirvBuilder.h +++ b/tools/clang/include/clang/SPIRV/SpirvBuilder.h @@ -437,6 +437,25 @@ class SpirvBuilder { QualType resultType, NonSemanticDebugPrintfInstructions instId, llvm::ArrayRef operands, SourceLocation); + SpirvInstruction *createIsNodePayloadValid(SpirvInstruction *payloadArray, + SpirvInstruction *nodeIndex, + SourceLocation); + + SpirvInstruction *createNodePayloadArrayLength(SpirvInstruction *payloadArray, + SourceLocation); + + SpirvInstruction *createAllocateNodePayloads(QualType resultType, + spv::Scope allocationScope, + SpirvInstruction *shaderIndex, + SpirvInstruction *recordCount, + SourceLocation); + + void createEnqueueOutputNodePayloads(SpirvInstruction *payload, + SourceLocation); + + SpirvInstruction *createFinishWritingNodePayload(SpirvInstruction *payload, + SourceLocation); + /// \brief Creates an OpMemoryBarrier or OpControlBarrier instruction with the /// given flags. If execution scope (exec) is provided, an OpControlBarrier /// is created; otherwise an OpMemoryBarrier is created. @@ -766,6 +785,7 @@ class SpirvBuilder { llvm::ArrayRef constituents, bool specConst = false); SpirvConstant *getConstantNull(QualType); + SpirvConstant *getConstantString(llvm::StringRef str, bool specConst = false); SpirvUndef *getUndef(QualType); SpirvString *createString(llvm::StringRef str); diff --git a/tools/clang/include/clang/SPIRV/SpirvContext.h b/tools/clang/include/clang/SPIRV/SpirvContext.h index c18c139642..8b2773f1d5 100644 --- a/tools/clang/include/clang/SPIRV/SpirvContext.h +++ b/tools/clang/include/clang/SPIRV/SpirvContext.h @@ -101,6 +101,21 @@ struct RuntimeArrayTypeMapInfo { } }; +// Provides DenseMapInfo for NodePayloadArrayType so we can create a DenseSet of +// node payload array types. +struct NodePayloadArrayTypeMapInfo { + static inline NodePayloadArrayType *getEmptyKey() { return nullptr; } + static inline NodePayloadArrayType *getTombstoneKey() { return nullptr; } + static unsigned getHashValue(const NodePayloadArrayType *Val) { + return llvm::hash_combine(Val->getElementType(), Val->getNodeDecl()); + } + static bool isEqual(const NodePayloadArrayType *LHS, + const NodePayloadArrayType *RHS) { + // Either both are null, or both should have the same underlying type. + return (LHS == RHS) || (LHS && RHS && *LHS == *RHS); + } +}; + // Provides DenseMapInfo for ImageType so we can create a DenseSet of // image types. struct ImageTypeMapInfo { @@ -273,6 +288,9 @@ class SpirvContext { const RuntimeArrayType * getRuntimeArrayType(const SpirvType *elemType, llvm::Optional arrayStride); + const NodePayloadArrayType * + getNodePayloadArrayType(const SpirvType *elemType, + const ParmVarDecl *nodeDecl); const StructType *getStructType( llvm::ArrayRef fields, llvm::StringRef name, @@ -349,6 +367,7 @@ class SpirvContext { bool isDS() const { return curShaderModelKind == ShaderModelKind::Domain; } bool isCS() const { return curShaderModelKind == ShaderModelKind::Compute; } bool isLib() const { return curShaderModelKind == ShaderModelKind::Library; } + bool isNode() const { return curShaderModelKind == ShaderModelKind::Node; } bool isRay() const { return curShaderModelKind >= ShaderModelKind::RayGeneration && curShaderModelKind <= ShaderModelKind::Callable; @@ -440,6 +459,31 @@ class SpirvContext { instructionsWithLoweredType.end(); } + void registerDispatchGridIndex(const RecordDecl *decl, unsigned index) { + auto iter = dispatchGridIndices.find(decl); + if (iter == dispatchGridIndices.end()) { + dispatchGridIndices[decl] = index; + } + } + + llvm::Optional getDispatchGridIndex(const RecordDecl *decl) { + auto iter = dispatchGridIndices.find(decl); + if (iter != dispatchGridIndices.end()) { + return iter->second; + } + return llvm::None; + } + + void registerNodeDeclPayloadType(const NodePayloadArrayType *type, + const ParmVarDecl *decl) { + nodeDecls[decl] = type; + } + + const NodePayloadArrayType *getNodeDeclPayloadType(const ParmVarDecl *decl) { + auto iter = nodeDecls.find(decl); + return iter == nodeDecls.end() ? nullptr : iter->second; + } + private: /// \brief The allocator used to create SPIR-V entity objects. /// @@ -484,6 +528,8 @@ class SpirvContext { llvm::DenseSet arrayTypes; llvm::DenseSet runtimeArrayTypes; + llvm::DenseSet + nodePayloadArrayTypes; llvm::SmallVector structTypes; llvm::SmallVector hybridStructTypes; llvm::DenseMap pointerTypes; @@ -510,6 +556,9 @@ class SpirvContext { llvm::StringMap debugInfo; SpirvDebugInstruction *currentLexicalScope; + // Mapping from graphics node input record types to member decoration maps. + llvm::MapVector dispatchGridIndices; + // Mapping from SPIR-V type to debug type instruction. // The purpose is not to generate several DebugType* instructions for the same // type if the type is used for several variables. @@ -541,6 +590,10 @@ class SpirvContext { // Set of instructions that already have lowered SPIR-V types. llvm::DenseSet instructionsWithLoweredType; + + // Mapping from shader entry function parameter declaration to node payload + // array type. + llvm::MapVector nodeDecls; }; } // end namespace spirv diff --git a/tools/clang/include/clang/SPIRV/SpirvInstruction.h b/tools/clang/include/clang/SPIRV/SpirvInstruction.h index 6d95459373..3c675a4fe2 100644 --- a/tools/clang/include/clang/SPIRV/SpirvInstruction.h +++ b/tools/clang/include/clang/SPIRV/SpirvInstruction.h @@ -70,6 +70,7 @@ class SpirvInstruction { IK_ConstantInteger, IK_ConstantFloat, IK_ConstantComposite, + IK_ConstantString, IK_ConstantNull, // Pointer <-> uint conversions. @@ -168,6 +169,13 @@ class SpirvInstruction { IK_DebugTypeMember, IK_DebugTypeTemplate, IK_DebugTypeTemplateParameter, + + // For workgraph instructions + IK_IsNodePayloadValid, + IK_NodePayloadArrayLength, + IK_AllocateNodePayloads, + IK_EnqueueNodePayloads, + IK_FinishWritingNodePayload, }; // All instruction classes should include a releaseMemory method. @@ -443,9 +451,13 @@ class SpirvExecutionMode : public SpirvExecutionModeBase { bool invokeVisitor(Visitor *v) override; + SpirvFunction *getEntryPoint() const { return entryPoint; } + spv::ExecutionMode getExecutionMode() const { return execMode; } llvm::ArrayRef getParams() const { return params; } private: + SpirvFunction *entryPoint; + spv::ExecutionMode execMode; llvm::SmallVector params; }; @@ -1059,6 +1071,119 @@ class SpirvBarrier : public SpirvInstruction { llvm::Optional executionScope; }; +/// \brief OpIsNodePayloadValidAMDX instruction +class SpirvIsNodePayloadValid : public SpirvInstruction { +public: + SpirvIsNodePayloadValid(QualType resultType, SourceLocation loc, + SpirvInstruction *payloadArray, + SpirvInstruction *nodeIndex); + + DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvIsNodePayloadValid) + + // For LLVM-style RTTI + static bool classof(const SpirvInstruction *inst) { + return inst->getKind() == IK_IsNodePayloadValid; + } + + bool invokeVisitor(Visitor *v) override; + + SpirvInstruction *getPayloadArray() { return payloadArray; } + SpirvInstruction *getNodeIndex() { return nodeIndex; } + +private: + SpirvInstruction *payloadArray; + SpirvInstruction *nodeIndex; +}; + +/// \brief OpNodePayloadArrayLengthAMDX instruction +class SpirvNodePayloadArrayLength : public SpirvInstruction { +public: + SpirvNodePayloadArrayLength(QualType resultType, SourceLocation loc, + SpirvInstruction *payloadArray); + + DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNodePayloadArrayLength) + + // For LLVM-style RTTI + static bool classof(const SpirvInstruction *inst) { + return inst->getKind() == IK_NodePayloadArrayLength; + } + + bool invokeVisitor(Visitor *v) override; + + SpirvInstruction *getPayloadArray() { return payloadArray; } + +private: + SpirvInstruction *payloadArray; +}; + +/// \brief OpAllocateNodePayloadsAMDX instruction +class SpirvAllocateNodePayloads : public SpirvInstruction { +public: + SpirvAllocateNodePayloads(QualType resultType, SourceLocation loc, + spv::Scope allocationScope, + SpirvInstruction *shaderIndex, + SpirvInstruction *recordCount); + + DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvAllocateNodePayloads) + + // For LLVM-style RTTI + static bool classof(const SpirvInstruction *inst) { + return inst->getKind() == IK_AllocateNodePayloads; + } + + bool invokeVisitor(Visitor *v) override; + + spv::Scope getAllocationScope() { return allocationScope; } + SpirvInstruction *getShaderIndex() { return shaderIndex; } + SpirvInstruction *getRecordCount() { return recordCount; } + +private: + spv::Scope allocationScope; + SpirvInstruction *shaderIndex; + SpirvInstruction *recordCount; +}; + +/// \brief OpReleaseOutputNodePayloadAMDX instruction +class SpirvEnqueueNodePayloads : public SpirvInstruction { +public: + SpirvEnqueueNodePayloads(SourceLocation loc, SpirvInstruction *payload); + + DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvEnqueueNodePayloads) + + // For LLVM-style RTTI + static bool classof(const SpirvInstruction *inst) { + return inst->getKind() == IK_EnqueueNodePayloads; + } + + bool invokeVisitor(Visitor *v) override; + + SpirvInstruction *getPayload() { return payload; } + +private: + SpirvInstruction *payload; +}; + +/// \brief OpFinishWritingNodePayloadAMDX instruction +class SpirvFinishWritingNodePayload : public SpirvInstruction { +public: + SpirvFinishWritingNodePayload(QualType resultType, SourceLocation loc, + SpirvInstruction *payload); + + DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvFinishWritingNodePayload) + + // For LLVM-style RTTI + static bool classof(const SpirvInstruction *inst) { + return inst->getKind() == IK_FinishWritingNodePayload; + } + + bool invokeVisitor(Visitor *v) override; + + SpirvInstruction *getPayload() { return payload; } + +private: + SpirvInstruction *payload; +}; + /// \brief Represents SPIR-V binary operation instructions. /// /// This class includes: @@ -1355,6 +1480,27 @@ class SpirvConstantNull : public SpirvConstant { bool operator==(const SpirvConstantNull &that) const; }; +class SpirvConstantString : public SpirvConstant { +public: + SpirvConstantString(llvm::StringRef stringLiteral, bool isSpecConst = false); + + DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantString) + + // For LLVM-style RTTI + static bool classof(const SpirvInstruction *inst) { + return inst->getKind() == IK_ConstantString; + } + + bool invokeVisitor(Visitor *v) override; + + bool operator==(const SpirvConstantString &that) const; + + llvm::StringRef getString() const { return str; } + +private: + std::string str; +}; + class SpirvConvertPtrToU : public SpirvInstruction { public: SpirvConvertPtrToU(SpirvInstruction *ptr, QualType type, diff --git a/tools/clang/include/clang/SPIRV/SpirvType.h b/tools/clang/include/clang/SPIRV/SpirvType.h index 00a00ef238..1bd772e295 100644 --- a/tools/clang/include/clang/SPIRV/SpirvType.h +++ b/tools/clang/include/clang/SPIRV/SpirvType.h @@ -54,6 +54,7 @@ class SpirvType { TK_SampledImage, TK_Array, TK_RuntimeArray, + TK_NodePayloadArrayAMD, TK_Struct, TK_Pointer, TK_ForwardPointer, @@ -294,6 +295,26 @@ class RuntimeArrayType : public SpirvType { llvm::Optional stride; }; +class NodePayloadArrayType : public SpirvType { +public: + NodePayloadArrayType(const SpirvType *elemType, const ParmVarDecl *decl) + : SpirvType(TK_NodePayloadArrayAMD), elementType(elemType), + nodeDecl(decl) {} + + static bool classof(const SpirvType *t) { + return t->getKind() == TK_NodePayloadArrayAMD; + } + + bool operator==(const NodePayloadArrayType &that) const; + + const SpirvType *getElementType() const { return elementType; } + const ParmVarDecl *getNodeDecl() const { return nodeDecl; } + +private: + const SpirvType *elementType; + const ParmVarDecl *nodeDecl; +}; + // The StructType is the lowered type that best represents what a structure type // is in SPIR-V. Contains all necessary information for properly emitting a // SPIR-V structure type. @@ -630,6 +651,8 @@ bool SpirvType::isOrContainsType(const SpirvType *type) { return isOrContainsType(pointerType->getPointeeType()); if (const auto *raType = dyn_cast(type)) return isOrContainsType(raType->getElementType()); + if (const auto *npaType = dyn_cast(type)) + return isOrContainsType(npaType->getElementType()); if (const auto *imgType = dyn_cast(type)) return isOrContainsType(imgType->getSampledType()); if (const auto *sampledImageType = dyn_cast(type)) diff --git a/tools/clang/include/clang/SPIRV/SpirvVisitor.h b/tools/clang/include/clang/SPIRV/SpirvVisitor.h index 95bc46aa5f..b6f6d6ce85 100644 --- a/tools/clang/include/clang/SPIRV/SpirvVisitor.h +++ b/tools/clang/include/clang/SPIRV/SpirvVisitor.h @@ -85,6 +85,11 @@ class Visitor { DEFINE_VISIT_METHOD(SpirvAccessChain) DEFINE_VISIT_METHOD(SpirvAtomic) DEFINE_VISIT_METHOD(SpirvBarrier) + DEFINE_VISIT_METHOD(SpirvIsNodePayloadValid) + DEFINE_VISIT_METHOD(SpirvNodePayloadArrayLength) + DEFINE_VISIT_METHOD(SpirvAllocateNodePayloads) + DEFINE_VISIT_METHOD(SpirvEnqueueNodePayloads) + DEFINE_VISIT_METHOD(SpirvFinishWritingNodePayload) DEFINE_VISIT_METHOD(SpirvBinaryOp) DEFINE_VISIT_METHOD(SpirvBitFieldExtract) DEFINE_VISIT_METHOD(SpirvBitFieldInsert) @@ -92,6 +97,7 @@ class Visitor { DEFINE_VISIT_METHOD(SpirvConstantInteger) DEFINE_VISIT_METHOD(SpirvConstantFloat) DEFINE_VISIT_METHOD(SpirvConstantComposite) + DEFINE_VISIT_METHOD(SpirvConstantString) DEFINE_VISIT_METHOD(SpirvConstantNull) DEFINE_VISIT_METHOD(SpirvConvertPtrToU) DEFINE_VISIT_METHOD(SpirvConvertUToPtr) diff --git a/tools/clang/lib/AST/HlslTypes.cpp b/tools/clang/lib/AST/HlslTypes.cpp index 05386ddaa5..0e04e2d117 100644 --- a/tools/clang/lib/AST/HlslTypes.cpp +++ b/tools/clang/lib/AST/HlslTypes.cpp @@ -590,6 +590,12 @@ bool IsHLSLRONodeInputRecordType(clang::QualType type) { static_cast(DXIL::NodeIOFlags::Input); } +bool IsHLSLDispatchNodeInputRecordType(clang::QualType type) { + return IsHLSLNodeInputType(type) && + (static_cast(GetNodeIOType(type)) & + static_cast(DXIL::NodeIOFlags::DispatchRecord)) != 0; +} + bool IsHLSLNodeOutputType(clang::QualType type) { return (static_cast(GetNodeIOType(type)) & (static_cast(DXIL::NodeIOFlags::Output) | @@ -597,6 +603,23 @@ bool IsHLSLNodeOutputType(clang::QualType type) { static_cast(DXIL::NodeIOFlags::Output); } +bool IsHLSLNodeRecordArrayType(clang::QualType type) { + if (const RecordType *RT = type->getAs()) { + StringRef name = RT->getDecl()->getName(); + if (name == "ThreadNodeOutputRecords" || name == "GroupNodeOutputRecords" || + name == "GroupNodeInputRecords" || name == "RWGroupNodeInputRecords" || + name == "EmptyNodeInput") + return true; + } + return false; +} + +bool IsHLSLEmptyNodeRecordType(clang::QualType type) { + return (static_cast(GetNodeIOType(type)) & + static_cast(DXIL::NodeIOFlags::EmptyRecord)) == + static_cast(DXIL::NodeIOFlags::EmptyRecord); +} + bool IsHLSLStructuredBufferType(clang::QualType type) { if (const HLSLResourceAttr *Attr = getAttr(type)) return Attr->getResKind() == DXIL::ResourceKind::StructuredBuffer; @@ -839,6 +862,23 @@ QualType GetHLSLResourceResultType(QualType type) { return HandleFieldDecl->getType(); } +QualType GetHLSLNodeIOResultType(ASTContext &astContext, QualType type) { + if (hlsl::IsHLSLEmptyNodeRecordType(type)) { + RecordDecl *RD = astContext.buildImplicitRecord(""); + RD->startDefinition(); + RD->completeDefinition(); + return astContext.getRecordType(RD); + } else if (hlsl::IsHLSLNodeType(type)) { + const RecordType *recordType = type->getAs(); + if (const auto *templateDecl = + dyn_cast(recordType->getDecl())) { + const auto &templateArgs = templateDecl->getTemplateArgs(); + return templateArgs[0].getAsType(); + } + } + return type; +} + unsigned GetHLSLResourceTemplateUInt(clang::QualType type) { const ClassTemplateSpecializationDecl *templateDecl = cast( diff --git a/tools/clang/lib/SPIRV/CapabilityVisitor.cpp b/tools/clang/lib/SPIRV/CapabilityVisitor.cpp index 9ca9cbc6cd..baee0b753d 100644 --- a/tools/clang/lib/SPIRV/CapabilityVisitor.cpp +++ b/tools/clang/lib/SPIRV/CapabilityVisitor.cpp @@ -125,6 +125,12 @@ void CapabilityVisitor::addCapabilityForType(const SpirvType *type, } addCapabilityForType(raType->getElementType(), loc, sc); } + // Node payload array also requires additional capability. + else if (const auto *npaType = dyn_cast(type)) { + addExtension(Extension::AMD_shader_enqueue, "Vulkan 1.3", loc); + addCapability(spv::Capability::ShaderEnqueueAMDX, loc); + addCapabilityForType(npaType->getElementType(), loc, sc); + } // Image types else if (const auto *imageType = dyn_cast(type)) { switch (imageType->getDimension()) { @@ -257,6 +263,19 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) { addCapability(spv::Capability::FragmentBarycentricKHR); break; } + case spv::Decoration::NodeSharesPayloadLimitsWithAMDX: + case spv::Decoration::NodeMaxPayloadsAMDX: + case spv::Decoration::TrackFinishWritingAMDX: + case spv::Decoration::PayloadNodeNameAMDX: + case spv::Decoration::PayloadNodeBaseIndexAMDX: + case spv::Decoration::PayloadNodeSparseArrayAMDX: + case spv::Decoration::PayloadNodeArraySizeAMDX: + case spv::Decoration::PayloadDispatchIndirectAMDX: { + featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_3, "WorkGraphs", loc); + addCapability(spv::Capability::ShaderEnqueueAMDX, loc); + addExtension(Extension::AMD_shader_enqueue, "Vulkan 1.3", loc); + break; + } // Capabilities needed for built-ins case spv::Decoration::BuiltIn: { AddVulkanMemoryModelForVolatile(decor, loc); @@ -535,8 +554,14 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) { addCapability(spv::Capability::GroupNonUniformQuad); break; case spv::Op::OpVariable: { - if (spvOptions.enableReflect && - !cast(instr)->getHlslUserType().empty()) { + auto var = cast(instr); + auto storage = var->getStorageClass(); + if (storage == spv::StorageClass::NodePayloadAMDX) { + featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_3, "WorkGraphs", loc); + addCapability(spv::Capability::ShaderEnqueueAMDX, loc); + addExtension(Extension::AMD_shader_enqueue, "Vulkan 1.3", loc); + } + if (spvOptions.enableReflect && !var->getHlslUserType().empty()) { addExtension(Extension::GOOGLE_user_type, "HLSL User Type", loc); addExtension(Extension::GOOGLE_hlsl_functionality1, "HLSL User Type", loc); @@ -580,6 +605,28 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) { } break; } + case spv::Op::OpConstantStringAMDX: + case spv::Op::OpSpecConstantStringAMDX: + case spv::Op::OpAllocateNodePayloadsAMDX: + case spv::Op::OpEnqueueNodePayloadsAMDX: + case spv::Op::OpIsNodePayloadValidAMDX: + case spv::Op::OpFinishWritingNodePayloadAMDX: { + featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_3, "WorkGraphs", loc); + addCapability(spv::Capability::ShaderEnqueueAMDX, loc); + addExtension(Extension::AMD_shader_enqueue, "Vulkan 1.3", loc); + break; + } + case spv::Op::OpControlBarrier: + case spv::Op::OpMemoryBarrier: { + auto barrier = cast(instr); + if ((bool)(barrier->getMemorySemantics() & + spv::MemorySemanticsMask::OutputMemoryKHR)) { + featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_3, "NODE_OUTPUT_MEMORY", + loc); + addCapability(spv::Capability::VulkanMemoryModel, loc); + } + break; + } default: break; @@ -645,6 +692,19 @@ bool CapabilityVisitor::visit(SpirvExecutionModeBase *execMode) { SourceLocation entryPointSourceLocation = execMode->getEntryPoint()->getSourceLocation(); switch (executionMode) { + case spv::ExecutionMode::CoalescingAMDX: + case spv::ExecutionMode::MaxNodeRecursionAMDX: + case spv::ExecutionMode::StaticNumWorkgroupsAMDX: + case spv::ExecutionMode::MaxNumWorkgroupsAMDX: + featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_3, "WorkGraphs", + execModeSourceLocation); + addCapability(spv::Capability::ShaderEnqueueAMDX, execModeSourceLocation); + addExtension(Extension::AMD_shader_enqueue, "Vulkan 1.3", + execModeSourceLocation); + break; + case spv::ExecutionMode::SubgroupSize: + addCapability(spv::Capability::SubgroupDispatch, execModeSourceLocation); + break; case spv::ExecutionMode::PostDepthCoverage: addCapability(spv::Capability::SampleMaskPostDepthCoverage, entryPointSourceLocation); diff --git a/tools/clang/lib/SPIRV/DebugTypeVisitor.cpp b/tools/clang/lib/SPIRV/DebugTypeVisitor.cpp index 058e7b6255..24fab092cc 100644 --- a/tools/clang/lib/SPIRV/DebugTypeVisitor.cpp +++ b/tools/clang/lib/SPIRV/DebugTypeVisitor.cpp @@ -356,6 +356,17 @@ SpirvDebugType *DebugTypeVisitor::lowerToDebugType(const SpirvType *spirvType) { debugType = spvContext.getDebugTypeArray(spirvType, elemDebugType, counts); break; } + case SpirvType::TK_NodePayloadArrayAMD: { + auto *arrType = dyn_cast(spirvType); + SpirvDebugInstruction *elemDebugType = + lowerToDebugType(arrType->getElementType()); + + llvm::SmallVector counts; + counts.push_back(0u); + + debugType = spvContext.getDebugTypeArray(spirvType, elemDebugType, counts); + break; + } case SpirvType::TK_Vector: { auto *vecType = dyn_cast(spirvType); SpirvDebugInstruction *elemDebugType = diff --git a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp index de73d5e417..9d0d8f51a3 100644 --- a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp +++ b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp @@ -467,6 +467,10 @@ hlsl::DxilParamInputQual deduceParamQual(const DeclaratorDecl *decl, if (decl->hasAttr()) return hlsl::DxilParamInputQual::InPayload; + if (hlsl::IsHLSLNodeType(type)) { + return hlsl::DxilParamInputQual::NodeIO; + } + return asInput ? hlsl::DxilParamInputQual::In : hlsl::DxilParamInputQual::Out; } @@ -475,6 +479,9 @@ hlsl::DxilParamInputQual deduceParamQual(const DeclaratorDecl *decl, const hlsl::SigPoint *deduceSigPoint(const DeclaratorDecl *decl, bool asInput, const hlsl::ShaderModel::Kind kind, bool forPCF) { + if (kind == hlsl::ShaderModel::Kind::Node) { + return hlsl::SigPoint::GetSigPoint(hlsl::SigPoint::Kind::CSIn); + } return hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual( deduceParamQual(decl, asInput), kind, forPCF)); } @@ -2158,6 +2165,8 @@ bool DeclResultIdMapper::assignLocations( llvm::DenseSet *stageVariableLocationInfo) { for (const auto *var : vars) { + if (hlsl::IsHLSLNodeType(var->getAstType())) + continue; auto locCount = var->getLocationCount(); uint32_t location = nextLocs(locCount); spvBuilder.decorateLocation(var->getSpirvInstr(), location); @@ -3489,7 +3498,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable( StageVar stageVar( stageVarData.sigPoint, *stageVarData.semantic, builtinAttr, evalType, // For HS/DS/GS, we have already stripped the outmost arrayness on type. - getLocationAndComponentCount(astContext, stageVarData.type)); + hlsl::IsHLSLNodeInputType(stageVarData.type) + ? LocationAndComponent({0, 0, false}) + : getLocationAndComponentCount(astContext, stageVarData.type)); const auto name = stageVarData.namePrefix.str() + "." + stageVar.getSemanticStr(); SpirvVariable *varInstr = createSpirvStageVar( @@ -3708,6 +3719,22 @@ bool DeclResultIdMapper::createStageVars(StageVarDataBundle &stageVarData, stageVarData.semantic = &thisSemantic; } + if (hlsl::IsHLSLNodeType(stageVarData.type)) { + // Hijack the notion of semantic to use createSpirvInterfaceVariable + StringRef str = stageVarData.decl->getName(); + stageVarData.semantic->str = stageVarData.semantic->name = str; + stageVarData.semantic->semantic = hlsl::Semantic::GetArbitrary(); + SpirvVariable *varInstr = createSpirvInterfaceVariable(stageVarData); + if (!varInstr) { + return false; + } + + *value = hlsl::IsHLSLNodeInputType(stageVarData.type) + ? varInstr + : loadShaderInputVariable(varInstr, stageVarData); + return true; + } + if (stageVarData.semantic->isValid() && // Structs with attached semantics will be handled later. !stageVarData.type->isStructureType()) { @@ -4161,6 +4188,8 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn, case spv::BuiltIn::GlobalInvocationId: case spv::BuiltIn::WorkgroupId: case spv::BuiltIn::LocalInvocationIndex: + case spv::BuiltIn::RemainingRecursionLevelsAMDX: + case spv::BuiltIn::ShaderIndexAMDX: sc = spv::StorageClass::Input; break; case spv::BuiltIn::TaskCountNV: @@ -4196,7 +4225,9 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar( const auto type = stageVar->getAstType(); const auto isPrecise = decl->hasAttr(); auto isNointerp = decl->hasAttr(); - spv::StorageClass sc = getStorageClassForSigPoint(sigPoint); + spv::StorageClass sc = hlsl::IsHLSLNodeInputType(stageVar->getAstType()) + ? spv::StorageClass::NodePayloadAMDX + : getStorageClassForSigPoint(sigPoint); if (sc == spv::StorageClass::Max) return 0; stageVar->setStorageClass(sc); diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index f58160254a..3f96063101 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -941,6 +941,73 @@ bool EmitVisitor::visit(SpirvBarrier *inst) { curInst.push_back(memoryScopeId); curInst.push_back(memorySemanticsId); finalizeInstruction(&mainBinary); + emitDebugNameForInstruction(getOrAssignResultId(inst), + inst->getDebugName()); + return true; +} + +bool EmitVisitor::visit(SpirvIsNodePayloadValid *inst) { + initInstruction(inst); + curInst.push_back(inst->getResultTypeId()); + curInst.push_back(getOrAssignResultId(inst)); + curInst.push_back( + getOrAssignResultId(inst->getPayloadArray())); + curInst.push_back( + getOrAssignResultId(inst->getNodeIndex())); + finalizeInstruction(&mainBinary); + emitDebugNameForInstruction(getOrAssignResultId(inst), + inst->getDebugName()); + return true; +} + +bool EmitVisitor::visit(SpirvNodePayloadArrayLength *inst) { + initInstruction(inst); + curInst.push_back(inst->getResultTypeId()); + curInst.push_back(getOrAssignResultId(inst)); + curInst.push_back( + getOrAssignResultId(inst->getPayloadArray())); + finalizeInstruction(&mainBinary); + emitDebugNameForInstruction(getOrAssignResultId(inst), + inst->getDebugName()); + return true; +} + +bool EmitVisitor::visit(SpirvAllocateNodePayloads *inst) { + const uint32_t allocationScopeId = typeHandler.getOrCreateConstantInt( + llvm::APInt(32, static_cast(inst->getAllocationScope())), + context.getUIntType(32), /*isSpecConst */ false); + + initInstruction(inst); + curInst.push_back(inst->getResultTypeId()); + curInst.push_back(getOrAssignResultId(inst)); + curInst.push_back(allocationScopeId); + curInst.push_back( + getOrAssignResultId(inst->getRecordCount())); + curInst.push_back( + getOrAssignResultId(inst->getShaderIndex())); + finalizeInstruction(&mainBinary); + emitDebugNameForInstruction(getOrAssignResultId(inst), + inst->getDebugName()); + return true; +} + +bool EmitVisitor::visit(SpirvEnqueueNodePayloads *inst) { + initInstruction(inst); + curInst.push_back(getOrAssignResultId(inst->getPayload())); + finalizeInstruction(&mainBinary); + emitDebugNameForInstruction(getOrAssignResultId(inst), + inst->getDebugName()); + return true; +} + +bool EmitVisitor::visit(SpirvFinishWritingNodePayload *inst) { + initInstruction(inst); + curInst.push_back(inst->getResultTypeId()); + curInst.push_back(getOrAssignResultId(inst)); + curInst.push_back(getOrAssignResultId(inst->getPayload())); + finalizeInstruction(&mainBinary); + emitDebugNameForInstruction(getOrAssignResultId(inst), + inst->getDebugName()); return true; } @@ -1014,6 +1081,13 @@ bool EmitVisitor::visit(SpirvConstantComposite *inst) { return true; } +bool EmitVisitor::visit(SpirvConstantString *inst) { + typeHandler.getOrCreateConstant(inst); + emitDebugNameForInstruction(getOrAssignResultId(inst), + inst->getDebugName()); + return true; +} + bool EmitVisitor::visit(SpirvConstantNull *inst) { typeHandler.getOrCreateConstant(inst); emitDebugNameForInstruction(getOrAssignResultId(inst), @@ -2081,6 +2155,8 @@ uint32_t EmitTypeHandler::getOrCreateConstant(SpirvConstant *inst) { return getOrCreateConstantNull(constNull); } else if (auto *constBool = dyn_cast(inst)) { return getOrCreateConstantBool(constBool); + } else if (auto *constString = dyn_cast(inst)) { + return getOrCreateConstantString(constString); } else if (auto *constUndef = dyn_cast(inst)) { return getOrCreateUndef(constUndef); } @@ -2119,6 +2195,36 @@ uint32_t EmitTypeHandler::getOrCreateConstantBool(SpirvConstantBoolean *inst) { return inst->getResultId(); } +uint32_t EmitTypeHandler::getOrCreateConstantString(SpirvConstantString *inst) { + const StringRef str = inst->getString(); + const bool isSpecConst = inst->isSpecConstant(); + + if (!isSpecConst && + emittedConstantStrings.find(str) != emittedConstantStrings.end()) { + // Already emitted this constant value. Reuse. + inst->setResultId(emittedConstantStrings[str]->getResultId()); + } else if (isSpecConst && emittedSpecConstantInstructions.find(inst) != + emittedSpecConstantInstructions.end()) { + // We've already emitted this SpecConstant. Reuse. + return inst->getResultId(); + } else { + // Constant wasn't emitted in the past. + const auto &words = string::encodeSPIRVString(inst->getString()); + initTypeInstruction(inst->getopcode()); + curTypeInst.push_back(getOrAssignResultId(inst)); + curTypeInst.insert(curTypeInst.end(), words.begin(), words.end()); + finalizeTypeInstruction(); + // Remember this constant for the future (if not a spec constant) + if (isSpecConst) { + emittedSpecConstantInstructions.insert(inst); + } else { + emittedConstantStrings[str] = inst; + } + } + + return inst->getResultId(); +} + uint32_t EmitTypeHandler::getOrCreateConstantNull(SpirvConstantNull *inst) { auto found = std::find_if(emittedConstantNulls.begin(), emittedConstantNulls.end(), @@ -2539,6 +2645,84 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { if (stride.hasValue()) emitDecoration(id, spv::Decoration::ArrayStride, {stride.getValue()}); } + // NodePayloadArray types + else if (const auto *npaType = dyn_cast(type)) { + const uint32_t elemTypeId = emitType(npaType->getElementType()); + initTypeInstruction(spv::Op::OpTypeNodePayloadArrayAMDX); + curTypeInst.push_back(id); + curTypeInst.push_back(elemTypeId); + finalizeTypeInstruction(); + + // Emit decorations + const ParmVarDecl *nodeDecl = npaType->getNodeDecl(); + if (hlsl::IsHLSLNodeOutputType(nodeDecl->getType())) { + StringRef name = nodeDecl->getName(); + unsigned index = 0; + if (auto nodeID = nodeDecl->getAttr()) { + name = nodeID->getName(); + index = nodeID->getArrayIndex(); + } + + auto *str = new (context) SpirvConstantString(name); + uint32_t nodeName = getOrCreateConstantString(str); + emitDecoration(id, spv::Decoration::PayloadNodeNameAMDX, {nodeName}, + llvm::None, true); + if (index) { + uint32_t baseIndex = getOrCreateConstantInt( + llvm::APInt(32, index), context.getUIntType(32), false); + emitDecoration(id, spv::Decoration::PayloadNodeBaseIndexAMDX, + {baseIndex}, llvm::None, true); + } + } + + uint32_t maxRecords; + if (const auto *attr = nodeDecl->getAttr()) { + maxRecords = getOrCreateConstantInt(llvm::APInt(32, attr->getMaxCount()), + context.getUIntType(32), false); + } else { + maxRecords = getOrCreateConstantInt(llvm::APInt(32, 1), + context.getUIntType(32), false); + } + emitDecoration(id, spv::Decoration::NodeMaxPayloadsAMDX, {maxRecords}, + llvm::None, true); + + if (const auto *attr = nodeDecl->getAttr()) { + const DeclContext *dc = nodeDecl->getParentFunctionOrMethod(); + if (const auto *funDecl = dyn_cast_or_null(dc)) { + IdentifierInfo *ii = attr->getName(); + bool alreadyExists = false; + for (auto *paramDecl : funDecl->params()) { + if (paramDecl->getIdentifier() == ii) { + assert(paramDecl != nodeDecl); + auto otherType = context.getNodeDeclPayloadType(paramDecl); + const uint32_t otherId = + getResultIdForType(otherType, &alreadyExists); + assert(alreadyExists && "forward references not allowed in " + "MaxRecordsSharedWith attribute"); + emitDecoration(id, spv::Decoration::NodeSharesPayloadLimitsWithAMDX, + {otherId}, llvm::None, true); + break; + } + } + assert(alreadyExists && + "invalid reference in MaxRecordsSharedWith attribute"); + } + } + if (const auto *attr = nodeDecl->getAttr()) { + emitDecoration(id, spv::Decoration::PayloadNodeSparseArrayAMDX, {}, + llvm::None); + } + if (const auto *attr = nodeDecl->getAttr()) { + emitDecoration(id, spv::Decoration::PayloadNodeSparseArrayAMDX, {}, + llvm::None); + } + if (const auto *attr = nodeDecl->getAttr()) { + uint32_t arraySize = getOrCreateConstantInt( + llvm::APInt(32, attr->getCount()), context.getUIntType(32), false); + emitDecoration(id, spv::Decoration::PayloadNodeArraySizeAMDX, {arraySize}, + llvm::None, true); + } + } // Structure types else if (const auto *structType = dyn_cast(type)) { std::vector> @@ -2552,6 +2736,15 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { } } + if (const auto recordDecl = dyn_cast_or_null( + context.getStructDeclForSpirvType(structType))) { + auto index = context.getDispatchGridIndex(recordDecl); + if (index.hasValue()) { + emitDecoration(id, spv::Decoration::PayloadDispatchIndirectAMDX, {}, + index); + } + } + // Emit OpMemberName for the struct members. for (size_t i = 0; i < fieldsToGenerate.size(); ++i) emitNameForType(fieldsToGenerate[i].get().name, id, i); @@ -2614,6 +2807,13 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { else if (interfaceType == StructInterfaceType::UniformBuffer) emitDecoration(id, spv::Decoration::Block, {}); + // Emit NodeTrackRWInputSharing decoration if attribute is present. + const auto *structDecl = dyn_cast_or_null( + context.getStructDeclForSpirvType(structType)); + if (structDecl && structDecl->hasAttr()) { + emitDecoration(id, spv::Decoration::TrackFinishWritingAMDX, {}); + } + initTypeInstruction(spv::Op::OpTypeStruct); curTypeInst.push_back(id); for (auto fieldTypeId : fieldTypeIds) @@ -2756,14 +2956,17 @@ void EmitTypeHandler::emitLiteral(const SpirvConstant *literal, void EmitTypeHandler::emitDecoration(uint32_t typeResultId, spv::Decoration decoration, llvm::ArrayRef decorationParams, - llvm::Optional memberIndex) { - + llvm::Optional memberIndex, + bool usesIdParams) { spv::Op op = memberIndex.hasValue() ? spv::Op::OpMemberDecorate : spv::Op::OpDecorate; if (decoration == spv::Decoration::UserTypeGOOGLE) { op = memberIndex.hasValue() ? spv::Op::OpMemberDecorateString : spv::Op::OpDecorateString; } + if (usesIdParams) { + op = spv::Op::OpDecorateId; + } assert(curDecorationInst.empty()); curDecorationInst.push_back(static_cast(op)); diff --git a/tools/clang/lib/SPIRV/EmitVisitor.h b/tools/clang/lib/SPIRV/EmitVisitor.h index bfa0710998..7b1ac3926d 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.h +++ b/tools/clang/lib/SPIRV/EmitVisitor.h @@ -86,7 +86,8 @@ class EmitTypeHandler { // parameters. void emitDecoration(uint32_t typeResultId, spv::Decoration, llvm::ArrayRef decorationParams, - llvm::Optional memberIndex = llvm::None); + llvm::Optional memberIndex = llvm::None, + bool usesIdParams = false); uint32_t getOrCreateConstant(SpirvConstant *); @@ -113,6 +114,7 @@ class EmitTypeHandler { uint32_t getOrCreateConstantNull(SpirvConstantNull *); uint32_t getOrCreateUndef(SpirvUndef *); uint32_t getOrCreateConstantBool(SpirvConstantBoolean *); + uint32_t getOrCreateConstantString(SpirvConstantString *); template void emitLiteral(const SpirvConstant *, vecType &outInst); template @@ -176,6 +178,7 @@ class EmitTypeHandler { emittedConstantInts; llvm::DenseMap, uint32_t> emittedConstantFloats; + llvm::DenseMap emittedConstantStrings; llvm::SmallVector emittedConstantComposites; llvm::SmallVector emittedConstantNulls; llvm::SmallVector emittedUndef; @@ -251,6 +254,11 @@ class EmitVisitor : public Visitor { bool visit(SpirvAccessChain *) override; bool visit(SpirvAtomic *) override; bool visit(SpirvBarrier *) override; + bool visit(SpirvIsNodePayloadValid *inst) override; + bool visit(SpirvNodePayloadArrayLength *inst) override; + bool visit(SpirvAllocateNodePayloads *inst) override; + bool visit(SpirvEnqueueNodePayloads *inst) override; + bool visit(SpirvFinishWritingNodePayload *inst) override; bool visit(SpirvBinaryOp *) override; bool visit(SpirvBitFieldExtract *) override; bool visit(SpirvBitFieldInsert *) override; @@ -258,6 +266,7 @@ class EmitVisitor : public Visitor { bool visit(SpirvConstantInteger *) override; bool visit(SpirvConstantFloat *) override; bool visit(SpirvConstantComposite *) override; + bool visit(SpirvConstantString *) override; bool visit(SpirvConstantNull *) override; bool visit(SpirvConvertPtrToU *) override; bool visit(SpirvConvertUToPtr *) override; @@ -458,6 +467,10 @@ class EmitVisitor : public Visitor { std::vector mainBinary; // String literals to SpirvString objects llvm::StringMap stringIdMap; + // String literals to SpirvConstantString objects + llvm::StringMap stringConstantIdMap; + // String spec constants + llvm::DenseSet stringSpecConstantInstructions; // Main file information for debugging that will be used by OpLine. uint32_t debugMainFileId; // Id for Vulkan DebugInfo extended instruction set. Used when generating diff --git a/tools/clang/lib/SPIRV/FeatureManager.cpp b/tools/clang/lib/SPIRV/FeatureManager.cpp index 7fb449fee9..b6aed4d8b6 100644 --- a/tools/clang/lib/SPIRV/FeatureManager.cpp +++ b/tools/clang/lib/SPIRV/FeatureManager.cpp @@ -214,6 +214,7 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) { .Case("SPV_EXT_shader_image_int64", Extension::EXT_shader_image_int64) .Case("SPV_KHR_physical_storage_buffer", Extension::KHR_physical_storage_buffer) + .Case("SPV_AMDX_shader_enqueue", Extension::AMD_shader_enqueue) .Case("SPV_KHR_vulkan_memory_model", Extension::KHR_vulkan_memory_model) .Case("SPV_KHR_compute_shader_derivatives", Extension::KHR_compute_shader_derivatives) @@ -284,6 +285,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) { return "SPV_EXT_shader_image_int64"; case Extension::KHR_physical_storage_buffer: return "SPV_KHR_physical_storage_buffer"; + case Extension::AMD_shader_enqueue: + return "SPV_AMDX_shader_enqueue"; case Extension::KHR_vulkan_memory_model: return "SPV_KHR_vulkan_memory_model"; case Extension::KHR_compute_shader_derivatives: diff --git a/tools/clang/lib/SPIRV/GlPerVertex.cpp b/tools/clang/lib/SPIRV/GlPerVertex.cpp index 09b09236b4..aa5a40d008 100644 --- a/tools/clang/lib/SPIRV/GlPerVertex.cpp +++ b/tools/clang/lib/SPIRV/GlPerVertex.cpp @@ -324,6 +324,9 @@ bool GlPerVertex::setClipCullDistanceType(SemanticIndexToTypeMap *typeMap, bool GlPerVertex::doGlPerVertexFacts(const NamedDecl *decl, QualType baseType, bool asInput) { + if (hlsl::IsHLSLNodeType(baseType)) { + return true; + } llvm::StringRef semanticStr; const hlsl::Semantic *semantic = {}; diff --git a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp index b31d19b5d8..a415359f1a 100644 --- a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp +++ b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp @@ -365,6 +365,16 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type, return raType; return spvContext.getRuntimeArrayType(loweredElemType, raType->getStride()); } + // Node payload arrays could contain a hybrid type + else if (const auto *npaType = dyn_cast(type)) { + const auto *loweredElemType = + lowerType(npaType->getElementType(), rule, loc); + // If runtime array didn't contain any hybrid types, return itself. + if (npaType->getElementType() == loweredElemType) + return npaType; + return spvContext.getNodePayloadArrayType(loweredElemType, + npaType->getNodeDecl()); + } // Pointer types could point to a hybrid type. else if (const auto *ptrType = dyn_cast(type)) { const auto *loweredPointee = diff --git a/tools/clang/lib/SPIRV/PreciseVisitor.cpp b/tools/clang/lib/SPIRV/PreciseVisitor.cpp index 34e6087990..f1869318a4 100644 --- a/tools/clang/lib/SPIRV/PreciseVisitor.cpp +++ b/tools/clang/lib/SPIRV/PreciseVisitor.cpp @@ -60,6 +60,9 @@ bool isAccessingPrecise(clang::spirv::SpirvAccessChain *inst) { } else if (auto *raType = llvm::dyn_cast(baseType)) { indexes.pop(); baseType = raType->getElementType(); + } else if (auto *npaType = llvm::dyn_cast(baseType)) { + indexes.pop(); + baseType = npaType->getElementType(); } else if (auto *structType = llvm::dyn_cast(baseType)) { SpirvInstruction *index = indexes.top(); if (auto *constInt = llvm::dyn_cast(index)) { diff --git a/tools/clang/lib/SPIRV/SpirvBuilder.cpp b/tools/clang/lib/SPIRV/SpirvBuilder.cpp index 689fc0715f..b6bd4c0268 100644 --- a/tools/clang/lib/SPIRV/SpirvBuilder.cpp +++ b/tools/clang/lib/SPIRV/SpirvBuilder.cpp @@ -84,7 +84,9 @@ SpirvBuilder::addFnParam(QualType ptrType, bool isPrecise, bool isNointerp, param = new (context) SpirvFunctionParameter(ptrType, isPrecise, isNointerp, loc); } - param->setStorageClass(spv::StorageClass::Function); + param->setStorageClass(hlsl::IsHLSLNodeInputType(ptrType) + ? spv::StorageClass::NodePayloadAMDX + : spv::StorageClass::Function); param->setDebugName(name); function->addParameter(param); return param; @@ -233,6 +235,13 @@ SpirvInstruction *SpirvBuilder::createLoad(QualType resultType, createEndInvocationInterlockEXT(loc, range); } + if (context.hasLoweredType(pointer)) { + // preserve distinct node payload array types + auto *ptrType = dyn_cast(pointer->getResultType()); + instruction->setResultType(ptrType->getPointeeType()); + context.addToInstructionsWithLoweredType(instruction); + } + const auto &bitfieldInfo = pointer->getBitfieldInfo(); if (!bitfieldInfo.hasValue()) return instruction; @@ -309,6 +318,12 @@ SpirvStore *SpirvBuilder::createStore(SpirvInstruction *address, auto *instruction = new (context) SpirvStore(loc, address, source, llvm::None, range); + if (context.hasLoweredType(source)) { + // preserve distinct node payload array types + address->setResultType(context.getPointerType(source->getResultType(), + address->getStorageClass())); + context.addToInstructionsWithLoweredType(address); + } insertPoint->addInstruction(instruction); if (address->getStorageClass() == spv::StorageClass::PhysicalStorageBuffer && @@ -875,6 +890,53 @@ SpirvInstruction *SpirvBuilder::createNonSemanticDebugPrintfExtInst( return extInst; } +SpirvInstruction * +SpirvBuilder::createIsNodePayloadValid(SpirvInstruction *payloadArray, + SpirvInstruction *nodeIndex, + SourceLocation loc) { + auto *inst = new (context) + SpirvIsNodePayloadValid(astContext.BoolTy, loc, payloadArray, nodeIndex); + insertPoint->addInstruction(inst); + return inst; +} + +SpirvInstruction * +SpirvBuilder::createNodePayloadArrayLength(SpirvInstruction *payloadArray, + SourceLocation loc) { + auto *inst = new (context) + SpirvNodePayloadArrayLength(astContext.UnsignedIntTy, loc, payloadArray); + insertPoint->addInstruction(inst); + return inst; +} + +SpirvInstruction *SpirvBuilder::createAllocateNodePayloads( + QualType resultType, spv::Scope allocationScope, + SpirvInstruction *shaderIndex, SpirvInstruction *recordCount, + SourceLocation loc) { + assert(insertPoint && "null insert point"); + auto *inst = new (context) SpirvAllocateNodePayloads( + resultType, loc, allocationScope, shaderIndex, recordCount); + insertPoint->addInstruction(inst); + return inst; +} + +void SpirvBuilder::createEnqueueOutputNodePayloads(SpirvInstruction *payload, + SourceLocation loc) { + assert(insertPoint && "null insert point"); + auto *inst = new (context) SpirvEnqueueNodePayloads(loc, payload); + insertPoint->addInstruction(inst); +} + +SpirvInstruction * +SpirvBuilder::createFinishWritingNodePayload(SpirvInstruction *payload, + SourceLocation loc) { + assert(insertPoint && "null insert point"); + auto *inst = new (context) + SpirvFinishWritingNodePayload(astContext.BoolTy, loc, payload); + insertPoint->addInstruction(inst); + return inst; +} + void SpirvBuilder::createBarrier(spv::Scope memoryScope, spv::MemorySemanticsMask memorySemantics, llvm::Optional exec, @@ -1869,6 +1931,14 @@ SpirvConstant *SpirvBuilder::getConstantNull(QualType type) { return nullConst; } +SpirvConstant *SpirvBuilder::getConstantString(llvm::StringRef str, + bool specConst) { + // We do not care about making unique constants at this point. + auto *stringConst = new (context) SpirvConstantString(str, specConst); + mod->addConstant(stringConst); + return stringConst; +} + SpirvUndef *SpirvBuilder::getUndef(QualType type) { // We do not care about making unique constants at this point. auto *undef = new (context) SpirvUndef(type); diff --git a/tools/clang/lib/SPIRV/SpirvContext.cpp b/tools/clang/lib/SPIRV/SpirvContext.cpp index 47dfc67433..61b2f1749f 100644 --- a/tools/clang/lib/SPIRV/SpirvContext.cpp +++ b/tools/clang/lib/SPIRV/SpirvContext.cpp @@ -65,6 +65,9 @@ SpirvContext::~SpirvContext() { for (auto *raType : runtimeArrayTypes) raType->~RuntimeArrayType(); + for (auto *npaType : nodePayloadArrayTypes) + npaType->~NodePayloadArrayType(); + for (auto *fnType : functionTypes) fnType->~FunctionType(); @@ -276,6 +279,19 @@ SpirvContext::getRuntimeArrayType(const SpirvType *elemType, return *(inserted.first); } +const NodePayloadArrayType * +SpirvContext::getNodePayloadArrayType(const SpirvType *elemType, + const ParmVarDecl *nodeDecl) { + NodePayloadArrayType type(elemType, nodeDecl); + auto found = nodePayloadArrayTypes.find(&type); + if (found != nodePayloadArrayTypes.end()) + return *found; + + auto inserted = nodePayloadArrayTypes.insert( + new (this) NodePayloadArrayType(elemType, nodeDecl)); + return *(inserted.first); +} + const StructType * SpirvContext::getStructType(llvm::ArrayRef fields, llvm::StringRef name, bool isReadOnly, diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 9ffa978511..2b8ab0baae 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -268,7 +268,8 @@ inline bool canActAsInParmVar(const ParmVarDecl *param) { return !param->hasAttr() && // GS output streams are marked as inout, but it should not be // used as in parameter. - !hlsl::IsHLSLStreamOutputType(param->getType()); + !hlsl::IsHLSLStreamOutputType(param->getType()) && + !hlsl::IsHLSLNodeOutputType(param->getType()); } /// Returns true if the given function parameter can act as shader stage @@ -1415,6 +1416,83 @@ SpirvInstruction *SpirvEmitter::castToType(SpirvInstruction *value, return nullptr; } +static bool handleDispatchGrid(SpirvContext &spvContext, + const RecordDecl *recordDecl) { + unsigned index = 0; + for (auto fieldDecl : recordDecl->fields()) { + QualType fieldType = fieldDecl->getType(); + for (const hlsl::UnusualAnnotation *it : + fieldDecl->getUnusualAnnotations()) { + if (it->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) { + const hlsl::SemanticDecl *sd = cast(it); + if (sd->SemanticName.equals("SV_DispatchGrid")) { + spvContext.registerDispatchGridIndex(recordDecl, index); + return true; + } + } + } + if (const auto *innerType = fieldType->getAs()) { + if (handleDispatchGrid(spvContext, innerType->getDecl())) + return true; + } + ++index; + } + return false; +} + +bool SpirvEmitter::handleNodePayloadArrayType(const ParmVarDecl *decl, + SpirvInstruction *instr) { + // Because SPIR-V node payload array types are node-specific, propagate + // lowered types + switch (instr->getKind()) { + case SpirvInstruction::Kind::IK_Load: { + SpirvInstruction *ptr = dyn_cast(instr)->getPointer(); + if (handleNodePayloadArrayType(decl, ptr)) { + const SpirvPointerType *ptrType = + dyn_cast(ptr->getResultType()); + instr->setResultType(ptrType->getPointeeType()); + spvContext.addToInstructionsWithLoweredType(instr); + return true; + } + return false; + } + case SpirvInstruction::Kind::IK_FunctionParameter: + case SpirvInstruction::Kind::IK_Variable: { + QualType varType = decl->getType(); + if (hlsl::IsHLSLNodeType(varType)) { + if (auto *type = spvContext.getNodeDeclPayloadType(decl)) { + instr->setResultType( + spvContext.getPointerType(type, instr->getStorageClass())); + } else { + LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions, + spvBuilder); + QualType resultType = + hlsl::GetHLSLNodeIOResultType(astContext, varType); + const auto *recordType = resultType->getAs(); + assert(recordType); + if (hlsl::IsHLSLDispatchNodeInputRecordType(varType)) { + handleDispatchGrid(spvContext, recordType->getDecl()); + } + const SpirvType *elemType = lowerTypeVisitor.lowerType( + resultType, clang::spirv::SpirvLayoutRule::Scalar, llvm::None, + decl->getLocation()); + const NodePayloadArrayType *arrType = + spvContext.getNodePayloadArrayType(elemType, decl); + const SpirvType *ptrType = + spvContext.getPointerType(arrType, instr->getStorageClass()); + instr->setResultType(ptrType); + spvContext.registerNodeDeclPayloadType(arrType, decl); + } + spvContext.addToInstructionsWithLoweredType(instr); + return true; + } + return false; + } + default: + return false; + } +} + void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) { // Forward declaration of a function inside another. if (!decl->isThisDeclarationADefinition()) { @@ -1556,6 +1634,9 @@ void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) { QualType paramType = paramDecl->getType(); auto *param = declIdMapper.createFnParam(paramDecl, i + 1 + isNonStaticMemberFn); + if (isEntry) { + handleNodePayloadArrayType(paramDecl, param); + } #ifdef ENABLE_SPIRV_CODEGEN if (hlsl::IsVKBufferPointerType(paramType)) { Optional isRowMajor = llvm::None; @@ -4309,7 +4390,7 @@ SpirvEmitter::processTextureLevelOfDetail(const CXXMemberCallExpr *expr, spvBuilder.createImageQuery(spv::Op::OpImageQueryLod, queryResultType, expr->getExprLoc(), sampledImage, coordinate); - if (spvContext.isCS()) { + if (spvContext.isCS() || spvContext.isNode()) { addDerivativeGroupExecutionMode(); } // The first component of the float2 contains the mipmap array layer. @@ -5312,6 +5393,9 @@ SpirvEmitter::doCXXMemberCallExpr(const CXXMemberCallExpr *expr) { uint32_t opcode = static_cast(hlsl::IntrinsicOp::Num_Intrinsics); if (hlsl::GetIntrinsicOp(callee, opcode, group)) { + if (group == "subscript") { + return processIntrinsicExtractRecordStruct(expr); + } return processIntrinsicMemberCall(expr, static_cast(opcode)); } @@ -5508,6 +5592,28 @@ SpirvEmitter::processIntrinsicMemberCall(const CXXMemberCallExpr *expr, return processRayQueryIntrinsics(expr, opcode); case IntrinsicOp::MOP_GetBufferContents: return processIntrinsicGetBufferContents(expr); + case hlsl::IntrinsicOp::MOP_GetThreadNodeOutputRecords: + return processIntrinsicGetNodeOutputRecords(expr, false); + case hlsl::IntrinsicOp::MOP_GetGroupNodeOutputRecords: + return processIntrinsicGetNodeOutputRecords(expr, true); + case hlsl::IntrinsicOp::MOP_ThreadIncrementOutputCount: + retVal = processIntrinsicIncrementOutputCount(expr, false); + break; + case hlsl::IntrinsicOp::MOP_GroupIncrementOutputCount: + retVal = processIntrinsicIncrementOutputCount(expr, true); + break; + case hlsl::IntrinsicOp::MOP_IsValid: + retVal = processIntrinsicIsValid(expr); + break; + case hlsl::IntrinsicOp::MOP_Count: + retVal = processIntrinsicGetRecordCount(expr); + break; + case hlsl::IntrinsicOp::MOP_OutputComplete: + processIntrinsicOutputComplete(expr); + break; + case hlsl::IntrinsicOp::MOP_FinishedCrossGroupSharing: + retVal = processIntrinsicFinishedCrossGroupSharing(expr); + break; default: emitError("intrinsic '%0' method unimplemented", expr->getCallee()->getExprLoc()) @@ -5559,7 +5665,8 @@ SpirvInstruction *SpirvEmitter::createImageSample( const bool isExplicit = lod || (grad.first && grad.second); // Implicit-lod instructions are only allowed in pixel and compute shaders. - if (!spvContext.isPS() && !spvContext.isCS() && !isExplicit) + if (!spvContext.isPS() && !spvContext.isCS() && !spvContext.isNode() && + !isExplicit) emitError("sampling with implicit lod is only allowed in fragment and " "compute shaders", loc); @@ -5664,7 +5771,7 @@ SpirvEmitter::processTextureSampleGather(const CXXMemberCallExpr *expr, const auto retType = expr->getDirectCallee()->getReturnType(); if (isSample) { - if (spvContext.isCS()) { + if (spvContext.isCS() || spvContext.isNode()) { addDerivativeGroupExecutionMode(); } return createImageSample(retType, imageType, image, sampler, coordinate, @@ -5754,7 +5861,7 @@ SpirvEmitter::processTextureSampleBiasLevel(const CXXMemberCallExpr *expr, const auto retType = expr->getDirectCallee()->getReturnType(); - if (!lod && spvContext.isCS()) { + if (!lod && (spvContext.isCS() || spvContext.isNode())) { addDerivativeGroupExecutionMode(); } return createImageSample( @@ -8641,9 +8748,10 @@ const Expr *SpirvEmitter::collectArrayStructIndices( } { - // Indexing into ConstantBuffers and TextureBuffers involves an additional - // FlatConversion node which casts the handle to the underlying structure - // type. We can look past the FlatConversion to continue to collect indices. + // Indexing into ConstantBuffers, TextureBuffers, and node input/output + // types involves an additional FlatConversion node which casts the handle + // to the underlying structure type. We can look past the FlatConversion to + // continue to collect indices. // For example: MyConstantBufferArray[0].structMember1 // `-MemberExpr .structMember1 // `-ImplicitCastExpr 'const T' lvalue @@ -8652,7 +8760,8 @@ const Expr *SpirvEmitter::collectArrayStructIndices( if (castExpr->getCastKind() == CK_FlatConversion) { const auto *subExpr = castExpr->getSubExpr(); const QualType subExprType = subExpr->getType(); - if (isConstantTextureBuffer(subExprType)) { + if (isConstantTextureBuffer(subExprType) || + hlsl::IsHLSLNodeType(subExprType)) { return collectArrayStructIndices(subExpr, rawIndex, rawIndices, indices, isMSOutAttribute); } @@ -9054,6 +9163,9 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) { case hlsl::IntrinsicOp::IOP_udot: retVal = processIntrinsicDot(callExpr); break; + case hlsl::IntrinsicOp::IOP_Barrier: + retVal = processIntrinsicBarrier(callExpr); + break; case hlsl::IntrinsicOp::IOP_GroupMemoryBarrier: retVal = processIntrinsicMemoryBarrier(callExpr, /*isDevice*/ false, @@ -9086,6 +9198,9 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) { /*groupSync*/ true, /*isAllBarrier*/ true); break; + case hlsl::IntrinsicOp::IOP_GetRemainingRecursionLevels: + retVal = processIntrinsicGetRemainingRecursionLevels(callExpr); + break; case hlsl::IntrinsicOp::IOP_CheckAccessFullyMapped: retVal = spvBuilder.createImageSparseTexelsResident( doExpr(callExpr->getArg(0)), srcLoc, srcRange); @@ -9562,6 +9677,15 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) { return retVal; } +SpirvInstruction *SpirvEmitter::processIntrinsicGetRecordCount( + const CXXMemberCallExpr *callExpr) { + assert(callExpr->getNumArgs() == 0); + const auto obj = callExpr->getImplicitObjectArgument(); + const auto loc = callExpr->getExprLoc(); + SpirvInstruction *payload = doExpr(obj); + return spvBuilder.createNodePayloadArrayLength(payload, loc); +} + SpirvInstruction * SpirvEmitter::processIntrinsicFirstbit(const CallExpr *callExpr, GLSLstd450 glslOpcode) { @@ -11042,6 +11166,194 @@ SpirvInstruction *SpirvEmitter::processIntrinsicGetBufferContents( return retVal; } +SpirvInstruction *SpirvEmitter::processIntrinsicExtractRecordStruct( + const CXXMemberCallExpr *callExpr) { + Expr *obj = callExpr->getImplicitObjectArgument(); + QualType objType = obj->getType(); + unsigned n = callExpr->getNumArgs(); + assert(hlsl::IsHLSLNodeType(objType)); + assert(n == 0 || n == 1 && hlsl::IsHLSLNodeRecordArrayType(objType)); + + QualType recordType = hlsl::GetHLSLNodeIOResultType(astContext, objType); + SpirvInstruction *res = doExpr(obj); + SpirvInstruction *index = + n ? doExpr(callExpr->getArg(0)) + : spvBuilder.getConstantInt(astContext.UnsignedIntTy, + llvm::APInt(32, 0)); + res->setLayoutRule(SpirvLayoutRule::Scalar); + + return spvBuilder.createAccessChain(recordType, res, {index}, + callExpr->getExprLoc(), + callExpr->getSourceRange()); +} + +SpirvInstruction *SpirvEmitter::processIntrinsicGetRemainingRecursionLevels( + const CallExpr *callExpr) { + assert(callExpr->getNumArgs() == 0); + const auto loc = callExpr->getExprLoc(); + const QualType retType = callExpr->getCallReturnType(astContext); + auto *var = declIdMapper.getBuiltinVar( + spv::BuiltIn::RemainingRecursionLevelsAMDX, retType, loc); + return spvBuilder.createLoad(retType, var, loc); +} + +SpirvInstruction * +SpirvEmitter::processIntrinsicIsValid(const CXXMemberCallExpr *callExpr) { + assert(callExpr->getNumArgs() == 0); + const auto loc = callExpr->getExprLoc(); + const Expr *nodeOutputExpr = callExpr->getImplicitObjectArgument(); + Expr *baseExpr = const_cast(nodeOutputExpr); + SpirvInstruction *shaderIndex = nullptr; + + if (const auto subExpr = dyn_cast_or_null( + nodeOutputExpr->IgnoreParenNoopCasts(astContext))) { + if (subExpr->getOperator() == OverloadedOperatorKind::OO_Subscript) { + // special case: offset shader index by the array subscript + shaderIndex = doExpr(subExpr->getArg(1)); + baseExpr = const_cast(subExpr->getArg(0)); + } + } + + const auto *declRefExpr = dyn_cast(baseExpr->IgnoreImpCasts()); + const auto *paramDecl = dyn_cast(declRefExpr->getDecl()); + int nodeIndex = 0; + if (HLSLNodeIdAttr *nodeId = paramDecl->getAttr()) { + nodeIndex = nodeId->getArrayIndex(); + } + + SpirvInstruction *payload = doExpr(baseExpr); + if (!shaderIndex) { + shaderIndex = spvBuilder.getConstantInt(astContext.UnsignedIntTy, + llvm::APInt(32, nodeIndex)); + } + + return spvBuilder.createIsNodePayloadValid(payload, shaderIndex, loc); +} + +SpirvInstruction *SpirvEmitter::processIntrinsicGetNodeOutputRecords( + const CXXMemberCallExpr *callExpr, bool isGroupShared) { + assert(callExpr->getNumArgs() == 1); + const auto loc = callExpr->getExprLoc(); + const Expr *nodeOutputExpr = callExpr->getImplicitObjectArgument(); + Expr *baseExpr = const_cast(nodeOutputExpr); + SpirvInstruction *shaderIndex = nullptr; + + if (const auto subExpr = dyn_cast_or_null( + nodeOutputExpr->IgnoreParenNoopCasts(astContext))) { + if (subExpr->getOperator() == OverloadedOperatorKind::OO_Subscript) { + // special case: offset shader index by the array subscript + shaderIndex = doExpr(subExpr->getArg(1)); + baseExpr = const_cast(subExpr->getArg(0)); + } + } + + const auto *declRefExpr = dyn_cast(baseExpr->IgnoreImpCasts()); + const auto *paramDecl = dyn_cast(declRefExpr->getDecl()); + const auto *nodeID = paramDecl->getAttr(); + StringRef nodeName = paramDecl->getName(); + unsigned nodeIndex = 0; + if (nodeID) { + nodeName = nodeID->getName(); + nodeIndex = nodeID->getArrayIndex(); + } + + if (!shaderIndex) { + shaderIndex = spvBuilder.getConstantInt(astContext.UnsignedIntTy, + llvm::APInt(32, nodeIndex)); + } + + LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions, + spvBuilder); + const SpirvType *elemType = lowerTypeVisitor.lowerType( + hlsl::GetHLSLNodeIOResultType(astContext, baseExpr->getType()), + clang::spirv::SpirvLayoutRule::Scalar, llvm::None, + paramDecl->getLocation()); + const SpirvType *payloadType = spvContext.getPointerType( + spvContext.getNodePayloadArrayType(elemType, paramDecl), + spv::StorageClass::NodePayloadAMDX); + + spv::Scope scope = + isGroupShared ? spv::Scope::Workgroup : spv::Scope::Invocation; + SpirvInstruction *recordCount = doExpr(callExpr->getArg(0)); + SpirvInstruction *result = spvBuilder.createAllocateNodePayloads( + callExpr->getType(), scope, shaderIndex, recordCount, loc); + result->setResultType(payloadType); + spvContext.addToInstructionsWithLoweredType(result); + return result; +} + +SpirvInstruction *SpirvEmitter::processIntrinsicIncrementOutputCount( + const CXXMemberCallExpr *callExpr, bool isGroupShared) { + return processIntrinsicGetNodeOutputRecords(callExpr, isGroupShared); +} + +void SpirvEmitter::processIntrinsicOutputComplete( + const CXXMemberCallExpr *callExpr) { + Expr *payloadExpr = + callExpr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext); + SpirvInstruction *payload = doExpr(payloadExpr); + spvBuilder.createEnqueueOutputNodePayloads(payload, callExpr->getExprLoc()); +} + +SpirvInstruction *SpirvEmitter::processIntrinsicFinishedCrossGroupSharing( + const CXXMemberCallExpr *callExpr) { + Expr *payloadExpr = callExpr->getImplicitObjectArgument(); + SpirvInstruction *payload = doExpr(payloadExpr); + return spvBuilder.createFinishWritingNodePayload(payload, + callExpr->getExprLoc()); +} + +SpirvInstruction * +SpirvEmitter::processIntrinsicBarrier(const CallExpr *callExpr) { + llvm::APSInt a1(32, true), a2(32, true); + int64_t i1, i2; + const Expr *e1 = callExpr->getArg(0), *e2 = callExpr->getArg(1); + + // object as first argument + if (!e1->EvaluateAsInt(a1, astContext)) { + assert(e1->getType()->isStructureOrClassType()); + a1.setAllBits(); + } + + if (e2->EvaluateAsInt(a2, astContext) && (i1 = a1.getExtValue()) >= 0 && + (i2 = a2.getExtValue()) >= 0) { + } else { + emitError("Barrier arguments must be non-negative integer constants", + callExpr->getExprLoc()); + return nullptr; + } + + if (!(i1 | i2)) { // all zero -> no-op + return nullptr; + } + + spv::Scope memScope = + (i2 & (unsigned)hlsl::DXIL::BarrierSemanticFlag::DeviceScope) + ? spv::Scope::Device + : (i2 & (unsigned)hlsl::DXIL::BarrierSemanticFlag::GroupScope) + ? spv::Scope::Workgroup + : spv::Scope::Invocation; + spv::MemorySemanticsMask memSemaMask = + spv::MemorySemanticsMask::AcquireRelease | + ((i1 & (unsigned)hlsl::DXIL::MemoryTypeFlag::UavMemory) + ? spv::MemorySemanticsMask::UniformMemory + : spv::MemorySemanticsMask::MaskNone) | + ((i1 & (unsigned)hlsl::DXIL::MemoryTypeFlag::GroupSharedMemory) + ? spv::MemorySemanticsMask::WorkgroupMemory + : spv::MemorySemanticsMask::MaskNone) | + ((i1 & (unsigned)hlsl::DXIL::MemoryTypeFlag::NodeOutputMemory) + ? spv::MemorySemanticsMask::OutputMemory + : spv::MemorySemanticsMask::MaskNone); + Optional execScope = + (i2 & (unsigned)hlsl::DXIL::BarrierSemanticFlag::GroupSync) + ? Optional(spv::Scope::Workgroup) + : None; + + spvBuilder.createBarrier(memScope, memSemaMask, execScope, + callExpr->getExprLoc()); + return nullptr; +} + SpirvInstruction * SpirvEmitter::processIntrinsicMemoryBarrier(const CallExpr *callExpr, bool isDevice, bool groupSync, @@ -12179,7 +12491,7 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingSpirvInst( case spv::Op::OpFwidth: case spv::Op::OpFwidthFine: case spv::Op::OpFwidthCoarse: - if (spvContext.isCS()) + if (spvContext.isCS() || spvContext.isNode()) addDerivativeGroupExecutionMode(); needsLegalization = true; break; @@ -13264,6 +13576,7 @@ hlsl::ShaderModel::Kind SpirvEmitter::getShaderModelKind(StringRef stageName) { .Case("callable", hlsl::ShaderModel::Kind::Callable) .Case("mesh", hlsl::ShaderModel::Kind::Mesh) .Case("amplification", hlsl::ShaderModel::Kind::Amplification) + .Case("node", hlsl::ShaderModel::Kind::Node) .Default(hlsl::ShaderModel::Kind::Invalid); assert(SMK != hlsl::ShaderModel::Kind::Invalid); return SMK; @@ -13284,6 +13597,7 @@ SpirvEmitter::getSpirvShaderStage(hlsl::ShaderModel::Kind smk, case hlsl::ShaderModel::Kind::Pixel: return spv::ExecutionModel::Fragment; case hlsl::ShaderModel::Kind::Compute: + case hlsl::ShaderModel::Kind::Node: return spv::ExecutionModel::GLCompute; case hlsl::ShaderModel::Kind::RayGeneration: return spv::ExecutionModel::RayGenerationNV; @@ -13504,6 +13818,21 @@ void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) { } } +void SpirvEmitter::checkForWaveSizeAttr(const FunctionDecl *decl) { + if (auto *waveSizeAttr = decl->getAttr()) { + // Not supported in Vulkan SPIR-V, warn and ignore. + + // SPIR-V SubgroupSize execution mode would work but it is Kernel only + // (requires the SubgroupDispatch capability, which implies the + // DeviceEnqueue capability, which is Kernel only). Subgroup sizes can be + // specified in Vulkan on the application side via + // VK_EXT_subgroup_size_control. + emitWarning("Wave size is not supported by Vulkan SPIR-V. Consider using " + "VK_EXT_subgroup_size_control.", + waveSizeAttr->getLocation()); + } +} + void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) { auto *numThreadsAttr = decl->getAttr(); assert(numThreadsAttr && "thread group size missing from entry-point"); @@ -13515,19 +13844,82 @@ void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) { spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, {x, y, z}, decl->getLocation()); - auto *waveSizeAttr = decl->getAttr(); - if (waveSizeAttr) { - // Not supported in Vulkan SPIR-V, warn and ignore. + checkForWaveSizeAttr(decl); +} - // SPIR-V SubgroupSize execution mode would work but it is Kernel only - // (requires the SubgroupDispatch capability, which implies the - // DeviceEnqueue capability, which is Kernel only). Subgroup sizes can be - // specified in Vulkan on the application side via - // VK_EXT_subgroup_size_control. - emitWarning("Wave size is not supported by Vulkan SPIR-V. Consider using " - "VK_EXT_subgroup_size_control.", - waveSizeAttr->getLocation()); +void SpirvEmitter::processNodeShaderAttributes(const FunctionDecl *decl) { + uint32_t x = 1, y = 1, z = 1; + if (auto *numThreadsAttr = decl->getAttr()) { + x = static_cast(numThreadsAttr->getX()); + y = static_cast(numThreadsAttr->getY()); + z = static_cast(numThreadsAttr->getZ()); + } + spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, + {x, y, z}, decl->getLocation()); + + auto *nodeLaunchAttr = decl->getAttr(); + StringRef launchType = nodeLaunchAttr ? nodeLaunchAttr->getLaunchType() : ""; + if (launchType.equals("coalescing") || launchType.equals("thread")) { + spvBuilder.addExecutionMode(entryFunction, + spv::ExecutionMode::CoalescingAMDX, {}, + decl->getLocation()); } + + uint64_t nodeId = 0; + if (const auto nodeIdAttr = decl->getAttr()) + nodeId = static_cast(nodeIdAttr->getArrayIndex()); + spvBuilder.addExecutionModeId( + entryFunction, spv::ExecutionMode::ShaderIndexAMDX, + {spvBuilder.getConstantInt(astContext.UnsignedIntTy, + llvm::APInt(32, nodeId))}, + decl->getLocation()); + + if (const auto *nodeMaxRecursionDepthAttr = + decl->getAttr()) { + SpirvInstruction *count = spvBuilder.getConstantInt( + astContext.UnsignedIntTy, + llvm::APInt(32, nodeMaxRecursionDepthAttr->getCount())); + spvBuilder.addExecutionModeId(entryFunction, + spv::ExecutionMode::MaxNodeRecursionAMDX, + {count}, decl->getLocation()); + } + + if (const auto *nodeShareInputOfAttr = + decl->getAttr()) { + SpirvInstruction *name = + spvBuilder.getConstantString(nodeShareInputOfAttr->getName()); + SpirvInstruction *index = spvBuilder.getConstantInt( + astContext.UnsignedIntTy, + llvm::APInt(32, nodeShareInputOfAttr->getArrayIndex())); + spvBuilder.addExecutionModeId(entryFunction, + spv::ExecutionMode::SharesInputWithAMDX, + {name, index}, decl->getLocation()); + } + + if (const auto *dispatchGrid = decl->getAttr()) { + SpirvInstruction *gridX = spvBuilder.getConstantInt( + astContext.UnsignedIntTy, llvm::APInt(32, dispatchGrid->getX())); + SpirvInstruction *gridY = spvBuilder.getConstantInt( + astContext.UnsignedIntTy, llvm::APInt(32, dispatchGrid->getY())); + SpirvInstruction *gridZ = spvBuilder.getConstantInt( + astContext.UnsignedIntTy, llvm::APInt(32, dispatchGrid->getZ())); + spvBuilder.addExecutionModeId(entryFunction, + spv::ExecutionMode::StaticNumWorkgroupsAMDX, + {gridX, gridY, gridZ}, decl->getLocation()); + } else if (const auto *maxDispatchGrid = + decl->getAttr()) { + SpirvInstruction *gridX = spvBuilder.getConstantInt( + astContext.UnsignedIntTy, llvm::APInt(32, maxDispatchGrid->getX())); + SpirvInstruction *gridY = spvBuilder.getConstantInt( + astContext.UnsignedIntTy, llvm::APInt(32, maxDispatchGrid->getY())); + SpirvInstruction *gridZ = spvBuilder.getConstantInt( + astContext.UnsignedIntTy, llvm::APInt(32, maxDispatchGrid->getZ())); + spvBuilder.addExecutionModeId(entryFunction, + spv::ExecutionMode::MaxNumWorkgroupsAMDX, + {gridX, gridY, gridZ}, decl->getLocation()); + } + + checkForWaveSizeAttr(decl); } bool SpirvEmitter::processTessellationShaderAttributes( @@ -13970,6 +14362,8 @@ SpirvFunction *SpirvEmitter::emitEntryFunctionWrapper( processPixelShaderAttributes(decl); } else if (spvContext.isCS()) { processComputeShaderAttributes(decl); + } else if (spvContext.isNode()) { + processNodeShaderAttributes(decl); } else if (spvContext.isHS()) { if (!processTessellationShaderAttributes(decl, &numOutputControlPoints)) return nullptr; @@ -14078,12 +14472,23 @@ SpirvFunction *SpirvEmitter::emitEntryFunctionWrapper( llvm::SmallVector params; for (const auto *param : decl->params()) { const auto paramType = param->getType(); + if (hlsl::IsHLSLNodeInputType(paramType)) { + SpirvInstruction *value = nullptr; + if (!declIdMapper.createStageInputVar(param, &value, false)) + return nullptr; + if (value && value->getKind() == SpirvInstruction::Kind::IK_Variable) { + handleNodePayloadArrayType(param, value); + params.push_back(value); + } + continue; + } + std::string tempVarName = "param.var." + param->getNameAsString(); auto *tempVar = spvBuilder.addFnVar(paramType, param->getLocation(), tempVarName, param->hasAttr(), param->hasAttr()); - + handleNodePayloadArrayType(param, tempVar); params.push_back(tempVar); // Create the stage input variable for parameter not marked as pure out and @@ -14101,6 +14506,9 @@ SpirvFunction *SpirvEmitter::emitEntryFunctionWrapper( if (!declIdMapper.createStageInputVar(param, &loadedValue, false)) return nullptr; + if (loadedValue) { + handleNodePayloadArrayType(param, loadedValue); + } // Only initialize the temporary variable if the parameter is indeed used, // or if it is an inout parameter. diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 10694313a8..892b81408d 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -126,6 +126,8 @@ class SpirvEmitter : public ASTConsumer { SourceRange range = {}); private: + bool handleNodePayloadArrayType(const ParmVarDecl *decl, + SpirvInstruction *instr); void doFunctionDecl(const FunctionDecl *decl); void doVarDecl(const VarDecl *decl); void doRecordDecl(const RecordDecl *decl); @@ -504,6 +506,9 @@ class SpirvEmitter : public ASTConsumer { SpirvInstruction * processIntrinsicGetBufferContents(const CXXMemberCallExpr *); + /// Processes the 'Barrier' intrinsic function. + SpirvInstruction *processIntrinsicBarrier(const CallExpr *); + /// Processes the 'GroupMemoryBarrier', 'GroupMemoryBarrierWithGroupSync', /// 'DeviceMemoryBarrier', 'DeviceMemoryBarrierWithGroupSync', /// 'AllMemoryBarrier', and 'AllMemoryBarrierWithGroupSync' intrinsic @@ -512,6 +517,40 @@ class SpirvEmitter : public ASTConsumer { bool isDevice, bool groupSync, bool isAllBarrier); + /// Processes the 'GetRemainingRecursionLevels' intrinsic function. + SpirvInstruction * + processIntrinsicGetRemainingRecursionLevels(const CallExpr *callExpr); + + /// Processes the 'IsValid' intrinsic function. + SpirvInstruction *processIntrinsicIsValid(const CXXMemberCallExpr *callExpr); + + /// Processes the 'Get' intrinsic function for (arrays of) node records and + /// the array subscript operator for node record arrays. + SpirvInstruction * + processIntrinsicExtractRecordStruct(const CXXMemberCallExpr *callExpr); + + /// Processes the 'GetGroupNodeOutputRecords' and 'GetThreadNodeOutputRecords' + /// intrinsic functions. + SpirvInstruction * + processIntrinsicGetNodeOutputRecords(const CXXMemberCallExpr *callExpr, + bool isGroupShared); + + /// Processes the 'IncrementOutputCount' intrinsic function. + SpirvInstruction * + processIntrinsicIncrementOutputCount(const CXXMemberCallExpr *callExpr, + bool isGroupShared); + + /// Processes the 'Count' intrinsic function for node input record arrays. + SpirvInstruction * + processIntrinsicGetRecordCount(const CXXMemberCallExpr *callExpr); + + /// Processes the 'OutputComplete' intrinsic function. + void processIntrinsicOutputComplete(const CXXMemberCallExpr *callExpr); + + /// Processes the 'FinishedCrossGroupSharing' intrinsic function. + SpirvInstruction * + processIntrinsicFinishedCrossGroupSharing(const CXXMemberCallExpr *callExpr); + /// Processes the 'mad' intrinsic function. SpirvInstruction *processIntrinsicMad(const CallExpr *); @@ -849,6 +888,7 @@ class SpirvEmitter : public ASTConsumer { static hlsl::ShaderModel::Kind getShaderModelKind(StringRef stageName); static spv::ExecutionModel getSpirvShaderStage(hlsl::ShaderModel::Kind smk, bool); + void checkForWaveSizeAttr(const FunctionDecl *decl); /// \brief Handle inline SPIR-V attributes for the entry function. void processInlineSpirvAttributes(const FunctionDecl *entryFunction); @@ -875,6 +915,10 @@ class SpirvEmitter : public ASTConsumer { /// HLSL attributes of the entry point function. void processComputeShaderAttributes(const FunctionDecl *entryFunction); + /// \brief Adds necessary execution modes for the node shader based on the + /// HLSL attributes of the entry point function. + void processNodeShaderAttributes(const FunctionDecl *entryFunction); + /// \brief Adds necessary execution modes for the mesh/amplification shader /// based on the HLSL attributes of the entry point function. bool diff --git a/tools/clang/lib/SPIRV/SpirvInstruction.cpp b/tools/clang/lib/SPIRV/SpirvInstruction.cpp index 3b5861710d..575c9c9652 100644 --- a/tools/clang/lib/SPIRV/SpirvInstruction.cpp +++ b/tools/clang/lib/SPIRV/SpirvInstruction.cpp @@ -55,6 +55,11 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvUnreachable) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvAccessChain) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvAtomic) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBarrier) +DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvIsNodePayloadValid) +DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvNodePayloadArrayLength) +DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvAllocateNodePayloads) +DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvEnqueueNodePayloads) +DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvFinishWritingNodePayload) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBinaryOp) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBitFieldExtract) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBitFieldInsert) @@ -62,6 +67,7 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantBoolean) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantInteger) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantFloat) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantComposite) +DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantString) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantNull) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConvertPtrToU) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConvertUToPtr) @@ -472,6 +478,41 @@ SpirvBarrier::SpirvBarrier(SourceLocation loc, spv::Scope memScope, memoryScope(memScope), memorySemantics(memSemantics), executionScope(execScope) {} +SpirvIsNodePayloadValid::SpirvIsNodePayloadValid(QualType resultType, + SourceLocation loc, + SpirvInstruction *payloadArray, + SpirvInstruction *nodeIndex) + : SpirvInstruction(IK_IsNodePayloadValid, spv::Op::OpIsNodePayloadValidAMDX, + resultType, loc), + payloadArray(payloadArray), nodeIndex(nodeIndex) {} + +SpirvNodePayloadArrayLength::SpirvNodePayloadArrayLength( + QualType resultType, SourceLocation loc, SpirvInstruction *payloadArray) + : SpirvInstruction(IK_NodePayloadArrayLength, + spv::Op::OpNodePayloadArrayLengthAMDX, resultType, loc), + payloadArray(payloadArray) {} + +SpirvAllocateNodePayloads::SpirvAllocateNodePayloads( + QualType resultType, SourceLocation loc, spv::Scope allocationScope, + SpirvInstruction *shaderIndex, SpirvInstruction *recordCount) + : SpirvInstruction(IK_AllocateNodePayloads, + spv::Op::OpAllocateNodePayloadsAMDX, resultType, loc), + allocationScope(allocationScope), shaderIndex(shaderIndex), + recordCount(recordCount) {} + +SpirvEnqueueNodePayloads::SpirvEnqueueNodePayloads(SourceLocation loc, + SpirvInstruction *payload) + : SpirvInstruction(IK_EnqueueNodePayloads, + spv::Op::OpEnqueueNodePayloadsAMDX, QualType(), loc), + payload(payload) {} + +SpirvFinishWritingNodePayload::SpirvFinishWritingNodePayload( + QualType resultType, SourceLocation loc, SpirvInstruction *payload) + : SpirvInstruction(IK_FinishWritingNodePayload, + spv::Op::OpFinishWritingNodePayloadAMDX, resultType, + loc), + payload(payload) {} + SpirvBinaryOp::SpirvBinaryOp(spv::Op opcode, QualType resultType, SourceLocation loc, SpirvInstruction *op1, SpirvInstruction *op2, SourceRange range) @@ -568,7 +609,8 @@ bool SpirvConstant::isSpecConstant() const { return opcode == spv::Op::OpSpecConstant || opcode == spv::Op::OpSpecConstantTrue || opcode == spv::Op::OpSpecConstantFalse || - opcode == spv::Op::OpSpecConstantComposite; + opcode == spv::Op::OpSpecConstantComposite || + opcode == spv::Op::OpSpecConstantStringAMDX; } SpirvConstantBoolean::SpirvConstantBoolean(QualType type, bool val, @@ -623,6 +665,19 @@ SpirvConstantComposite::SpirvConstantComposite( type), constituents(constituentsVec.begin(), constituentsVec.end()) {} +SpirvConstantString::SpirvConstantString(llvm::StringRef stringLiteral, + bool isSpecConst) + : SpirvConstant(IK_ConstantString, + isSpecConst ? spv::Op::OpSpecConstantStringAMDX + : spv::Op::OpConstantStringAMDX, + QualType()), + str(stringLiteral) {} + +bool SpirvConstantString::operator==(const SpirvConstantString &that) const { + return opcode == that.opcode && resultType == that.resultType && + str == that.str; +} + SpirvConstantNull::SpirvConstantNull(QualType type) : SpirvConstant(IK_ConstantNull, spv::Op::OpConstantNull, type) {} diff --git a/tools/clang/lib/SPIRV/SpirvType.cpp b/tools/clang/lib/SPIRV/SpirvType.cpp index cabeba4cda..286e6224a4 100644 --- a/tools/clang/lib/SPIRV/SpirvType.cpp +++ b/tools/clang/lib/SPIRV/SpirvType.cpp @@ -167,6 +167,10 @@ bool RuntimeArrayType::operator==(const RuntimeArrayType &that) const { (!stride.hasValue() || stride.getValue() == that.stride.getValue()); } +bool NodePayloadArrayType::operator==(const NodePayloadArrayType &that) const { + return elementType == that.elementType && nodeDecl == that.nodeDecl; +} + bool SpvIntrinsicTypeOperand::operator==( const SpvIntrinsicTypeOperand &that) const { if (isTypeOperand != that.isTypeOperand) diff --git a/tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv b/tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv index 30565394b4..1425137c68 100644 --- a/tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv +++ b/tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv @@ -161,7 +161,7 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch + Barrier(UAV_MEMORY|GROUP_SHARED_MEMORY|NODE_INPUT_MEMORY|NODE_OUTPUT_MEMORY, + DEVICE_SCOPE); + + // AllMemoryBarrierWithGroupSync() -> + Barrier(UAV_MEMORY|GROUP_SHARED_MEMORY|NODE_INPUT_MEMORY|NODE_OUTPUT_MEMORY, + GROUP_SYNC|DEVICE_SCOPE); + + // DeviceMemoryBarrier() -> + Barrier(UAV_MEMORY, + DEVICE_SCOPE); + + // DeviceMemoryBarrierWithGroupSync() -> + Barrier(UAV_MEMORY, + GROUP_SYNC|DEVICE_SCOPE); + + // GroupMemoryBarrier() -> + Barrier(GROUP_SHARED_MEMORY, + GROUP_SCOPE); + + // GroupMemoryBarrierWithGroupSync() -> + Barrier(GROUP_SHARED_MEMORY, + GROUP_SYNC|GROUP_SCOPE); +} + + +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U2:%[^ ]*]] = OpConstant %uint 2 +// CHECK-DAG: [[U5:%[^ ]*]] = OpConstant %uint 5 +// CHECK-DAG: [[U72:%[^ ]*]] = OpConstant %uint 72 +// CHECK-DAG: [[U264:%[^ ]*]] = OpConstant %uint 264 +// CHECK-DAG: [[U328:%[^ ]*]] = OpConstant %uint 328 +// CHECK-DAG: [[U4424:%[^ ]*]] = OpConstant %uint 4424 + +// CHECK: OpControlBarrier [[U2]] [[U2]] [[U72]] +// CHECK: OpMemoryBarrier [[U2]] [[U328]] +// CHECK: OpMemoryBarrier [[U5]] [[U4424]] +// CHECK: OpControlBarrier [[U2]] [[U5]] [[U4424]] +// CHECK: OpMemoryBarrier [[U5]] [[U72]] +// CHECK: OpControlBarrier [[U2]] [[U5]] [[U72]] +// CHECK: OpMemoryBarrier [[U2]] [[U264]] +// CHECK: OpControlBarrier [[U2]] [[U2]] [[U264]] diff --git a/tools/clang/test/CodeGenSPIRV/node.barrier.object-arg.hlsl b/tools/clang/test/CodeGenSPIRV/node.barrier.object-arg.hlsl new file mode 100644 index 0000000000..215acf7bfd --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.barrier.object-arg.hlsl @@ -0,0 +1,213 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// Barrier is called with each node record and UAV type + +struct RECORD +{ + uint value; +}; + +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U256:%[^ ]*]] = OpConstant [[UINT]] 256 +// CHECK-DAG: [[U1:%[^ ]*]] = OpConstant [[UINT]] 1 +// CHECK-DAG: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK-DAG: [[U3:%[^ ]*]] = OpConstant [[UINT]] 3 +// CHECK-DAG: [[U4:%[^ ]*]] = OpConstant [[UINT]] 4 +// CHECK-DAG: [[U2:%[^ ]*]] = OpConstant [[UINT]] 2 +// CHECK-DAG: [[U4424:%[^ ]*]] = OpConstant [[UINT]] 4424 +// CHECK-DAG: [[U5:%[^ ]*]] = OpConstant [[UINT]] 5 + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(256,1,1)] +[NodeDispatchGrid(256,1,1)] +void node01(DispatchNodeInputRecord input) +{ + Barrier(input, 5); +} + +// CHECK: OpControlBarrier %uint_2 %uint_5 %uint_4424 + +[Shader("node")] +[NodeLaunch("coalescing")] +[NumThreads(256,1,1)] +void node02([MaxRecords(8)] GroupNodeInputRecords input) +{ + Barrier(input, 3); +} + +// CHECK: OpControlBarrier %uint_2 %uint_2 %uint_4424 + +[Shader("node")] +[NodeLaunch("thread")] +void node03(RWThreadNodeInputRecord input) +{ + Barrier(input, 0); +} + +// CHECK: OpMemoryBarrier %uint_4 %uint_4424 + +[Shader("node")] +[NodeLaunch("coalescing")] +[NumThreads(256,1,1)] +void node04([MaxRecords(6)] RWGroupNodeInputRecords input) +{ + Barrier(input, 0); +} + +// CHECK: OpMemoryBarrier %uint_4 %uint_4424 + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(256,1,1)] +[NodeDispatchGrid(256,1,1)] +void node05([MaxRecords(5)] NodeOutput outputs) +{ + ThreadNodeOutputRecords outrec = outputs.GetThreadNodeOutputRecords(1); + Barrier(outrec, 0); +} + +// CHECK: OpMemoryBarrier %uint_4 %uint_4424 + +[Shader("node")] +[NodeLaunch("thread")] +void node06([MaxRecords(5)] NodeOutput outputs) +{ + ThreadNodeOutputRecords outrec = outputs.GetThreadNodeOutputRecords(3); + Barrier(outrec, 0); +} + +// CHECK: OpMemoryBarrier %uint_4 %uint_4424 + +[Shader("node")] +[NodeLaunch("coalescing")] +[NumThreads(256,1,3)] +void node07([MaxRecords(5)] NodeOutput outputs) +{ + GroupNodeOutputRecords outrec = outputs.GetGroupNodeOutputRecords(1); + Barrier(outrec, 3); +} + +// CHECK: OpControlBarrier %uint_2 %uint_2 %uint_4424 + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(256,1,4)] +[NodeDispatchGrid(256,1,1)] +void node08([MaxRecords(5)] NodeOutput outputs) +{ + GroupNodeOutputRecords outrec = outputs.GetGroupNodeOutputRecords(4); + Barrier(outrec, 3); +} + +// CHECK: OpControlBarrier %uint_2 %uint_2 %uint_4424 + +RWBuffer obj09; +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(256,1,4)] +[NodeDispatchGrid(256,1,1)] +void node09() +{ + Barrier(obj09, 5); +} + +// CHECK: OpControlBarrier %uint_2 %uint_5 %uint_4424 + +RWTexture1D obj10; +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(256,1,4)] +[NodeDispatchGrid(256,1,1)] +void node10() +{ + Barrier(obj10, 5); +} + +// CHECK: OpControlBarrier %uint_2 %uint_5 %uint_4424 + +RWTexture1DArray obj11; +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(256,1,4)] +[NodeDispatchGrid(256,1,1)] +void node11() +{ + Barrier(obj11, 5); +} + +// CHECK: OpControlBarrier %uint_2 %uint_5 %uint_4424 + +RWTexture2D obj12; +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(256,1,4)] +[NodeDispatchGrid(256,1,1)] +void node12() +{ + Barrier(obj12, 5); +} + +// CHECK: OpControlBarrier %uint_2 %uint_5 %uint_4424 + +RWTexture2DArray obj13; +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(256,1,4)] +[NodeDispatchGrid(256,1,1)] +void node13() +{ + Barrier(obj13, 5); +} + +// CHECK: OpControlBarrier %uint_2 %uint_5 %uint_4424 + +RWTexture3D obj14; +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(256,1,4)] +[NodeDispatchGrid(256,1,1)] +void node14() +{ + Barrier(obj14, 5); +} + +// CHECK: OpControlBarrier %uint_2 %uint_5 %uint_4424 + +RWStructuredBuffer obj15; +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(256,1,4)] +[NodeDispatchGrid(256,1,1)] +void node15() +{ + Barrier(obj15, 5); +} + +// CHECK: OpControlBarrier %uint_2 %uint_5 %uint_4424 + +RWByteAddressBuffer obj16; +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(256,1,4)] +[NodeDispatchGrid(256,1,1)] +void node16() +{ + Barrier(obj16, 5); +} + +// CHECK: OpControlBarrier %uint_2 %uint_5 %uint_4424 + +AppendStructuredBuffer obj17; +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(256,1,4)] +[NodeDispatchGrid(256,1,1)] +void node17() +{ + Barrier(obj17, 5); +} + +// CHECK: OpControlBarrier %uint_2 %uint_5 %uint_4424 diff --git a/tools/clang/test/CodeGenSPIRV/node.broadcasting.no-input.hlsl b/tools/clang/test/CodeGenSPIRV/node.broadcasting.no-input.hlsl new file mode 100644 index 0000000000..a3c369b252 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.broadcasting.no-input.hlsl @@ -0,0 +1,15 @@ +// RUN: %dxc -spirv -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// Broadcasting launch node with no input + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(3,4,5)] +[NumThreads(6,7,1)] +[NodeIsProgramEntry] +void node070_broadcasting_noinput() +{ +} + +// CHECK: OpReturn + diff --git a/tools/clang/test/CodeGenSPIRV/node.coalescing.num-threads.hlsl b/tools/clang/test/CodeGenSPIRV/node.coalescing.num-threads.hlsl new file mode 100644 index 0000000000..14e899da02 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.coalescing.num-threads.hlsl @@ -0,0 +1,16 @@ +// RUN: %dxc -spirv -Od -T lib_6_8 -fspv-target-env=vulkan1.3 external %s | FileCheck %s + +// Coalescing launch node with thread group defined in the shader + +[Shader("node")] +[NodeLaunch("coalescing")] +[NumThreads(1024,1,1)] +[NodeIsProgramEntry] +void node008_coalescing_numthreads_shader() +{ +} + +// CHECK: OpEntryPoint GLCompute [[SHADER:%[0-9A-Za-z_]*]] +// CHECK-DAG: OpExecutionMode [[SHADER]] CoalescingAMDX +// CHECK-DAG: OpExecutionMode [[SHADER]] LocalSize 1024 1 1 +// CHECK: OpReturn diff --git a/tools/clang/test/CodeGenSPIRV/node.dispatch-grid.hlsl b/tools/clang/test/CodeGenSPIRV/node.dispatch-grid.hlsl new file mode 100644 index 0000000000..302c8ea698 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.dispatch-grid.hlsl @@ -0,0 +1,28 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 external %s | FileCheck %s +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// Broadcasting launch node with dispatch grid defined in shader + +struct INPUT_NOGRID +{ + uint textureIndex; +}; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(2,3,2)] +[NumThreads(1024,1,1)] +[NodeIsProgramEntry] +void node001_dispatchgrid_shader(DispatchNodeInputRecord input) +{ +} + +// CHECK: OpEntryPoint GLCompute [[SHADER:%[0-9A-Za-z_]*]] +// CHECK-DAG: OpExecutionMode [[SHADER]] LocalSize 1024 1 1 +// CHECK-DAG: OpExecutionModeId [[SHADER]] StaticNumWorkgroupsAMDX [[U2:%[0-9A-Za-z_]*]] +// CHECK-SAME: [[U3:%[^ ]*]] [[U2]] +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U2]] = OpConstant [[UINT]] 2 +// CHECK-DAG: [[U3]] = OpConstant [[UINT]] 3 +// CHECK: OpReturn diff --git a/tools/clang/test/CodeGenSPIRV/node.empty-node-input.hlsl b/tools/clang/test/CodeGenSPIRV/node.empty-node-input.hlsl new file mode 100644 index 0000000000..fa16429a1b --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.empty-node-input.hlsl @@ -0,0 +1,28 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// Coalescing launch node declares EmptyNodeInput + +RWBuffer buf0; + +[Shader("node")] +[NodeLaunch("coalescing")] +[NodeIsProgramEntry] +[NumThreads(2,1,1)] +void emptynodeinput(EmptyNodeInput input) +{ + // input.Count should always return 1 here, so there is + // an opportunity for an optimization. + buf0[0] = input.Count(); +} + +// CHECK-DAG: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK-DAG: [[IMG:%[^ ]*]] = OpTypeImage [[UINT]] Buffer 2 0 0 2 R32ui +// CHECK-DAG: [[IMGPTR:%[^ ]*]] = OpTypePointer UniformConstant [[IMG]] +// CHECK-DAG: [[BUF:%[^ ]*]] = OpVariable [[IMGPTR]] UniformConstant + +// CHECK: [[COUNT:%[^ ]*]] = OpNodePayloadArrayLengthAMDX [[UINT]] +// CHECK: [[IMAGE:%[^ ]*]] = OpLoad [[IMG]] [[BUF]] +// CHECK: OpImageWrite [[IMAGE]] [[U0]] [[COUNT]] None diff --git a/tools/clang/test/CodeGenSPIRV/node.finished-cross-group-sharing.hlsl b/tools/clang/test/CodeGenSPIRV/node.finished-cross-group-sharing.hlsl new file mode 100644 index 0000000000..8e1ce56307 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.finished-cross-group-sharing.hlsl @@ -0,0 +1,32 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// FinishedCrossGroupSharing() is called with RWDispatchNodeInputRecord + +RWBuffer buf0; + +struct [NodeTrackRWInputSharing] INPUT_RECORD +{ + uint value; +}; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(256,1,1)] +[NumThreads(1,1,1)] +void node037_finishedcrossgroupsharing(RWDispatchNodeInputRecord input) +{ + bool b = input.FinishedCrossGroupSharing(); + buf0[0] = 0 ? b : 1; +} + +// CHECK: OpName [[INPUT:%[^ ]*]] "input" +// CHECK: OpDecorate [[STRUCT:%[^ ]*]] TrackFinishWritingAMDX +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK: [[STRUCT]] = OpTypeStruct [[UINT]] +// CHECK: [[ARR:%[^ ]*]] = OpTypeNodePayloadArrayAMDX [[STRUCT]] +// CHECK: [[PTR:%[^ ]*]] = OpTypePointer NodePayloadAMDX [[ARR]] +// CHECK: [[BOOL:%[^ ]*]] = OpTypeBool +// CHECK: [[INPUT]] = OpFunctionParameter [[PTR]] +// CHECK: OpFinishWritingNodePayloadAMDX [[BOOL]] [[INPUT]] diff --git a/tools/clang/test/CodeGenSPIRV/node.get-input-record-count.hlsl b/tools/clang/test/CodeGenSPIRV/node.get-input-record-count.hlsl new file mode 100644 index 0000000000..a3af668c46 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.get-input-record-count.hlsl @@ -0,0 +1,25 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 external -fspv-target-env=vulkan1.3 %s | FileCheck %s +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// GetInputRecordCount() called with NodeInputRecordArray + +RWBuffer buf0; + +struct INPUT_RECORD +{ + uint textureIndex; +}; + +[Shader("node")] +[NodeLaunch("coalescing")] +[NumThreads(1024,1,1)] +[NodeIsProgramEntry] +void node014_getinputrecordcount([MaxRecords(256)] GroupNodeInputRecords inputs) +{ + uint numRecords = inputs.Count(); + buf0[0] = numRecords; +} + +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK: OpNodePayloadArrayLengthAMDX [[UINT]] diff --git a/tools/clang/test/CodeGenSPIRV/node.get-node-output-record.multiple.hlsl b/tools/clang/test/CodeGenSPIRV/node.get-node-output-record.multiple.hlsl new file mode 100644 index 0000000000..d029bd20bb --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.get-node-output-record.multiple.hlsl @@ -0,0 +1,72 @@ +// RUN: %dxc -spirv -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// Multiple calls to Get*NodeOuputRecords(array) + +struct RECORD { + int i; + float3 foo; +}; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(64, 1, 1)] +[NodeDispatchGrid(8, 1, 1)] +void node150_a(NodeOutput output) +{ + GroupNodeOutputRecords outRec1 = output.GetGroupNodeOutputRecords(1); + GroupNodeOutputRecords outRec2 = output.GetGroupNodeOutputRecords(4); + outRec1.OutputComplete(); + outRec2.OutputComplete(); +} + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(64, 1, 1)] +[NodeDispatchGrid(8, 1, 1)] +void node150_b(NodeOutput output) +{ + ThreadNodeOutputRecords outRec1 = output.GetThreadNodeOutputRecords(5); + ThreadNodeOutputRecords outRec2 = output.GetThreadNodeOutputRecords(1); + outRec1.OutputComplete(); + outRec1 = outRec2; + outRec1.OutputComplete(); +} + +// CHECK: OpDecorateId [[ARR_A:%[^ ]*]] PayloadNodeNameAMDX [[STR:%[0-9A-Za-z_]*]] +// CHECK: OpDecorateId [[ARR_B:%[^ ]*]] PayloadNodeNameAMDX [[STR]] + +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK-DAG: [[U1:%[^ ]*]] = OpConstant [[UINT]] 1 +// CHECK-DAG: [[U2:%[^ ]*]] = OpConstant [[UINT]] 2 +// CHECK-DAG: [[U4:%[^ ]*]] = OpConstant [[UINT]] 4 +// CHECK-DAG: [[U5:%[^ ]*]] = OpConstant [[UINT]] 5 +// CHECK-DAG: [[STR]] = OpConstantStringAMDX "output" +// CHECK-DAG: [[ARR_A]] = OpTypeNodePayloadArrayAMDX +// CHECK-DAG: [[ARR_B]] = OpTypeNodePayloadArrayAMDX +// CHECK-DAG: [[FPTR_A:%[^ ]*]] = OpTypePointer Function [[ARR_A]] +// CHECK-DAG: [[NPTR_A:%[^ ]*]] = OpTypePointer NodePayloadAMDX [[ARR_A]] +// CHECK-DAG: [[FPTR_B:%[^ ]*]] = OpTypePointer Function [[ARR_B]] +// CHECK-DAG: [[NPTR_B:%[^ ]*]] = OpTypePointer NodePayloadAMDX [[ARR_B]] + +// checking for OpFunctionCall skips over the entry function wrapper and +// thereby avoids matching wrapper variables +// CHECK: OpFunctionCall +// CHECK: [[OUT1:%[^ ]*]] = OpVariable [[FPTR_A]] +// CHECK: [[OUT2:%[^ ]*]] = OpVariable [[FPTR_A]] +// CHECK: [[PAY:%[^ ]*]] = OpAllocateNodePayloadsAMDX [[NPTR_A]] [[U2]] [[U1]] [[U0]] +// CHECK: [[VAL:%[^ ]*]] = OpLoad [[ARR_A]] [[PAY]] +// CHECK: OpStore [[OUT1]] [[VAL]] +// CHECK: [[PAY:%[^ ]*]] = OpAllocateNodePayloadsAMDX [[NPTR_A]] [[U2]] [[U4]] [[U0]] +// CHECK: [[VAL:%[^ ]*]] = OpLoad [[ARR_A]] [[PAY]] +// CHECK: OpStore [[OUT2]] [[VAL]] +// CHECK: OpFunctionCall +// CHECK: [[OUT1:%[^ ]*]] = OpVariable [[FPTR_B]] +// CHECK: [[OUT2:%[^ ]*]] = OpVariable [[FPTR_B]] +// CHECK: [[PAY:%[^ ]*]] = OpAllocateNodePayloadsAMDX [[NPTR_B]] [[U4]] [[U5]] [[U0]] +// CHECK: [[VAL:%[^ ]*]] = OpLoad [[ARR_B]] [[PAY]] +// CHECK: OpStore [[OUT1]] [[VAL]] +// CHECK: [[PAY:%[^ ]*]] = OpAllocateNodePayloadsAMDX [[NPTR_B]] [[U4]] [[U1]] [[U0]] +// CHECK: [[VAL:%[^ ]*]] = OpLoad [[ARR_B]] [[PAY]] +// CHECK: OpStore [[OUT2]] [[VAL]] +// CHECK: OpFunctionEnd diff --git a/tools/clang/test/CodeGenSPIRV/node.get-remaining-recursion-levels.hlsl b/tools/clang/test/CodeGenSPIRV/node.get-remaining-recursion-levels.hlsl new file mode 100644 index 0000000000..f981282748 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.get-remaining-recursion-levels.hlsl @@ -0,0 +1,26 @@ +// RUN: %dxc -spirv -T lib_6_8 external -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// GetRemainingRecusionLevels() called + +RWBuffer buf0; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(8,1,1)] +[NodeDispatchGrid(32,2,2)] +[NodeMaxRecursionDepth(16)] +void node133_getremainingrecursionlevels() +{ + uint remaining = GetRemainingRecursionLevels(); + // Use resource as a way of preventing DCE + buf0[0] = remaining; +} + +// CHECK: OpEntryPoint GLCompute [[SHADER:%[^ ]*]] "node133_getremainingrecursionlevels" [[RRL:%[^ ]*]] +// CHECK: OpExecutionModeId [[SHADER]] MaxNodeRecursionAMDX [[U16:%[^ ]*]] +// CHECK: OpDecorate [[RRL]] BuiltIn RemainingRecursionLevelsAMDX +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK: [[U16]] = OpConstant [[UINT]] 16 +// CHECK: [[PTR:%[^ ]*]] = OpTypePointer Input [[UINT]] +// CHECK: [[RRL]] = OpVariable [[PTR]] Input +// CHECK: OpLoad [[UINT]] [[RRL]] diff --git a/tools/clang/test/CodeGenSPIRV/node.group-shared.barrier.hlsl b/tools/clang/test/CodeGenSPIRV/node.group-shared.barrier.hlsl new file mode 100644 index 0000000000..cf1638d75c --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.group-shared.barrier.hlsl @@ -0,0 +1,18 @@ +// RUN: %dxc -spirv -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// Check that a barrier can be used on a groupshared object from a +// work graph node + +groupshared uint Test; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(1, 1, 1)] +[NumThreads(1, 1, 1)] +void firstNode() +{ + Test = 1; + AllMemoryBarrierWithGroupSync(); +} + +// CHECK: OpReturn diff --git a/tools/clang/test/CodeGenSPIRV/node.group-shared.hlsl b/tools/clang/test/CodeGenSPIRV/node.group-shared.hlsl new file mode 100644 index 0000000000..81fc0e39a2 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.group-shared.hlsl @@ -0,0 +1,24 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// Check that group shared memory is allowed from a work graph node + +struct Record +{ + uint index; +}; + +groupshared uint testLds[512]; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(2, 1, 1)] +[NumThreads(1,1,1)] +void firstNode(DispatchNodeInputRecord inputData) +{ + testLds[inputData.Get().index] = 99; +} + +// CHECK: OpReturn + diff --git a/tools/clang/test/CodeGenSPIRV/node.increment-output-count.group.hlsl b/tools/clang/test/CodeGenSPIRV/node.increment-output-count.group.hlsl new file mode 100644 index 0000000000..d6a2ea759e --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.increment-output-count.group.hlsl @@ -0,0 +1,22 @@ +// RUN: %dxc -spirv -Od -T lib_6_8 external -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// Node with EmptyNodeOutput calls GroupIncrementOutputCount + + +[Shader("node")] +[NodeLaunch("coalescing")] +[NumThreads(1024,1,1)] +[NodeIsProgramEntry] +void node028_incrementoutputcount([MaxRecords(32)] EmptyNodeOutput empty) +{ + empty.GroupIncrementOutputCount(1); +} + +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK-DAG: [[U1:%[^ ]*]] = OpConstant [[UINT]] 1 +// CHECK-DAG: [[STRUCT:%[^ ]*]] = OpTypeStruct +// CHECK-DAG: [[ARR:%[^ ]*]] = OpTypeNodePayloadArrayAMDX [[STRUCT]] +// CHECK-DAG: [[PTR:%[^ ]*]] = OpTypePointer NodePayloadAMDX [[ARR]] +// CHECK-DAG: [[U2:%[^ ]*]] = OpConstant [[UINT]] 2 +// CHECK: OpAllocateNodePayloadsAMDX [[PTR]] [[U2]] [[U1]] [[U0]] diff --git a/tools/clang/test/CodeGenSPIRV/node.increment-output-count.thread.hlsl b/tools/clang/test/CodeGenSPIRV/node.increment-output-count.thread.hlsl new file mode 100644 index 0000000000..6cd984fe69 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.increment-output-count.thread.hlsl @@ -0,0 +1,22 @@ +// RUN: %dxc -spirv -Od -T lib_6_8 external -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// Node with EmptyNodeOutput calls ThreadIncrementOutputCount + + +[Shader("node")] +[NodeLaunch("thread")] +[NodeIsProgramEntry] +void node028_incrementoutputcount([MaxRecords(32)] EmptyNodeOutput empty) +{ + empty.ThreadIncrementOutputCount(1); +} + +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK-DAG: [[U1:%[^ ]*]] = OpConstant [[UINT]] 1 +// CHECK-DAG: [[STRUCT:%[^ ]*]] = OpTypeStruct +// CHECK-DAG: [[ARR:%[^ ]*]] = OpTypeNodePayloadArrayAMDX [[STRUCT]] +// CHECK-DAG: [[PTR:%[^ ]*]] = OpTypePointer NodePayloadAMDX [[ARR]] +// CHECK-DAG: OpConstantStringAMDX "empty" +// CHECK-DAG: [[U4:%[^ ]*]] = OpConstant [[UINT]] 4 +// CHECK: OpAllocateNodePayloadsAMDX [[PTR]] [[U4]] [[U1]] [[U0]] diff --git a/tools/clang/test/CodeGenSPIRV/node.input-record.dispatch-grid.array.hlsl b/tools/clang/test/CodeGenSPIRV/node.input-record.dispatch-grid.array.hlsl new file mode 100644 index 0000000000..bae3f759b8 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.input-record.dispatch-grid.array.hlsl @@ -0,0 +1,26 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// Check that SV_DispatchGrid supports array + +struct RECORD +{ + uint a[3] : SV_DispatchGrid; + uint b[3]; +}; + +[Shader("node")] +[NodeLaunch("coalescing")] +[numthreads(4,4,4)] +void node01(RWGroupNodeInputRecords input) +{ + input.Get().a = input.Get().b; +} + +// CHECK: OpName [[RECORD:%[^ ]*]] "RECORD" +// CHECK: OpMemberDecorate [[RECORD]] 0 PayloadDispatchIndirectAMDX +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK: [[U3:%[^ ]*]] = OpConstant %uint 3 +// CHECK: [[ARRAY:%[^ ]*]] = OpTypeArray [[UINT]] [[U3]] +// CHECK: [[RECORD]] = OpTypeStruct [[ARRAY]] [[ARRAY]] diff --git a/tools/clang/test/CodeGenSPIRV/node.input-record.dispatch-grid.nested.hlsl b/tools/clang/test/CodeGenSPIRV/node.input-record.dispatch-grid.nested.hlsl new file mode 100644 index 0000000000..aee7e0d014 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.input-record.dispatch-grid.nested.hlsl @@ -0,0 +1,32 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// Check that SV_DispatchGrid in nested struct is recognized + +struct INNER { + uint c; + uint3 grid : SV_DispatchGrid; +}; + +struct RECORD +{ + uint a; + INNER b; +}; + +[Shader("node")] +[NodeLaunch("coalescing")] +[numthreads(4,4,4)] +void node01(RWGroupNodeInputRecords input) +{ + input.Get().a = input.Get().b.grid.x; +} + +// CHECK: OpName [[RECORD:%[^ ]*]] "RECORD" +// CHECK: OpName [[INNER:%[^ ]*]] "INNER" +// CHECK: OpMemberDecorate [[INNER]] 1 PayloadDispatchIndirectAMDX +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK: [[VECTOR:%[^ ]*]] = OpTypeVector %uint 3 +// CHECK: [[INNER]] = OpTypeStruct [[UINT]] [[VECTOR]] +// CHECK: [[RECORD]] = OpTypeStruct [[UINT]] [[INNER]] diff --git a/tools/clang/test/CodeGenSPIRV/node.max-dispatch-grid.hlsl b/tools/clang/test/CodeGenSPIRV/node.max-dispatch-grid.hlsl new file mode 100644 index 0000000000..e2440a31c0 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.max-dispatch-grid.hlsl @@ -0,0 +1,30 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// Broadcasting launch node with dispatch grid defined in input +// and max dispatch grid defined in the shader + +struct INPUT_GRID +{ + uint3 DispatchGrid : SV_DispatchGrid; + uint textureIndex; +}; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeMaxDispatchGrid(2,3,4)] +[NumThreads(1024,1,1)] +void node002_dispatchgrid_input_maxdispatchgrid_shader(DispatchNodeInputRecord input) +{ +} + +// CHECK: OpEntryPoint GLCompute [[SHADER:%[^ ]*]] "node002_dispatchgrid_input_maxdispatchgrid_shader" +// CHECK-DAG: OpExecutionMode [[SHADER]] LocalSize 1024 1 1 +// CHECK-DAG: OpExecutionModeId [[SHADER]] MaxNumWorkgroupsAMDX [[U2:%[^ ]*]] [[U3:%[^ ]*]] [[U4:%[0-9A-Za-z_]*]] +// CHECK: OpMemberDecorate %{{[^ ]*}} 0 PayloadDispatchIndirectAMDX +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U2]] = OpConstant [[UINT]] 2 +// CHECK-DAG: [[U3]] = OpConstant [[UINT]] 3 +// CHECK-DAG: [[U4]] = OpConstant [[UINT]] 4 +// CHECK: OpReturn diff --git a/tools/clang/test/CodeGenSPIRV/node.max-records.hlsl b/tools/clang/test/CodeGenSPIRV/node.max-records.hlsl new file mode 100644 index 0000000000..7d8449afab --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.max-records.hlsl @@ -0,0 +1,45 @@ +// RUN: %dxc -spirv -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// Test referencing params with MaxOutputRecordsSharedWith + +struct rec0 +{ + int i0; + float f0; +}; + +struct rec1 +{ + float f1; + int i1; +}; + +[Shader("node")] +[NodeLaunch("thread")] +void BackwardRef( + RWThreadNodeInputRecord InputyMcInputFace, + [MaxRecords(5)] NodeOutput Output1, + [MaxRecordsSharedWith(Output1)] NodeOutput Output2) +{ +} + +// CHECK: OpDecorateId [[TYPE1:%[^ ]*]] PayloadNodeNameAMDX [[STR1:%[^ ]*]] +// CHECK: OpDecorateId [[TYPE1]] NodeMaxPayloadsAMDX [[U5:%[^ ]*]] +// CHECK: OpDecorateId [[TYPE2:%[^ ]*]] PayloadNodeNameAMDX [[STR2:%[^ ]*]] +// CHECK: OpDecorateId [[TYPE2]] NodeSharesPayloadLimitsWithAMDX [[TYPE1]] +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U5]] = OpConstant [[UINT]] 5 +// CHECK-DAG: [[STR1]] = OpConstantStringAMDX "Output1" +// CHECK-DAG: [[STR2]] = OpConstantStringAMDX "Output2" + +#if 0 +// copied from DXIL test but doesn't seem to conform to spec +[Shader("node")] +[NodeLaunch("thread")] +void ForwardRef( + RWThreadNodeInputRecord InputyMcInputFace, + [MaxRecordsSharedWith(Output2)] NodeOutput Output1, + [MaxRecords(5)] NodeOutput Output2) +{ +} +#endif diff --git a/tools/clang/test/CodeGenSPIRV/node.member.read.hlsl b/tools/clang/test/CodeGenSPIRV/node.member.read.hlsl new file mode 100644 index 0000000000..ac2474b29b --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.member.read.hlsl @@ -0,0 +1,150 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// Read access to members of node input/output records + +RWBuffer buf0; + +struct RECORD +{ + uint a; + uint b; + uint c; +}; + +// CHECK: OpName [[BUF0:%[^ ]*]] "buf0" +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK: [[U16:%[^ ]*]] = OpConstant [[UINT]] 16 +// CHECK-DAG: [[INT:%[^ ]*]] = OpTypeInt 32 1 +// CHECK-DAG: [[S0:%[^ ]*]] = OpConstant [[INT]] 0 +// CHECK-DAG: [[U1:%[^ ]*]] = OpConstant [[UINT]] 1 +// CHECK-DAG: [[S1:%[^ ]*]] = OpConstant [[INT]] 1 +// CHECK-DAG: [[U2:%[^ ]*]] = OpConstant [[UINT]] 2 +// CHECK-DAG: [[S2:%[^ ]*]] = OpConstant [[INT]] 2 +// CHECK-DAG: [[U4:%[^ ]*]] = OpConstant [[UINT]] 4 +// CHECK-DAG: [[U7:%[^ ]*]] = OpConstant [[UINT]] 7 +// CHECK-DAG: [[TBI:%[^ ]*]] = OpTypeImage [[UINT]] Buffer + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(16,1,1)] +void node01(DispatchNodeInputRecord input) +{ + buf0[0] = input.Get().a; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[S0]] +// CHECK: [[VAL:%[^ ]*]] = OpLoad [[UINT]] [[PTR]] +// CHECK: [[IMG:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[IMG]] [[U0]] [[VAL]] +// CHECK: OpFunctionEnd + + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(16,1,1)] +void node02(RWDispatchNodeInputRecord input) +{ + buf0[0] = input.Get().b; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[S1]] +// CHECK: [[VAL:%[^ ]*]] = OpLoad [[UINT]] [[PTR]] +// CHECK: [[IMG:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[IMG]] [[U0]] [[VAL]] +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(1024, 1, 1)] +[NodeLaunch("coalescing")] +void node03([MaxRecords(3)] GroupNodeInputRecords input) +{ + buf0[0] = input[1].c; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[U1]] [[S2]] +// CHECK: [[VAL:%[^ ]*]] = OpLoad [[UINT]] [[PTR]] +// CHECK: [[IMG:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[IMG]] [[U0]] [[VAL]] +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(1,1,1)] +[NodeLaunch("coalescing")] +void node04([MaxRecords(4)] RWGroupNodeInputRecords input) +{ + buf0[0] = input[2].c; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[U2]] [[S2]] +// CHECK: [[VAL:%[^ ]*]] = OpLoad [[UINT]] [[PTR]] +// CHECK: [[IMG:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[IMG]] [[U0]] [[VAL]] +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(1,1,1)] +[NodeLaunch("coalescing")] +void node05(NodeOutput output) +{ + ThreadNodeOutputRecords outrec = output.GetThreadNodeOutputRecords(1); + buf0[0] = outrec.Get().a; +} + +// CHECK: OpFunction +// CHECK: [[PAY:%[^ ]*]] = OpAllocateNodePayloadsAMDX %{{[^ ]*}} [[U4]] [[U1]] [[U0]] +// CHECK: [[TEMP:%[^ ]*]] = OpLoad %{{[^ ]*}} [[PAY]] +// CHECK: OpStore [[OUT:%[^ ]*]] [[TEMP]] +// CHECK: [[PTR1:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[OUT]] [[U0]] +// CHECK: [[PTR2:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[PTR1]] [[S0]] +// CHECK-DAG: [[VAL:%[^ ]*]] = OpLoad [[UINT]] [[PTR2]] +// CHECK-DAG: [[IMG:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[IMG]] [[U0]] [[VAL]] +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(1,1,1)] +[NodeLaunch("coalescing")] +void node06(NodeOutput output) +{ + ThreadNodeOutputRecords outrec = output.GetThreadNodeOutputRecords(7); + buf0[0] = outrec[2].b; +} + +// CHECK: OpFunction +// CHECK: [[PAY:%[^ ]*]] = OpAllocateNodePayloadsAMDX %{{[^ ]*}} [[U4]] [[U7]] [[U0]] +// CHECK: [[TEMP:%[^ ]*]] = OpLoad %{{[^ ]*}} [[PAY]] +// CHECK: OpStore [[OUT:%[^ ]*]] [[TEMP]] +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[OUT]] [[U2]] [[S1]] +// CHECK-DAG: [[VAL:%[^ ]*]] = OpLoad [[UINT]] [[PTR]] +// CHECK-DAG: [[IMG:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[IMG]] [[U0]] [[VAL]] +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(1,1,1)] +[NodeLaunch("coalescing")] +void node07(NodeOutput output) +{ + GroupNodeOutputRecords outrec = output.GetGroupNodeOutputRecords(1); + buf0[0] = outrec.Get().c; +} + +// CHECK: OpFunction +// CHECK: [[PAY:%[^ ]*]] = OpAllocateNodePayloadsAMDX %{{[^ ]*}} [[U2]] [[U1]] [[U0]] +// CHECK: [[TEMP:%[^ ]*]] = OpLoad %{{[^ ]*}} [[PAY]] +// CHECK: OpStore [[OUT:%[^ ]*]] [[TEMP]] +// CHECK: [[PTR1:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[OUT]] [[U0]] +// CHECK: [[PTR2:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[PTR1]] [[S2]] +// CHECK-DAG: [[VAL:%[^ ]*]] = OpLoad [[UINT]] [[PTR2]] +// CHECK-DAG: [[IMG:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[IMG]] [[U0]] [[VAL]] +// CHECK: OpFunctionEnd diff --git a/tools/clang/test/CodeGenSPIRV/node.member.read.types.hlsl b/tools/clang/test/CodeGenSPIRV/node.member.read.types.hlsl new file mode 100644 index 0000000000..5f7d434bd2 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.member.read.types.hlsl @@ -0,0 +1,193 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 -enable-16bit-types %s | FileCheck %s +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// Read access of members of input/output record with different type +// sizes - we check the function specializations generated + +RWBuffer buf0; + +struct RECORD +{ + half h; + float f; + double d; + bool b; + uint16_t i16; + int i; + int64_t i64; + uint64_t u64; +}; + +// CHECK: OpName [[BUF0:%[^ ]*]] "buf0" +// CHECK-DAG: OpName [[RECORD:%[^ ]*]] "RECORD" +// CHECK-DAG: OpMemberName [[RECORD]] 0 "h" +// CHECK-DAG: OpMemberName [[RECORD]] 1 "f" +// CHECK-DAG: OpMemberName [[RECORD]] 2 "d" +// CHECK-DAG: OpMemberName [[RECORD]] 3 "b" +// CHECK-DAG: OpMemberName [[RECORD]] 4 "i16" +// CHECK-DAG: OpMemberName [[RECORD]] 5 "i" +// CHECK-DAG: OpMemberName [[RECORD]] 6 "i64" +// CHECK-DAG: OpMemberName [[RECORD]] 7 "u64" + +// CHECK-DAG: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[INT:%[^ ]*]] = OpTypeInt 32 1 +// CHECK-DAG: [[S0:%[^ ]*]] = OpConstant [[INT]] 0 +// CHECK-DAG: [[S1:%[^ ]*]] = OpConstant [[INT]] 1 +// CHECK-DAG: [[S2:%[^ ]*]] = OpConstant [[INT]] 2 +// CHECK-DAG: [[S3:%[^ ]*]] = OpConstant [[INT]] 3 +// CHECK-DAG: [[S4:%[^ ]*]] = OpConstant [[INT]] 4 +// CHECK-DAG: [[S5:%[^ ]*]] = OpConstant [[INT]] 5 +// CHECK-DAG: [[S6:%[^ ]*]] = OpConstant [[INT]] 6 +// CHECK-DAG: [[S7:%[^ ]*]] = OpConstant [[INT]] 7 +// CHECK-DAG: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK-DAG: [[U1:%[^ ]*]] = OpConstant [[UINT]] 1 +// CHECK-DAG: [[TBI:%[^ ]*]] = OpTypeImage [[UINT]] Buffer + +// CHECK-DAG: [[HALF:%[^ ]*]] = OpTypeFloat 16 +// CHECK-DAG: [[FLOAT:%[^ ]*]] = OpTypeFloat 32 +// CHECK-DAG: [[DOUBLE:%[^ ]*]] = OpTypeFloat 64 +// CHECK-DAG: [[USHORT:%[^ ]*]] = OpTypeInt 16 0 +// CHECK-DAG: [[LONG:%[^ ]*]] = OpTypeInt 64 1 +// CHECK-DAG: [[ULONG:%[^ ]*]] = OpTypeInt 64 0 +// CHECK: [[RECORD]] = OpTypeStruct [[HALF]] [[FLOAT]] [[DOUBLE]] [[UINT]] [[USHORT]] [[INT]] [[LONG]] [[ULONG]] +// CHECK: [[BOOL:%[^ ]*]] = OpTypeBool + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(64,1,1)] +void node01(DispatchNodeInputRecord input) +{ + buf0[0] = input.Get().h; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[S0]] +// CHECK: [[VAL0:%[^ ]*]] = OpLoad [[HALF]] [[PTR]] +// CHECK: [[VAL1:%[^ ]*]] = OpConvertFToU [[UINT]] [[VAL0]] +// CHECK: [[VAL2:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[VAL2]] [[U0]] [[VAL1]] None +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(64,1,1)] +void node02(DispatchNodeInputRecord input) +{ + buf0[0] = input.Get().f; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[S1]] +// CHECK: [[VAL0:%[^ ]*]] = OpLoad [[FLOAT]] [[PTR]] +// CHECK: [[VAL1:%[^ ]*]] = OpConvertFToU [[UINT]] [[VAL0]] +// CHECK: [[VAL2:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[VAL2]] [[U0]] [[VAL1]] None +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(64,1,1)] +void node03(DispatchNodeInputRecord input) +{ + buf0[0] = input.Get().d; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[S2]] +// CHECK: [[VAL0:%[^ ]*]] = OpLoad [[DOUBLE]] [[PTR]] +// CHECK: [[VAL1:%[^ ]*]] = OpConvertFToU [[UINT]] [[VAL0]] +// CHECK: [[VAL2:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[VAL2]] [[U0]] [[VAL1]] None +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(64,1,1)] +void node04(DispatchNodeInputRecord input) +{ + buf0[0] = input.Get().b; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[S3]] +// CHECK: [[VAL0:%[^ ]*]] = OpLoad [[UINT]] [[PTR]] +// CHECK: [[VAL1:%[^ ]*]] = OpINotEqual [[BOOL]] [[VAL0]] [[U0]] +// CHECK: [[VAL2:%[^ ]*]] = OpSelect [[UINT]] [[VAL1]] [[U1]] [[U0]] +// CHECK: [[VAL3:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[VAL3]] [[U0]] [[VAL2]] None +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(64,1,1)] +void node05(DispatchNodeInputRecord input) +{ + buf0[0] = input.Get().i16; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[S4]] +// CHECK: [[VAL0:%[^ ]*]] = OpLoad [[USHORT]] [[PTR]] +// CHECK: [[VAL1:%[^ ]*]] = OpUConvert [[UINT]] [[VAL0]] +// CHECK: [[VAL2:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[VAL2]] [[U0]] [[VAL1]] None +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(64,1,1)] +void node06(DispatchNodeInputRecord input) +{ + buf0[0] = input.Get().i; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[S5]] +// CHECK: [[VAL0:%[^ ]*]] = OpLoad [[INT]] [[PTR]] +// CHECK: [[VAL1:%[^ ]*]] = OpBitcast [[UINT]] [[VAL0]] +// CHECK: [[VAL2:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[VAL2]] [[U0]] [[VAL1]] None +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(64,1,1)] +void node07(DispatchNodeInputRecord input) +{ + buf0[0] = input.Get().i64; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[S6]] +// CHECK: [[VAL0:%[^ ]*]] = OpLoad [[LONG]] [[PTR]] +// CHECK: [[VAL1:%[^ ]*]] = OpSConvert [[INT]] [[VAL0]] +// CHECK: [[VAL2:%[^ ]*]] = OpBitcast [[UINT]] [[VAL1]] +// CHECK: [[VAL3:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[VAL3]] [[U0]] [[VAL2]] None +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(64,1,1)] +void node08(DispatchNodeInputRecord input) +{ + buf0[0] = input.Get().u64; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[S7]] +// CHECK: [[VAL0:%[^ ]*]] = OpLoad [[ULONG]] [[PTR]] +// CHECK: [[VAL1:%[^ ]*]] = OpUConvert [[UINT]] [[VAL0]] +// CHECK: [[VAL2:%[^ ]*]] = OpLoad [[TBI]] [[BUF0]] +// CHECK: OpImageWrite [[VAL2]] [[U0]] [[VAL1]] None +// CHECK: OpFunctionEnd + diff --git a/tools/clang/test/CodeGenSPIRV/node.member.write.hlsl b/tools/clang/test/CodeGenSPIRV/node.member.write.hlsl new file mode 100644 index 0000000000..33fc2dd9ff --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.member.write.hlsl @@ -0,0 +1,88 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// Writes to members of the various read-write node records + +struct RECORD +{ + uint a; + uint b; +}; + +// CHECK-DAG: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[INT:%[^ ]*]] = OpTypeInt 32 1 +// CHECK-DAG: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK-DAG: [[S0:%[^ ]*]] = OpConstant [[INT]] 0 +// CHECK-DAG: [[U1:%[^ ]*]] = OpConstant [[UINT]] 1 +// CHECK-DAG: [[S1:%[^ ]*]] = OpConstant [[INT]] 1 +// CHECK-DAG: [[U2:%[^ ]*]] = OpConstant [[UINT]] 2 +// CHECK-DAG: [[U4:%[^ ]*]] = OpConstant [[UINT]] 4 +// CHECK-DAG: [[U5:%[^ ]*]] = OpConstant [[UINT]] 5 +// CHECK-DAG: [[U7:%[^ ]*]] = OpConstant [[UINT]] 7 +// CHECK-DAG: [[U8:%[^ ]*]] = OpConstant [[UINT]] 8 +// CHECK-DAG: [[U9:%[^ ]*]] = OpConstant [[UINT]] 9 +// CHECK-DAG: [[U11:%[^ ]*]] = OpConstant [[UINT]] 11 + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(64,1,1)] +void node01(RWDispatchNodeInputRecord input1) +{ + input1.Get().a = 5; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[S0]] +// CHECK: OpStore [[PTR]] [[U5]] +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(2,1,1)] +[NodeLaunch("coalescing")] +void node02([MaxRecords(4)] RWGroupNodeInputRecords input2) +{ + input2[1].b = 7; +} + +// CHECK: OpFunction +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %{{[^ ]*}} [[U1]] [[S1]] +// CHECK: OpStore [[PTR]] [[U7]] +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(3,1,1)] +[NodeLaunch("coalescing")] +void node03(NodeOutput output) +{ + ThreadNodeOutputRecords output3 = output.GetThreadNodeOutputRecords(2); + output3.Get().b = 9; +} + +// CHECK: OpFunction +// CHECK: [[PAY:%[^ ]*]] = OpAllocateNodePayloadsAMDX %{{[^ ]*}} [[U4]] [[U2]] [[U0]] +// CHECK: [[VAL:%[^ ]*]] = OpLoad %{{[^ ]*}} [[PAY]] +// CHECK: OpStore [[OUT:%[^ ]*]] [[VAL]] +// CHECK: [[PTR0:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[OUT]] [[U0]] +// CHECK: [[PTR1:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[PTR0]] [[S1]] +// CHECK: OpStore [[PTR1]] [[U9]] +// CHECK: OpFunctionEnd + +[Shader("node")] +[NumThreads(4,1,1)] +[NodeLaunch("coalescing")] +void node04(NodeOutput output) +{ + GroupNodeOutputRecords output4 = output.GetGroupNodeOutputRecords(8); + output4[0].a = 11; +} + +// CHECK: OpFunction +// CHECK: [[PAY:%[^ ]*]] = OpAllocateNodePayloadsAMDX %{{[^ ]*}} [[U2]] [[U8]] [[U0]] +// CHECK: [[VAL:%[^ ]*]] = OpLoad %{{[^ ]*}} [[PAY]] +// CHECK: OpStore [[OUT:%[^ ]*]] [[VAL]] +// CHECK: [[PTR:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[OUT]] [[U0]] [[S0]] +// CHECK: OpStore [[PTR]] [[U11]] +// CHECK: OpFunctionEnd + diff --git a/tools/clang/test/CodeGenSPIRV/node.member.write.matrix.hlsl b/tools/clang/test/CodeGenSPIRV/node.member.write.matrix.hlsl new file mode 100644 index 0000000000..d875f27d4e --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.member.write.matrix.hlsl @@ -0,0 +1,123 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments +// ================================================================== +// Test writing to matrix members of node records +// ================================================================== + +// CHECK: OpName [[NODE01:%[^ ]*]] "node01" +// CHECK: OpName [[INPUT1:%[^ ]*]] "input1" +// CHECK: OpName [[NODE02:%[^ ]*]] "node02" +// CHECK: OpName [[INPUT2:%[^ ]*]] "input2" +// CHECK: OpName [[NODE03:%[^ ]*]] "node03" +// CHECK: OpName [[OUTPUT3:%[^ ]*]] "output3" +// CHECK: OpName [[NODE04:%[^ ]*]] "node04" +// CHECK: OpName [[OUTPUTS4:%[^ ]*]] "outputs4" + +struct RECORD +{ + row_major float2x2 m0; + row_major float2x2 m1; + column_major float2x2 m2; +}; + +// CHECK-DAG: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U64:%[^ ]*]] = OpConstant [[UINT]] 64 +// CHECK-DAG: [[U1:%[^ ]*]] = OpConstant [[UINT]] 1 +// CHECK-DAG: [[FLOAT:%[^ ]*]] = OpTypeFloat 32 +// CHECK-DAG: [[F111:%[^ ]*]] = OpConstant [[FLOAT]] 111 +// CHECK-DAG: [[V2FLOAT:%[^ ]*]] = OpTypeVector [[FLOAT]] 2 +// CHECK-DAG: [[C1:%[^ ]*]] = OpConstantComposite [[V2FLOAT]] [[F111]] [[F111]] +// CHECK-DAG: [[MAT2V2FLOAT:[^ ]*]] = OpTypeMatrix [[V2FLOAT]] 2 +// CHECK-DAG: [[M1:%[^ ]*]] = OpConstantComposite [[MAT2V2FLOAT]] [[C1]] [[C1]] +// CHECK-DAG: [[INT:%[^ ]*]] = OpTypeInt 32 1 +// CHECK-DAG: [[I1:%[^ ]*]] = OpConstant [[INT]] 1 +// CHECK-DAG: [[I0:%[^ ]*]] = OpConstant [[INT]] 0 +// CHECK-DAG: [[I2:%[^ ]*]] = OpConstant [[INT]] 2 +// CHECK-DAG: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK-DAG: [[F222:%[^ ]*]] = OpConstant [[FLOAT]] 222 +// CHECK-DAG: [[C2:%[^ ]*]] = OpConstantComposite [[V2FLOAT]] [[F222]] [[F222]] +// CHECK-DAG: [[M2:%[^ ]*]] = OpConstantComposite [[MAT2V2FLOAT]] [[C2]] [[C2]] +// CHECK-DAG: [[U4:%[^ ]*]] = OpConstant [[UINT]] 4 +// CHECK-DAG: [[U2:%[^ ]*]] = OpConstant [[UINT]] 2 + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(64,1,1)] +void node01(RWDispatchNodeInputRecord input1) +{ + // CHECK: [[NODE01]] = OpFunction + // CHECK: [[P0:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[INPUT1]] [[U0]] + // CHECK: [[P1:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[P0]] [[I1]] + // CHECK: OpStore [[P1]] [[M1]] + // CHECK: [[P0:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[INPUT1]] [[U0]] + // CHECK: [[P2:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[P0]] [[I0]] + // CHECK: [[VAL:%[^ ]*]] = OpLoad [[MAT2V2FLOAT]] [[P2]] + // CHECK: [[P0:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[INPUT1]] [[U0]] + // CHECK: [[P3:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[P0]] [[I2]] + // CHECK: OpStore [[P3]] [[VAL]] + // CHECK: OpFunctionEnd + input1.Get().m1 = 111; + input1.Get().m2 = input1.Get().m0; +} + +[Shader("node")] +[NumThreads(1,1,1)] +[NodeLaunch("coalescing")] +void node02([MaxRecords(4)] RWGroupNodeInputRecords input2) +{ + // CHECK: [[NODE02]] = OpFunction + // CHECK: [[P1:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[INPUT2]] [[U0]] [[I1]] + // CHECK: OpStore [[P1]] [[M1]] + // CHECK: [[P2:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[INPUT2]] [[U1]] [[I0]] + // CHECK: [[VAL:%[^ ]*]] = OpLoad [[MAT2V2FLOAT]] [[P2]] + // CHECK: [[P3:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[INPUT2]] [[U1]] [[I2]] + // CHECK: OpStore [[P3]] [[VAL]] + // CHECK: OpFunctionEnd + input2[0].m1 = 111; + input2[1].m2 = input2[1].m0; +} + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeDispatchGrid(64,1,1)] +[NodeLaunch("broadcasting")] +void node03(NodeOutput output3) +{ + // CHECK: [[NODE03]] = OpFunction + // CHECK: [[PAY:%[^ ]*]] = OpAllocateNodePayloadsAMDX %{{[^ ]*}} [[U4]] [[U1]] [[U0]] + // CHECK: [[VAL:%[^ ]*]] = OpLoad %{{[^ ]*}} [[PAY]] + // CHECK: OpStore [[OUTREC3:%[^ ]*]] [[VAL]] + // CHECK: [[P0:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[OUTREC3]] [[U0]] + // CHECK: [[P1:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[P0]] [[I1]] + // CHECK: OpStore [[P1]] [[M1]] + // CHECK: [[P0:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[OUTREC3]] [[U0]] + // CHECK: [[P2:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[P0]] [[I2]] + // CHECK: OpStore [[P2]] [[M2]] + // CHECK: OpFunctionEnd + ThreadNodeOutputRecords outrec = output3.GetThreadNodeOutputRecords(1); + outrec.Get().m1 = 111; + outrec.Get().m2 = 222; +} + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("coalescing")] +void node04([MaxRecords(5)] NodeOutput outputs4) +{ + // CHECK: [[NODE04]] = OpFunction + // CHECK: [[PAY:%[^ ]*]] = OpAllocateNodePayloadsAMDX %{{[^ ]*}} [[U2]] [[U1]] [[U0]] + // CHECK: [[VAL:%[^ ]*]] = OpLoad %{{[^ ]*}} [[PAY]] + // CHECK: OpStore [[OUTREC4:%[^ ]*]] [[VAL]] + // CHECK: [[P0:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[OUTREC4]] [[U0]] + // CHECK: [[P1:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[P0]] [[I1]] + // CHECK: OpStore [[P1]] [[M1]] + // CHECK: [[P0:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[OUTREC4]] [[U0]] + // CHECK: [[P2:%[^ ]*]] = OpAccessChain %{{[^ ]*}} [[P0]] [[I2]] + // CHECK: OpStore [[P2]] [[M2]] + // CHECK: OpFunctionEnd + GroupNodeOutputRecords outrec = outputs4.GetGroupNodeOutputRecords(1); + outrec.Get().m1 = 111; + outrec.Get().m2 = 222; +} diff --git a/tools/clang/test/CodeGenSPIRV/node.member.write.types.hlsl b/tools/clang/test/CodeGenSPIRV/node.member.write.types.hlsl new file mode 100644 index 0000000000..ec95c3d758 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.member.write.types.hlsl @@ -0,0 +1,150 @@ +// RUN: %dxc -spirv -T lib_6_8 -fspv-target-env=vulkan1.3 -enable-16bit-types %s | FileCheck %s + +// Writes to node record members of various types + + +struct RECORD +{ + half h; + float f; + double d; + bool b; + int16_t i16; + uint16_t u16; + int i; + int64_t i64; + uint64_t u64; + float3 f3; + int ia[7]; +}; + +// CHECK: OpName [[RECORD:%[^ ]*]] "RECORD" +// CHECK: OpMemberName [[RECORD]] 0 "h" +// CHECK: OpMemberName [[RECORD]] 1 "f" +// CHECK: OpMemberName [[RECORD]] 2 "d" +// CHECK: OpMemberName [[RECORD]] 3 "b" +// CHECK: OpMemberName [[RECORD]] 4 "i16" +// CHECK: OpMemberName [[RECORD]] 5 "u16" +// CHECK: OpMemberName [[RECORD]] 6 "i" +// CHECK: OpMemberName [[RECORD]] 7 "i64" +// CHECK: OpMemberName [[RECORD]] 8 "u64" +// CHECK: OpMemberName [[RECORD]] 9 "f3" +// CHECK: OpMemberName [[RECORD]] 10 "ia" + +// CHECK-DAG: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[HALF:%[^ ]*]] = OpTypeFloat 16 +// CHECK-DAG: [[INT:%[^ ]*]] = OpTypeInt 32 1 +// CHECK-DAG: [[FLOAT:%[^ ]*]] = OpTypeFloat 32 +// CHECK-DAG: [[DOUBLE:%[^ ]*]] = OpTypeFloat 64 +// CHECK-DAG: [[SHORT:%[^ ]*]] = OpTypeInt 16 1 +// CHECK-DAG: [[USHORT:%[^ ]*]] = OpTypeInt 16 0 +// CHECK-DAG: [[LONG:%[^ ]*]] = OpTypeInt 64 1 +// CHECK-DAG: [[ULONG:%[^ ]*]] = OpTypeInt 64 0 +// CHECK-DAG: [[V3FLOAT:%[^ ]*]] = OpTypeVector [[FLOAT]] 3 + +// CHECK-DAG: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK-DAG: [[U1:%[^ ]*]] = OpConstant [[UINT]] 1 +// CHECK-DAG: [[HALF_0X1_8P_1:%[^ ]*]] = OpConstant [[HALF]] 0x1.8p+1 +// CHECK-DAG: [[I0:%[^ ]*]] = OpConstant [[INT]] 0 +// CHECK-DAG: [[FN5:%[^ ]*]] = OpConstant [[FLOAT]] -5 +// CHECK-DAG: [[I1:%[^ ]*]] = OpConstant [[INT]] 1 +// CHECK-DAG: [[D7:%[^ ]*]] = OpConstant [[DOUBLE]] 7 +// CHECK-DAG: [[I2:%[^ ]*]] = OpConstant [[INT]] 2 +// CHECK-DAG: [[I3:%[^ ]*]] = OpConstant [[INT]] 3 +// CHECK-DAG: [[S11:%[^ ]*]] = OpConstant [[SHORT]] 11 +// CHECK-DAG: [[I4:%[^ ]*]] = OpConstant [[INT]] 4 +// CHECK-DAG: [[US13:%[^ ]*]] = OpConstant [[USHORT]] 13 +// CHECK-DAG: [[I5:%[^ ]*]] = OpConstant [[INT]] 5 +// CHECK-DAG: [[I17:%[^ ]*]] = OpConstant [[INT]] 17 +// CHECK-DAG: [[I6:%[^ ]*]] = OpConstant [[INT]] 6 +// CHECK-DAG: [[LN19:%[^ ]*]] = OpConstant [[LONG]] -19 +// CHECK-DAG: [[I7:%[^ ]*]] = OpConstant [[INT]] 7 +// CHECK-DAG: [[UL21:%[^ ]*]] = OpConstant [[ULONG]] 21 +// CHECK-DAG: [[I8:%[^ ]*]] = OpConstant [[INT]] 8 +// CHECK-DAG: [[F23:%[^ ]*]] = OpConstant [[FLOAT]] 23 +// CHECK-DAG: [[I9:%[^ ]*]] = OpConstant [[INT]] 9 +// CHECK-DAG: [[I29:%[^ ]*]] = OpConstant [[INT]] 29 +// CHECK-DAG: [[I10:%[^ ]*]] = OpConstant [[INT]] 10 +// CHECK-DAG: [[U7:%[^ ]*]] = OpConstant [[UINT]] 7 + +// CHECK-DAG: [[AI7:%[^ ]*]] = OpTypeArray [[INT]] [[U7]] +// CHECK-DAG: [[RECORD]] = OpTypeStruct [[HALF]] [[FLOAT]] [[DOUBLE]] [[UINT]] [[SHORT]] [[USHORT]] [[INT]] [[LONG]] [[ULONG]] [[V3FLOAT]] [[AI7]] +// CHECK-DAG: [[RAR:%[^ ]*]] = OpTypeNodePayloadArrayAMDX %RECORD +// CHECK-DAG: [[RARP:%[^ ]*]] = OpTypePointer NodePayloadAMDX [[RAR]] +// CHECK-DAG: [[U2:%[^ ]*]] = OpConstant [[UINT]] 2 +// CHECK-DAG: [[HALFP:%[^ ]*]] = OpTypePointer Function [[HALF]] +// CHECK-DAG: [[FLOATP:%[^ ]*]] = OpTypePointer Function [[FLOAT]] +// CHECK-DAG: [[DOUBLEP:%[^ ]*]] = OpTypePointer Function [[DOUBLE]] +// CHECK-DAG: [[UINTP:%[^ ]*]] = OpTypePointer Function [[UINT]] +// CHECK-DAG: [[SHORTP:%[^ ]*]] = OpTypePointer Function [[SHORT]] +// CHECK-DAG: [[USHORTP:%[^ ]*]] = OpTypePointer Function [[USHORT]] +// CHECK-DAG: [[INTP:%[^ ]*]] = OpTypePointer Function [[INT]] +// CHECK-DAG: [[LONGP:%[^ ]*]] = OpTypePointer Function [[LONG]] +// CHECK-DAG: [[ULONGP:%[^ ]*]] = OpTypePointer Function [[ULONG]] + +[Shader("node")] +[NumThreads(1024,1,1)] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(512,1,1)] +void node125(NodeOutput output) +{ + GroupNodeOutputRecords output01 = output.GetGroupNodeOutputRecords(1); + // CHECK: OpAllocateNodePayloadsAMDX [[RARP]] [[U2]] [[U1]] [[U0]] + + output01.Get().h = 3.0; + // CHECK: [[PTR:%[^ ]*]] = OpAccessChain [[HALFP]] + // CHECK-SAME: [[I0]] + // CHECK: OpStore [[PTR]] [[HALF_0X1_8P_1]] + + output01.Get().f = -5.0; + // CHECK: [[PTR:%[^ ]*]] = OpAccessChain [[FLOATP]] + // CHECK-SAME: [[I1]] + // CHECK: OpStore [[PTR]] [[FN5]] + + output01.Get().d = 7.0; + // CHECK: [[PTR:%[^ ]*]] = OpAccessChain [[DOUBLEP]] + // CHECK-SAME: [[I2]] + // CHECK: OpStore [[PTR]] [[D7]] + + output01.Get().b = true; + // CHECK: [[PTR:%[^ ]*]] = OpAccessChain [[UINTP]] + // CHECK-SAME: [[I3]] + // CHECK: OpStore [[PTR]] [[U1]] + + output01.Get().i16 = 11; + // CHECK: [[PTR:%[^ ]*]] = OpAccessChain [[SHORTP]] + // CHECK-SAME: [[I4]] + // CHECK: OpStore [[PTR]] [[S11]] + + output01.Get().u16 = 13; + // CHECK: [[PTR:%[^ ]*]] = OpAccessChain [[USHORTP]] + // CHECK-SAME: [[I5]] + // CHECK: OpStore [[PTR]] [[US13]] + + output01.Get().i = 17; + // CHECK: [[PTR:%[^ ]*]] = OpAccessChain [[INTP]] + // CHECK-SAME: [[I6]] + // CHECK: OpStore [[PTR]] [[I17]] + + output01.Get().i64 = -19; + // CHECK: [[PTR:%[^ ]*]] = OpAccessChain [[LONGP]] + // CHECK-SAME: [[I7]] + // CHECK: OpStore [[PTR]] [[LN19]] + + output01.Get().u64 = 21; + // CHECK: [[PTR:%[^ ]*]] = OpAccessChain [[ULONGP]] + // CHECK-SAME: [[I8]] + // CHECK: OpStore [[PTR]] [[UL21]] + + output01.Get().f3.y = 23; + // CHECK: [[PTR:%[^ ]*]] = OpAccessChain [[FLOATP]] + // CHECK-SAME: [[I9]] + // CHECK-SAME: [[I1]] + // CHECK: OpStore [[PTR]] [[F23]] + + output01.Get().ia[5] = 29; + // CHECK: [[PTR:%[^ ]*]] = OpAccessChain [[INTP]] + // CHECK-SAME: [[I10]] + // CHECK-SAME: [[I5]] + // CHECK: OpStore [[PTR]] [[I29]] +} diff --git a/tools/clang/test/CodeGenSPIRV/node.mesh.hlsl b/tools/clang/test/CodeGenSPIRV/node.mesh.hlsl new file mode 100644 index 0000000000..4d1726abb2 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.mesh.hlsl @@ -0,0 +1,88 @@ +// RUN: %dxc -spirv -T lib_6_9 -fspv-target-env=vulkan1.3 %s | FileCheck %s +// XFAIL: * +// disabled until mesh nodes are implemented + +// Test loading of node input and funneling into mesh outputs +// Essentially an end-to-end mesh node test. + + +RWBuffer buf0; + +#define MAX_VERT 32 +#define MAX_PRIM 16 + +struct MeshPerVertex { + float4 position : SV_Position; + float color[4] : COLOR; +}; + +struct MeshPerPrimitive { + float normal : NORMAL; + float malnor : MALNOR; + float alnorm : ALNORM; + float ormaln : ORMALN; + int layer[6] : LAYER; +}; + +struct MeshPayload { + float normal; + float malnor; + float alnorm; + float ormaln; + int layer[6]; +}; + +groupshared float gsMem[MAX_PRIM]; + +[Shader("node")] +[NodeLaunch("mesh")] +[outputtopology("triangle")] +[numthreads(128, 1, 1)] +[NodeDispatchGrid(64,1,1)] +void node_setmeshoutputcounts(DispatchNodeInputRecord mpl, + out indices uint3 primIndices[MAX_PRIM], + out vertices MeshPerVertex verts[MAX_VERT], + out primitives MeshPerPrimitive prims[MAX_PRIM], + in uint tig : SV_GroupIndex) { + SetMeshOutputCounts(32, 16); + + // create mpl + + MeshPerVertex ov; + ov.position = float4(14.0,15.0,16.0,17.0); + ov.color[0] = 14.0; + ov.color[1] = 15.0; + ov.color[2] = 16.0; + ov.color[3] = 17.0; + + if (tig % 3) { + primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2); + + MeshPerPrimitive op; + op.normal = mpl.Get().normal; + op.malnor = gsMem[tig / 3 + 1]; + op.alnorm = mpl.Get().alnorm; + op.ormaln = mpl.Get().ormaln; + op.layer[0] = mpl.Get().layer[0]; + op.layer[1] = mpl.Get().layer[1]; + op.layer[2] = mpl.Get().layer[2]; + op.layer[3] = mpl.Get().layer[3]; + op.layer[4] = mpl.Get().layer[4]; + op.layer[5] = mpl.Get().layer[5]; + + gsMem[tig / 3] = op.normal; + prims[tig / 3] = op; + } + verts[tig] = ov; +} + +// CHECK: OpEntryPoint MeshExt [[ENTRY:%[^ ]*]] +// CHECK-DAG: OpExecutionMode [[ENTRY]] OutputVertices 32 +// CHECK-DAG: OpExecutionMode [[ENTRY]] OutputPrimitivesNV 16 +// CHECK-DAG: OpExecutionMode [[ENTRY]] OutputTrianglesNV +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U16:%[^ ]*]] = OpConstant [[UINT]] 16 +// CHECK-DAG: [[U32:%[^ ]*]] = OpConstant [[UINT]] 32 +// CHECK: [[ENTRY]] = OpFunction +// CHECK: OpSetMeshOutputsEXT [[U32]] [[U16]] +// CHECK: OpFunctionEnd diff --git a/tools/clang/test/CodeGenSPIRV/node.output-complete.hlsl b/tools/clang/test/CodeGenSPIRV/node.output-complete.hlsl new file mode 100644 index 0000000000..17db15e7db --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.output-complete.hlsl @@ -0,0 +1,33 @@ +// RUN: %dxc -spirv -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// OutputComplete() is called with NodeOutput + +struct OUTPUT_RECORD +{ + uint value; +}; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(256,1,1)] +[NumThreads(1024,1,1)] +void outputcomplete([MaxRecords(256)] NodeOutput output) +{ + ThreadNodeOutputRecords outputrecords = output.GetThreadNodeOutputRecords(1); + // ... + outputrecords.OutputComplete(); +} + +// CHECK: OpName [[RECORDS:%[^ ]*]] "outputrecords" +// CHECK: OpDecorateId [[ARR:%[^ ]*]] PayloadNodeNameAMDX [[STR:%[0-9A-Za-z_]*]] +// CHECK-DAG: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U1:%[^ ]*]] = OpConstant [[UINT]] 1 +// CHECK-DAG: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK-DAG: [[REC:%[^ ]*]] = OpTypeStruct [[UINT]] +// CHECK-DAG: [[ARR:%[^ ]*]] = OpTypeNodePayloadArrayAMDX [[REC]] +// CHECK-DAG: [[PTR:%[^ ]*]] = OpTypePointer NodePayloadAMDX [[ARR]] +// CHECK-DAG: [[U4:[^ ]*]] = OpConstant [[UINT]] 4 +// CHECK: [[V0:%[^ ]*]] = OpAllocateNodePayloadsAMDX [[PTR]] [[U4]] [[U1]] [[U0]] +// CHECK: [[V1:%[^ ]*]] = OpLoad [[ARR]] [[V0]] +// CHECK: OpStore [[RECORDS]] [[V1]] +// CHECK: OpEnqueueNodePayloadsAMDX [[RECORDS]] diff --git a/tools/clang/test/CodeGenSPIRV/node.output.is-valid.empty.hlsl b/tools/clang/test/CodeGenSPIRV/node.output.is-valid.empty.hlsl new file mode 100644 index 0000000000..08a103cf5e --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.output.is-valid.empty.hlsl @@ -0,0 +1,19 @@ +// RUN: %dxc -spirv -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// NodeOutputIsValid() is called with EmptyNodeOutput + +RWBuffer buf0; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(256,1,1)] +[NumThreads(1,1,1)] +void node131_nodeoutputisvalid_emptynodeoutput(EmptyNodeOutput output) +{ + buf0[0] = output.IsValid(); +} + +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK: [[BOOL:%[^ ]*]] = OpTypeBool +// CHECK: OpIsNodePayloadValidAMDX [[BOOL]] %{{[^ ]*}} [[U0]] diff --git a/tools/clang/test/CodeGenSPIRV/node.output.is-valid.hlsl b/tools/clang/test/CodeGenSPIRV/node.output.is-valid.hlsl new file mode 100644 index 0000000000..40e3a74fcb --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.output.is-valid.hlsl @@ -0,0 +1,24 @@ +// RUN: %dxc -spirv -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// IsValid() is invoked on NodeOutput + +RWBuffer buf0; + +struct RECORD +{ + uint value; +}; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(256,1,1)] +[NumThreads(1,1,1)] +void node129_nodeoutputisvalid_nodeoutput(NodeOutput output) +{ + buf0[0] = output.IsValid(); +} + +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK: [[BOOL:%[^ ]*]] = OpTypeBool +// CHECK: OpIsNodePayloadValidAMDX [[BOOL]] %{{[^ ]*}} [[U0]] diff --git a/tools/clang/test/CodeGenSPIRV/node.renamed.hlsl b/tools/clang/test/CodeGenSPIRV/node.renamed.hlsl new file mode 100644 index 0000000000..953288929d --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.renamed.hlsl @@ -0,0 +1,23 @@ +// RUN: %dxc -spirv -Od -T lib_6_8 -fspv-target-env=vulkan1.3 external %s | FileCheck %s + +// Renamed node, unnamed index defaults to 0 + +struct RECORD { + uint i; +}; + +[Shader("node")] +[NodeLaunch("thread")] +[NodeID("new_node_name")] +[NodeIsProgramEntry] +void node017_renamed_node([NodeID("output_node_name", 2)] NodeOutput r) +{ + r.GetThreadNodeOutputRecords(1); +} + +// CHECK: OpEntryPoint GLCompute %{{[^ ]*}} "node017_renamed_node" +// CHECK-DAG: OpDecorateId [[TYPE:%[^ ]*]] PayloadNodeNameAMDX [[STR:%[0-9A-Za-z_]*]] +// CHECK-DAG: OpDecorateId [[TYPE]] PayloadNodeBaseIndexAMDX [[U2:%[0-9A-Za-z_]*]] +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[STR]] = OpConstantStringAMDX "output_node_name" +// CHECK-DAG: [[U2]] = OpConstant [[UINT]] 2 diff --git a/tools/clang/test/CodeGenSPIRV/node.share-input.hlsl b/tools/clang/test/CodeGenSPIRV/node.share-input.hlsl new file mode 100644 index 0000000000..c439bef017 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.share-input.hlsl @@ -0,0 +1,42 @@ +// RUN: %dxc -spirv -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// Check that the NodeShareInputOf metadata entry is populated correctly + +struct entryRecord +{ + int data0; +}; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(2, 1, 1)] +[NumThreads(1, 1, 1)] +void firstNode(DispatchNodeInputRecord inputData) +{ } + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(2, 1, 1)] +[NumThreads(1, 1, 1)] +[NodeShareInputOf("firstNode")] +void secondNode(DispatchNodeInputRecord inputData) +{ } + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(2, 1, 1)] +[NumThreads(1, 1, 1)] +[NodeShareInputOf("firstNode", 3)] +void thirdNode(DispatchNodeInputRecord inputData) +{ } + + +// CHECK: OpEntryPoint GLCompute %firstNode "firstNode" +// CHECK: OpEntryPoint GLCompute %secondNode "secondNode" +// CHECK: OpEntryPoint GLCompute %thirdNode "thirdNode" +// CHECK-NOT: OpExecutionModeId %firstNode SharesInputWithAMDX +// CHECK: OpExecutionModeId %secondNode SharesInputWithAMDX [[STR:%[^ ]*]] [[U0:%[^ ]*]] +// CHECK: OpExecutionModeId %thirdNode SharesInputWithAMDX [[STR]] [[U3:%[^ ]*]] +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK-DAG: [[U3:%[^ ]*]] = OpConstant [[UINT]] 3 diff --git a/tools/clang/test/CodeGenSPIRV/node.sparse-nodes.hlsl b/tools/clang/test/CodeGenSPIRV/node.sparse-nodes.hlsl new file mode 100644 index 0000000000..ca3c14b8da --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.sparse-nodes.hlsl @@ -0,0 +1,141 @@ +// RUN: %dxc -spirv -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s + +struct RECORD1 +{ + uint value; + uint value2; +}; + +// CHECK: OpEntryPoint GLCompute [[NODE10:%[^ ]*]] "node_1_0" +// CHECK: OpEntryPoint GLCompute [[NODE11:%[^ ]*]] "node_1_1" +// CHECK: OpEntryPoint GLCompute [[NODE12:%[^ ]*]] "node_1_2" +// CHECK: OpEntryPoint GLCompute [[NODE20:%[^ ]*]] "node_2_0" +// CHECK: OpEntryPoint GLCompute [[NODE21:%[^ ]*]] "node_2_1" +// CHECK: OpEntryPoint GLCompute [[NODE22:%[^ ]*]] "node_2_2" +// CHECK: OpDecorateId [[A10:%[^ ]*]] PayloadNodeNameAMDX [[S10:%[^ ]*]] +// CHECK: OpDecorateId [[A10]] NodeMaxPayloadsAMDX [[U31:%[^ ]*]] +// CHECK: OpDecorate [[A10]] PayloadNodeSparseArrayAMDX +// CHECK: OpDecorateId [[A10]] PayloadNodeArraySizeAMDX [[U129:%[^ ]*]] +// CHECK: OpDecorateId [[A11:%[^ ]*]] PayloadNodeNameAMDX [[S11:%[^ ]*]] +// CHECK: OpDecorateId [[A11]] NodeMaxPayloadsAMDX [[U37:%[^ ]*]] +// CHECK: OpDecorate [[A11]] PayloadNodeSparseArrayAMDX +// CHECK: OpDecorateId [[A12:%[^ ]*]] PayloadNodeNameAMDX [[S12:%[^ ]*]] +// CHECK: OpDecorateId [[A12]] NodeMaxPayloadsAMDX [[U47:%[^ ]*]] +// CHECK: OpDecorate [[A12]] PayloadNodeSparseArrayAMDX +// CHECK: OpDecorateId [[A20:%[^ ]*]] PayloadNodeNameAMDX [[S20:%[^ ]*]] +// CHECK: OpDecorateId [[A20]] NodeMaxPayloadsAMDX [[U41:%[^ ]*]] +// CHECK: OpDecorate [[A20]] PayloadNodeSparseArrayAMDX +// CHECK: OpDecorateId [[A20]] PayloadNodeArraySizeAMDX [[U131:%[^ ]*]] +// CHECK: OpDecorateId [[A21:%[^ ]*]] PayloadNodeNameAMDX [[S21:%[^ ]*]] +// CHECK: OpDecorateId [[A21]] NodeMaxPayloadsAMDX [[U43:%[^ ]*]] +// CHECK: OpDecorate [[A21]] PayloadNodeSparseArrayAMDX +// CHECK: OpDecorateId [[A22:%[^ ]*]] PayloadNodeNameAMDX [[S22:%[^ ]*]] +// CHECK: OpDecorateId [[A22]] NodeMaxPayloadsAMDX [[U53:%[^ ]*]] +// CHECK: OpDecorate [[A22]] PayloadNodeSparseArrayAMDX +// CHECK: [[UINT:%[^ ]*]] = OpTypeInt 32 0 +// CHECK: [[U0:%[^ ]*]] = OpConstant [[UINT]] 0 +// CHECK: [[RECORD:%[^ ]*]] = OpTypeStruct [[UINT]] [[UINT]] +// CHECK-DAG: [[A10]] = OpTypeNodePayloadArrayAMDX [[RECORD]] +// CHECK-DAG: [[S10]] = OpConstantStringAMDX "OutputArray_1_0" +// CHECK-DAG: [[U31]] = OpConstant [[UINT]] 31 +// CHECK-DAG: [[U129]] = OpConstant [[UINT]] 129 +// CHECK-DAG: [[A11]] = OpTypeNodePayloadArrayAMDX [[RECORD]] +// CHECK-DAG: [[S11]] = OpConstantStringAMDX "OutputArray_1_1" +// CHECK-DAG: [[U37]] = OpConstant [[UINT]] 37 +// CHECK-DAG: [[A12]] = OpTypeNodePayloadArrayAMDX [[RECORD]] +// CHECK-DAG: [[S12]] = OpConstantStringAMDX "Output_1_2" +// CHECK-DAG: [[U47]] = OpConstant [[UINT]] 47 +// CHECK-DAG: [[EMPTY:%[^ ]*]] = OpTypeStruct +// CHECK-DAG: [[A20]] = OpTypeNodePayloadArrayAMDX [[EMPTY]] +// CHECK-DAG: [[S20]] = OpConstantStringAMDX "OutputArray_2_0" +// CHECK-DAG: [[U41]] = OpConstant [[UINT]] 41 +// CHECK-DAG: [[U131]] = OpConstant [[UINT]] 131 +// CHECK-DAG: [[A21]] = OpTypeNodePayloadArrayAMDX [[EMPTY]] +// CHECK-DAG: [[S21]] = OpConstantStringAMDX "OutputArray_2_1" +// CHECK-DAG: [[U43]] = OpConstant [[UINT]] 43 +// CHECK-DAG: [[A22]] = OpTypeNodePayloadArrayAMDX [[EMPTY]] +// CHECK-DAG: [[S22]] = OpConstantStringAMDX "Output_2_2" +// CHECK-DAG: [[U53]] = OpConstant [[UINT]] 53 + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(1, 1, 1)] +[NumThreads(1, 1, 1)] +void node_1_0( + [AllowSparseNodes] [NodeArraySize(129)] [MaxRecords(31)] + NodeOutputArray OutputArray_1_0) { + ThreadNodeOutputRecords outRec = OutputArray_1_0[1].GetThreadNodeOutputRecords(2); + outRec.OutputComplete(); +} + +// CHECK: [[NODE10]] = OpFunction %void None +// CHECK: OpFunctionEnd + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(1, 1, 1)] +[NumThreads(1, 1, 1)] +void node_1_1( + [UnboundedSparseNodes] [MaxRecords(37)] + NodeOutputArray OutputArray_1_1) { + ThreadNodeOutputRecords outRec = OutputArray_1_1[1].GetThreadNodeOutputRecords(2); + outRec.OutputComplete(); +} + +// CHECK: [[NODE11]] = OpFunction %void None +// CHECK: OpFunctionEnd + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(1, 1, 1)] +[NumThreads(1, 1, 1)] +void node_1_2( + [AllowSparseNodes] [MaxRecords(47)] + NodeOutput Output_1_2) { + ThreadNodeOutputRecords outRec = Output_1_2.GetThreadNodeOutputRecords(2); + outRec.OutputComplete(); +} + +// CHECK: [[NODE12]] = OpFunction %void None +// CHECK: %{{[^ ]*}} = OpAllocateNodePayloadsAMDX %{{[^ ]*}} %{{[^ ]*}} %{{[^ ]*}} [[U0]] +// CHECK: OpFunctionEnd + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(1, 1, 1)] +[NumThreads(1, 1, 1)] +void node_2_0( + [AllowSparseNodes] [NodeArraySize(131)] [MaxRecords(41)] + EmptyNodeOutputArray OutputArray_2_0) { + OutputArray_2_0[1].GroupIncrementOutputCount(10); +} + +// CHECK: [[NODE20]] = OpFunction %void None +// CHECK: OpFunctionEnd + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(1, 1, 1)] +[NumThreads(1, 1, 1)] +void node_2_1( + [UnboundedSparseNodes] [MaxRecords(43)] + EmptyNodeOutputArray OutputArray_2_1) { + OutputArray_2_1[1].GroupIncrementOutputCount(10); +} + +// CHECK: [[NODE21]] = OpFunction %void None +// CHECK: OpFunctionEnd + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeDispatchGrid(1, 1, 1)] +[NumThreads(1, 1, 1)] +void node_2_2( + [AllowSparseNodes] [MaxRecords(53)] + EmptyNodeOutput Output_2_2) { + Output_2_2.GroupIncrementOutputCount(10); +} + +// CHECK: [[NODE22]] = OpFunction %void None +// CHECK: %{{[^ ]*}} = OpAllocateNodePayloadsAMDX %{{[^ ]*}} %{{[^ ]*}} %{{[^ ]*}} [[U0]] +// CHECK: OpFunctionEnd diff --git a/tools/clang/test/CodeGenSPIRV/node.thread.num-threads.hlsl b/tools/clang/test/CodeGenSPIRV/node.thread.num-threads.hlsl new file mode 100644 index 0000000000..8732cf3478 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.thread.num-threads.hlsl @@ -0,0 +1,15 @@ +// RUN: %dxc -spirv -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// NumThreads + +[Shader("node")] +[NodeLaunch("thread")] +[NumThreads(1,1,1)] +[NodeIsProgramEntry] +void node010_thread_numthreads_shader() +{ +} + +// CHECK: OpEntryPoint GLCompute [[SHADER:%[0-9A-Za-z_]*]] +// CHECK: OpExecutionMode [[SHADER]] LocalSize 1 1 1 +// CHECK: OpReturn diff --git a/tools/clang/test/CodeGenSPIRV/node.thread.num-threads.none.hlsl b/tools/clang/test/CodeGenSPIRV/node.thread.num-threads.none.hlsl new file mode 100644 index 0000000000..0b230479c4 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/node.thread.num-threads.none.hlsl @@ -0,0 +1,15 @@ +// RUN: %dxc -spirv -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// Thread launch node without NumThreads specified should use a +// default of (1,1,1) + +[Shader("node")] +[NodeLaunch("thread")] +[NodeIsProgramEntry] +void node011_thread_numthreads_none() +{ +} + +// CHECK: OpEntryPoint GLCompute [[SHADER:%[0-9A-Za-z_]*]] +// CHECK: OpExecutionMode [[SHADER]] LocalSize 1 1 1 +// CHECK: OpReturn diff --git a/tools/clang/test/CodeGenSPIRV/vk.attribute.image-format.hlsl b/tools/clang/test/CodeGenSPIRV/vk.attribute.image-format.hlsl index 12b03fffda..4d10dc446b 100644 --- a/tools/clang/test/CodeGenSPIRV/vk.attribute.image-format.hlsl +++ b/tools/clang/test/CodeGenSPIRV/vk.attribute.image-format.hlsl @@ -60,12 +60,6 @@ RWBuffer Buf_r64i; [[vk::image_format("r64ui")]] RWBuffer Buf_r64ui; -[[vk::image_format("r16f")]] -// CHECK: [[ImgType:%[0-9a-zA-Z_]+]] = OpTypeImage %float 2D 2 0 0 2 R16f -// CHECK: [[ArrayType:%[0-9a-zA-Z_]+]] = OpTypeRuntimeArray [[ImgType]] -// CHECK: [[PtrType:%[0-9a-zA-Z_]+]] = OpTypePointer UniformConstant [[ArrayType]] -RWTexture2D Buf_r16f_bindless[]; - struct S { RWBuffer b; }; From edc97d82833b2531f5ea61e7c0b5414c79da0908 Mon Sep 17 00:00:00 2001 From: Dan Brown Date: Mon, 2 Jun 2025 15:48:09 -0600 Subject: [PATCH 2/2] Adds experimental feature warning. --- tools/clang/include/clang/Basic/DiagnosticSemaKinds.td | 2 ++ tools/clang/lib/Sema/SemaHLSL.cpp | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td b/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td index 48412facad..094570775c 100644 --- a/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -8063,6 +8063,8 @@ def err_hlsl_vk_pointer_cast_alignment: Error< "Vulkan buffer pointer cannot be cast to greater alignment">; def err_hlsl_vk_static_pointer_cast_type: Error< "vk::static_pointer_cast() content type must be base class of argument's content type">; +def warn_spirv_node_shaders_experimental : Warning< + "SPIR-V implementation of node shaders is experimental and subject to change">; // SPIRV Change Ends let CategoryName = "OpenMP Issue" in { diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index fa59aa6ef7..41802c2f7e 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -17140,6 +17140,10 @@ void DiagnoseNodeEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName, DXIL::ShaderKind shaderKind = ShaderModel::KindFromFullName(StageName); if (shaderKind == DXIL::ShaderKind::Node) { NodeLoc = pAttr->getLocation(); + // SPIR-V node shader support is experimental + if (S.getLangOpts().SPIRV) { + S.Diag(NodeLoc, diag::warn_spirv_node_shaders_experimental); + } } if (NodeLoc.isInvalid()) { return;