diff --git a/include/dxc/DXIL/DxilFunctionProps.h b/include/dxc/DXIL/DxilFunctionProps.h index 425ec4e391..e0223e9dd5 100644 --- a/include/dxc/DXIL/DxilFunctionProps.h +++ b/include/dxc/DXIL/DxilFunctionProps.h @@ -117,6 +117,7 @@ struct DxilFunctionProps { memset(&Node, 0, sizeof(Node)); Node.LaunchType = DXIL::NodeLaunchType::Invalid; Node.LocalRootArgumentsTableIndex = -1; + groupSharedLimitBytes = 0; } union { // Geometry shader. @@ -174,6 +175,8 @@ struct DxilFunctionProps { // numThreads shared between multiple shader types and node shaders. unsigned numThreads[3]; + unsigned groupSharedLimitBytes; + struct NodeProps { DXIL::NodeLaunchType LaunchType = DXIL::NodeLaunchType::Invalid; bool IsProgramEntry; diff --git a/include/dxc/DXIL/DxilMetadataHelper.h b/include/dxc/DXIL/DxilMetadataHelper.h index e17db016d8..5132cedb44 100644 --- a/include/dxc/DXIL/DxilMetadataHelper.h +++ b/include/dxc/DXIL/DxilMetadataHelper.h @@ -320,6 +320,7 @@ class DxilMDHelper { static const unsigned kDxilNodeOutputsTag = 21; static const unsigned kDxilNodeMaxDispatchGridTag = 22; static const unsigned kDxilRangedWaveSizeTag = 23; + static const unsigned kDxilGroupSharedLimitTag = 24; // Node Input/Output State. static const unsigned kDxilNodeOutputIDTag = 0; diff --git a/include/dxc/DXIL/DxilModule.h b/include/dxc/DXIL/DxilModule.h index 3f1ba12f86..a2cf80ae82 100644 --- a/include/dxc/DXIL/DxilModule.h +++ b/include/dxc/DXIL/DxilModule.h @@ -254,6 +254,10 @@ class DxilModule { void SetNumThreads(unsigned x, unsigned y, unsigned z); unsigned GetNumThreads(unsigned idx) const; + unsigned GetGroupSharedLimit() const; + // The total amount of group shared memory (in bytes) used by the shader. + unsigned GetTGSMSizeInBytes() const; + // Compute shader DxilWaveSize &GetWaveSize(); const DxilWaveSize &GetWaveSize() const; diff --git a/include/dxc/DxilContainer/DxilPipelineStateValidation.h b/include/dxc/DxilContainer/DxilPipelineStateValidation.h index 83d0dae6e9..87919f7df8 100644 --- a/include/dxc/DxilContainer/DxilPipelineStateValidation.h +++ b/include/dxc/DxilContainer/DxilPipelineStateValidation.h @@ -175,6 +175,10 @@ struct PSVRuntimeInfo3 : public PSVRuntimeInfo2 { uint32_t EntryFunctionName; }; +struct PSVRuntimeInfo4 : public PSVRuntimeInfo3 { + uint32_t NumBytesGroupSharedMemory; +}; + enum class PSVResourceType { Invalid = 0, @@ -474,7 +478,7 @@ class PSVSignatureElement { const uint32_t *SemanticIndexes) const; }; -#define MAX_PSV_VERSION 3 +#define MAX_PSV_VERSION 4 struct PSVInitInfo { PSVInitInfo(uint32_t psvVersion) : PSVVersion(psvVersion) {} @@ -491,7 +495,7 @@ struct PSVInitInfo { uint8_t SigPatchConstOrPrimVectors = 0; uint8_t SigOutputVectors[PSV_GS_MAX_STREAMS] = {0, 0, 0, 0}; - static_assert(MAX_PSV_VERSION == 3, "otherwise this needs updating."); + static_assert(MAX_PSV_VERSION == 4, "otherwise this needs updating."); uint32_t RuntimeInfoSize() const { switch (PSVVersion) { case 0: @@ -500,10 +504,12 @@ struct PSVInitInfo { return sizeof(PSVRuntimeInfo1); case 2: return sizeof(PSVRuntimeInfo2); + case 3: + return sizeof(PSVRuntimeInfo3); default: break; } - return sizeof(PSVRuntimeInfo3); + return sizeof(PSVRuntimeInfo4); } uint32_t ResourceBindInfoSize() const { if (PSVVersion < 2) @@ -519,6 +525,7 @@ class DxilPipelineStateValidation { PSVRuntimeInfo1 *m_pPSVRuntimeInfo1 = nullptr; PSVRuntimeInfo2 *m_pPSVRuntimeInfo2 = nullptr; PSVRuntimeInfo3 *m_pPSVRuntimeInfo3 = nullptr; + PSVRuntimeInfo4 *m_pPSVRuntimeInfo4 = nullptr; uint32_t m_uResourceCount = 0; uint32_t m_uPSVResourceBindInfoSize = 0; void *m_pPSVResourceBindInfo = nullptr; @@ -634,6 +641,8 @@ class DxilPipelineStateValidation { PSVRuntimeInfo3 *GetPSVRuntimeInfo3() const { return m_pPSVRuntimeInfo3; } + PSVRuntimeInfo4 *GetPSVRuntimeInfo4() const { return m_pPSVRuntimeInfo4; } + uint32_t GetBindCount() const { return m_uResourceCount; } template @@ -949,6 +958,8 @@ DxilPipelineStateValidation::ReadOrWrite(const void *pBits, uint32_t *pSize, m_uPSVRuntimeInfoSize); // failure ok AssignDerived(&m_pPSVRuntimeInfo3, m_pPSVRuntimeInfo0, m_uPSVRuntimeInfoSize); // failure ok + AssignDerived(&m_pPSVRuntimeInfo4, m_pPSVRuntimeInfo0, + m_uPSVRuntimeInfoSize); // failure ok // In RWMode::CalcSize, use temp runtime info to hold needed values from // initInfo @@ -1137,11 +1148,13 @@ void SetupPSVInitInfo(PSVInitInfo &InitInfo, const DxilModule &DM); void SetShaderProps(PSVRuntimeInfo0 *pInfo, const DxilModule &DM); void SetShaderProps(PSVRuntimeInfo1 *pInfo1, const DxilModule &DM); void SetShaderProps(PSVRuntimeInfo2 *pInfo2, const DxilModule &DM); +void SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM); void PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0, PSVRuntimeInfo1 *pInfo1, PSVRuntimeInfo2 *pInfo2, - PSVRuntimeInfo3 *pInfo3, uint8_t ShaderKind, - const char *EntryName, const char *Comment); + PSVRuntimeInfo3 *pInfo3, PSVRuntimeInfo4 *pInfo4, + uint8_t ShaderKind, const char *EntryName, + const char *Comment); } // namespace hlsl diff --git a/lib/DXIL/DxilMetadataHelper.cpp b/lib/DXIL/DxilMetadataHelper.cpp index c1282a980a..c683cae871 100644 --- a/lib/DXIL/DxilMetadataHelper.cpp +++ b/lib/DXIL/DxilMetadataHelper.cpp @@ -1624,6 +1624,13 @@ MDTuple *DxilMDHelper::EmitDxilEntryProperties(uint64_t rawShaderFlag, } MDVals.emplace_back(MDNode::get(m_Ctx, WaveSizeVal)); } + + const hlsl::ShaderModel *SM = GetShaderModel(); + if (SM->IsSMAtLeast(6, 10)) { + MDVals.emplace_back( + Uint32ToConstMD(DxilMDHelper::kDxilGroupSharedLimitTag)); + MDVals.emplace_back(Uint32ToConstMD(props.groupSharedLimitBytes)); + } } break; // Geometry shader. case DXIL::ShaderKind::Geometry: { @@ -1773,6 +1780,13 @@ void DxilMDHelper::LoadDxilEntryProperties(const MDOperand &MDO, props.numThreads[2] = ConstMDToUint32(pNode->getOperand(2)); } break; + case DxilMDHelper::kDxilGroupSharedLimitTag: { + DXASSERT(props.IsCS(), "else invalid shader kind"); + props.groupSharedLimitBytes = ConstMDToUint32(MDO); + if (!m_pSM->IsSMAtLeast(6, 10)) + m_bExtraMetadata = true; + } break; + case DxilMDHelper::kDxilGSStateTag: { DXASSERT(props.IsGS(), "else invalid shader kind"); auto &GS = props.ShaderProps.GS; diff --git a/lib/DXIL/DxilModule.cpp b/lib/DXIL/DxilModule.cpp index f4abdd15aa..4421147275 100644 --- a/lib/DXIL/DxilModule.cpp +++ b/lib/DXIL/DxilModule.cpp @@ -412,6 +412,28 @@ unsigned DxilModule::GetNumThreads(unsigned idx) const { return props.numThreads[idx]; } +unsigned DxilModule::GetGroupSharedLimit() const { + DXASSERT(m_DxilEntryPropsMap.size() == 1 && + (m_pSM->IsCS() || m_pSM->IsMS() || m_pSM->IsAS()), + "only works for CS/MS/AS profiles"); + const DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props; + DXASSERT_NOMSG(m_pSM->GetKind() == props.shaderKind); + return props.groupSharedLimitBytes; +} + +unsigned DxilModule::GetTGSMSizeInBytes() const { + const DataLayout &DL = m_pModule->getDataLayout(); + unsigned TGSMSize = 0; + + for (GlobalVariable &GV : m_pModule->globals()) { + if (GV.getType()->getAddressSpace() == DXIL::kTGSMAddrSpace) { + TGSMSize += DL.getTypeAllocSize(GV.getType()->getElementType()); + } + } + + return TGSMSize; +} + DxilWaveSize &DxilModule::GetWaveSize() { return const_cast( static_cast(this)->GetWaveSize()); diff --git a/lib/DxilContainer/DxilContainerAssembler.cpp b/lib/DxilContainer/DxilContainerAssembler.cpp index 48d8872733..736940b325 100644 --- a/lib/DxilContainer/DxilContainerAssembler.cpp +++ b/lib/DxilContainer/DxilContainerAssembler.cpp @@ -798,6 +798,8 @@ class DxilPSVWriter : public DxilPartWriter { PSVRuntimeInfo1 *pInfo1 = m_PSV.GetPSVRuntimeInfo1(); PSVRuntimeInfo2 *pInfo2 = m_PSV.GetPSVRuntimeInfo2(); PSVRuntimeInfo3 *pInfo3 = m_PSV.GetPSVRuntimeInfo3(); + PSVRuntimeInfo4 *pInfo4 = m_PSV.GetPSVRuntimeInfo4(); + if (pInfo) hlsl::SetShaderProps(pInfo, m_Module); if (pInfo1) @@ -806,6 +808,8 @@ class DxilPSVWriter : public DxilPartWriter { hlsl::SetShaderProps(pInfo2, m_Module); if (pInfo3) pInfo3->EntryFunctionName = EntryFunctionName; + if (pInfo4) + hlsl::SetShaderProps(pInfo4, m_Module); // Set resource binding information UINT uResIndex = 0; diff --git a/lib/DxilContainer/DxilPipelineStateValidation.cpp b/lib/DxilContainer/DxilPipelineStateValidation.cpp index 66186549f2..78be73c41d 100644 --- a/lib/DxilContainer/DxilPipelineStateValidation.cpp +++ b/lib/DxilContainer/DxilPipelineStateValidation.cpp @@ -33,6 +33,8 @@ uint32_t hlsl::GetPSVVersion(uint32_t ValMajor, uint32_t ValMinor) { PSVVersion = 1; else if (DXIL::CompareVersions(ValMajor, ValMinor, 1, 8) < 0) PSVVersion = 2; + else if (DXIL::CompareVersions(ValMajor, ValMinor, 1, 10) < 0) + PSVVersion = 3; return PSVVersion; } @@ -305,6 +307,20 @@ void hlsl::SetShaderProps(PSVRuntimeInfo2 *pInfo2, const DxilModule &DM) { } } +void hlsl::SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM) { + assert(pInfo4); + const ShaderModel *SM = DM.GetShaderModel(); + switch (SM->GetKind()) { + case ShaderModel::Kind::Compute: + case ShaderModel::Kind::Mesh: + case ShaderModel::Kind::Amplification: + pInfo4->NumBytesGroupSharedMemory = DM.GetTGSMSizeInBytes(); + break; + default: + break; + } +} + void PSVResourceBindInfo0::Print(raw_ostream &OS) const { OS << "PSVResourceBindInfo:\n"; OS << " Space: " << Space << "\n"; @@ -584,8 +600,9 @@ void PSVDependencyTable::Print(raw_ostream &OS, const char *InputSetName, void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0, PSVRuntimeInfo1 *pInfo1, PSVRuntimeInfo2 *pInfo2, - PSVRuntimeInfo3 *pInfo3, uint8_t ShaderKind, - const char *EntryName, const char *Comment) { + PSVRuntimeInfo3 *pInfo3, PSVRuntimeInfo4 *pInfo4, + uint8_t ShaderKind, const char *EntryName, + const char *Comment) { if (pInfo1 && pInfo1->ShaderStage != ShaderKind) ShaderKind = pInfo1->ShaderStage; OS << Comment << "PSVRuntimeInfo:\n"; @@ -808,6 +825,11 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0, OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << "," << pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n"; } + if (pInfo4) { + OS << Comment + << " NumBytesGroupSharedMemory: " << pInfo4->NumBytesGroupSharedMemory + << "\n"; + } break; case PSVShaderKind::Amplification: OS << Comment << " Amplification Shader\n"; @@ -815,6 +837,11 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0, OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << "," << pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n"; } + if (pInfo4) { + OS << Comment + << " NumBytesGroupSharedMemory: " << pInfo4->NumBytesGroupSharedMemory + << "\n"; + } break; case PSVShaderKind::Mesh: OS << Comment << " Mesh Shader\n"; @@ -841,6 +868,11 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0, OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << "," << pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n"; } + if (pInfo4) { + OS << Comment + << " NumBytesGroupSharedMemory: " << pInfo4->NumBytesGroupSharedMemory + << "\n"; + } break; case PSVShaderKind::Library: case PSVShaderKind::Invalid: @@ -887,9 +919,10 @@ void DxilPipelineStateValidation::PrintPSVRuntimeInfo( PSVRuntimeInfo1 *pInfo1 = m_pPSVRuntimeInfo1; PSVRuntimeInfo2 *pInfo2 = m_pPSVRuntimeInfo2; PSVRuntimeInfo3 *pInfo3 = m_pPSVRuntimeInfo3; + PSVRuntimeInfo4 *pInfo4 = m_pPSVRuntimeInfo4; hlsl::PrintPSVRuntimeInfo( - OS, pInfo0, pInfo1, pInfo2, pInfo3, ShaderKind, + OS, pInfo0, pInfo1, pInfo2, pInfo3, pInfo4, ShaderKind, m_pPSVRuntimeInfo3 ? m_StringTable.Get(pInfo3->EntryFunctionName) : "", Comment); } diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index 89e23767fe..a138de6d6b 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -185,7 +185,8 @@ class PSVContentVerifier { unsigned PSVVersion); void VerifyViewIDDependence(PSVRuntimeInfo1 *PSV1, unsigned PSVVersion); void VerifyEntryProperties(const ShaderModel *SM, PSVRuntimeInfo0 *PSV0, - PSVRuntimeInfo1 *PSV1, PSVRuntimeInfo2 *PSV2); + PSVRuntimeInfo1 *PSV1, PSVRuntimeInfo2 *PSV2, + PSVRuntimeInfo3 *PSV3, PSVRuntimeInfo4 *PSV4); void EmitMismatchError(StringRef Name, StringRef PartContent, StringRef ModuleContent) { ValCtx.EmitFormatError(ValidationRule::ContainerContentMatches, @@ -409,16 +410,16 @@ void PSVContentVerifier::VerifyResources(unsigned PSVVersion) { VerifyResourceTable(DM.GetUAVs(), ResIndex, PSVVersion); } -void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM, - PSVRuntimeInfo0 *PSV0, - PSVRuntimeInfo1 *PSV1, - PSVRuntimeInfo2 *PSV2) { - PSVRuntimeInfo3 DMPSV; - memset(&DMPSV, 0, sizeof(PSVRuntimeInfo3)); +void PSVContentVerifier::VerifyEntryProperties( + const ShaderModel *SM, PSVRuntimeInfo0 *PSV0, PSVRuntimeInfo1 *PSV1, + PSVRuntimeInfo2 *PSV2, PSVRuntimeInfo3 *PSV3, PSVRuntimeInfo4 *PSV4) { + PSVRuntimeInfo4 DMPSV; + memset(&DMPSV, 0, sizeof(PSVRuntimeInfo4)); hlsl::SetShaderProps((PSVRuntimeInfo0 *)&DMPSV, DM); hlsl::SetShaderProps((PSVRuntimeInfo1 *)&DMPSV, DM); hlsl::SetShaderProps((PSVRuntimeInfo2 *)&DMPSV, DM); + hlsl::SetShaderProps((PSVRuntimeInfo4 *)&DMPSV, DM); if (PSV1) { // Init things not set in InitPSVRuntimeInfo. DMPSV.ShaderStage = static_cast(SM->GetKind()); @@ -444,10 +445,14 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM, else Mismatched = memcmp(PSV0, &DMPSV, sizeof(PSVRuntimeInfo0)) != 0; + if (PSV4 && + PSV4->NumBytesGroupSharedMemory != DMPSV.NumBytesGroupSharedMemory) + Mismatched = true; + if (Mismatched) { std::string Str; raw_string_ostream OS(Str); - hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV, + hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV, &DMPSV, static_cast(SM->GetKind()), DM.GetEntryFunctionName().c_str(), ""); OS.flush(); @@ -476,9 +481,11 @@ void PSVContentVerifier::Verify(unsigned ValMajor, unsigned ValMinor, PSVRuntimeInfo0 *PSV0 = PSV.GetPSVRuntimeInfo0(); PSVRuntimeInfo1 *PSV1 = PSV.GetPSVRuntimeInfo1(); PSVRuntimeInfo2 *PSV2 = PSV.GetPSVRuntimeInfo2(); + PSVRuntimeInfo3 *PSV3 = PSV.GetPSVRuntimeInfo3(); + PSVRuntimeInfo4 *PSV4 = PSV.GetPSVRuntimeInfo4(); const ShaderModel *SM = DM.GetShaderModel(); - VerifyEntryProperties(SM, PSV0, PSV1, PSV2); + VerifyEntryProperties(SM, PSV0, PSV1, PSV2, PSV3, PSV4); if (PSVVersion > 0) { if (((PSV.GetSigInputElements() + PSV.GetSigOutputElements() + PSV.GetSigPatchConstOrPrimElements()) > 0) && diff --git a/lib/DxilValidation/DxilValidation.cpp b/lib/DxilValidation/DxilValidation.cpp index ac528c1117..fd7b000427 100644 --- a/lib/DxilValidation/DxilValidation.cpp +++ b/lib/DxilValidation/DxilValidation.cpp @@ -3935,6 +3935,18 @@ static void ValidateGlobalVariables(ValidationContext &ValCtx) { Rule = ValidationRule::SmMaxMSSMSize; MaxSize = DXIL::kMaxMSSMSize; } + + // Check if the entry function has attribute to override TGSM size. + if (M.HasDxilEntryProps(M.GetEntryFunction())) { + DxilEntryProps &EntryProps = M.GetDxilEntryProps(M.GetEntryFunction()); + if (EntryProps.props.IsCS()) { + unsigned SpecifiedTGSMSize = EntryProps.props.groupSharedLimitBytes; + if (SpecifiedTGSMSize > 0) { + MaxSize = SpecifiedTGSMSize; + } + } + } + if (TGSMSize > MaxSize) { Module::global_iterator GI = M.GetModule()->global_end(); GlobalVariable *GV = &*GI; diff --git a/tools/clang/include/clang/Basic/Attr.td b/tools/clang/include/clang/Basic/Attr.td index 83137dbc3a..33f1594cac 100644 --- a/tools/clang/include/clang/Basic/Attr.td +++ b/tools/clang/include/clang/Basic/Attr.td @@ -671,6 +671,11 @@ def HLSLNumThreads: InheritableAttr { let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">]; let Documentation = [Undocumented]; } +def HLSLGroupSharedLimit: InheritableAttr { + let Spellings = [CXX11<"", "GroupSharedLimit", 2017>]; + let Args = [IntArgument<"Limit">]; + let Documentation = [Undocumented]; +} def HLSLRootSignature: InheritableAttr { let Spellings = [CXX11<"", "RootSignature", 2015>]; let Args = [StringArgument<"SignatureName">]; diff --git a/tools/clang/lib/CodeGen/CGHLSLMS.cpp b/tools/clang/lib/CodeGen/CGHLSLMS.cpp index 6c68381a20..8be1117ad4 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMS.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMS.cpp @@ -1646,6 +1646,19 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { } } + if (const HLSLGroupSharedLimitAttr *Attr = + FD->getAttr()) { + funcProps->groupSharedLimitBytes = Attr->getLimit(); + } else { + if (SM->IsMS()) { // Fallback to default limits + funcProps->groupSharedLimitBytes = DXIL::kMaxMSSMSize; // 28k For MS + } else if (SM->IsAS() || SM->IsCS()) { + funcProps->groupSharedLimitBytes = DXIL::kMaxTGSMSize; // 32k For AS/CS + } else { + funcProps->groupSharedLimitBytes = 0; + } + } + // Hull shader. if (const HLSLPatchConstantFuncAttr *Attr = FD->getAttr()) { diff --git a/tools/clang/lib/Parse/ParseDecl.cpp b/tools/clang/lib/Parse/ParseDecl.cpp index d3f8d68443..eed90d0a78 100644 --- a/tools/clang/lib/Parse/ParseDecl.cpp +++ b/tools/clang/lib/Parse/ParseDecl.cpp @@ -843,6 +843,7 @@ void Parser::ParseGNUAttributeArgs(IdentifierInfo *AttrName, case AttributeList::AT_HLSLMaxVertexCount: case AttributeList::AT_HLSLUnroll: case AttributeList::AT_HLSLWaveSize: + case AttributeList::AT_HLSLGroupSharedLimit: case AttributeList::AT_NoInline: // The following are not accepted in [attribute(param)] syntax: // case AttributeList::AT_HLSLCentroid: diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index e9c8c90a2d..23e2dedc43 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -12971,6 +12971,22 @@ void DiagnoseEntryAttrAllowedOnStage(clang::Sema *self, } break; } + case clang::attr::HLSLGroupSharedLimit: { + switch (shaderKind) { + case DXIL::ShaderKind::Compute: + case DXIL::ShaderKind::Mesh: + case DXIL::ShaderKind::Amplification: + case DXIL::ShaderKind::Node: + break; + default: + self->Diag(pAttr->getRange().getBegin(), + diag::err_hlsl_attribute_unsupported_stage) + << "GroupSharedLimit" + << "compute, mesh, node, or amplification"; + break; + } + break; + } } } } @@ -14656,6 +14672,11 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, S.Context.getAddrSpaceQualType(VD->getType(), DXIL::kTGSMAddrSpace)); } break; + case AttributeList::AT_HLSLGroupSharedLimit: + declAttr = ::new (S.Context) HLSLGroupSharedLimitAttr( + A.getRange(), S.Context, ValidateAttributeIntArg(S, A), + A.getAttributeSpellingListIndex()); + break; case AttributeList::AT_HLSLUniform: declAttr = ::new (S.Context) HLSLUniformAttr( A.getRange(), S.Context, A.getAttributeSpellingListIndex()); diff --git a/tools/clang/test/DXC/dumpPSV_AS.hlsl b/tools/clang/test/DXC/dumpPSV_AS.hlsl index 41e5aeb4b6..eac57cc957 100644 --- a/tools/clang/test/DXC/dumpPSV_AS.hlsl +++ b/tools/clang/test/DXC/dumpPSV_AS.hlsl @@ -1,4 +1,4 @@ -// REQUIRES: dxil-1-8 +// REQUIRES: dxil-1-10 // RUN: %dxc -E main -T as_6_8 %s -Fo %t // RUN: %dxa %t -dumppsv | FileCheck %s @@ -6,6 +6,7 @@ // CHECK-NEXT: PSVRuntimeInfo: // CHECK-NEXT: Amplification Shader // CHECK-NEXT: NumThreads=(32,1,1) +// CHECK-NEXT: NumBytesGroupSharedMemory: 0 // CHECK-NEXT: MinimumExpectedWaveLaneCount: 0 // CHECK-NEXT: MaximumExpectedWaveLaneCount: 4294967295 // CHECK-NEXT: UsesViewID: false diff --git a/tools/clang/test/DXC/dumpPSV_AS_GroupShared.hlsl b/tools/clang/test/DXC/dumpPSV_AS_GroupShared.hlsl new file mode 100644 index 0000000000..54b2c516eb --- /dev/null +++ b/tools/clang/test/DXC/dumpPSV_AS_GroupShared.hlsl @@ -0,0 +1,33 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -E main -T as_6_8 %s -Fo %t +// RUN: %dxa %t -dumppsv | FileCheck %s + +// CHECK:DxilPipelineStateValidation: +// CHECK-NEXT: PSVRuntimeInfo: +// CHECK-NEXT: Amplification Shader +// CHECK-NEXT: NumThreads=(32,1,1) +// CHECK-NEXT: NumBytesGroupSharedMemory: 128 +// CHECK-NEXT: MinimumExpectedWaveLaneCount: 0 +// CHECK-NEXT: MaximumExpectedWaveLaneCount: 4294967295 + +// Test that NumBytesGroupSharedMemory is calculated correctly for AS with groupshared. +// groupshared float4 sharedData[8] = 8 * 16 bytes = 128 bytes + +struct Payload { + float4 data[4]; +}; + +groupshared float4 sharedData[8]; + +[numthreads(32, 1, 1)] +void main(uint gtid : SV_GroupIndex) { + sharedData[gtid % 8] = float4(gtid, 0, 0, 0); + GroupMemoryBarrierWithGroupSync(); + + Payload pld; + pld.data[0] = sharedData[0]; + pld.data[1] = sharedData[1]; + pld.data[2] = sharedData[2]; + pld.data[3] = sharedData[3]; + DispatchMesh(1, 1, 1, pld); +} diff --git a/tools/clang/test/DXC/dumpPSV_CS.hlsl b/tools/clang/test/DXC/dumpPSV_CS.hlsl index 3764f17008..a70b5595ab 100644 --- a/tools/clang/test/DXC/dumpPSV_CS.hlsl +++ b/tools/clang/test/DXC/dumpPSV_CS.hlsl @@ -1,4 +1,4 @@ -// REQUIRES: dxil-1-8 +// REQUIRES: dxil-1-10 // RUN: %dxc -E main -T cs_6_8 %s -Fo %t // RUN: %dxa %t -dumppsv | FileCheck %s @@ -6,6 +6,7 @@ // CHECK-NEXT: PSVRuntimeInfo: // CHECK-NEXT: Compute Shader // CHECK-NEXT: NumThreads=(128,1,1) +// CHECK-NEXT: NumBytesGroupSharedMemory: 2048 // CHECK-NEXT: MinimumExpectedWaveLaneCount: 0 // CHECK-NEXT: MaximumExpectedWaveLaneCount: 4294967295 // CHECK-NEXT: UsesViewID: false diff --git a/tools/clang/test/DXC/dumpPSV_CS_NoGroupShared.hlsl b/tools/clang/test/DXC/dumpPSV_CS_NoGroupShared.hlsl new file mode 100644 index 0000000000..6da3469e79 --- /dev/null +++ b/tools/clang/test/DXC/dumpPSV_CS_NoGroupShared.hlsl @@ -0,0 +1,20 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -E main -T cs_6_8 %s -Fo %t +// RUN: %dxa %t -dumppsv | FileCheck %s + +// CHECK: DxilPipelineStateValidation: +// CHECK-NEXT: PSVRuntimeInfo: +// CHECK-NEXT: Compute Shader +// CHECK-NEXT: NumThreads=(64,1,1) +// CHECK-NEXT: NumBytesGroupSharedMemory: 0 +// CHECK-NEXT: MinimumExpectedWaveLaneCount: 0 +// CHECK-NEXT: MaximumExpectedWaveLaneCount: 4294967295 + +// Test that NumBytesGroupSharedMemory is 0 when there is no groupshared memory. + +RWBuffer output : register(u0); + +[numthreads(64, 1, 1)] +void main(uint3 tid : SV_DispatchThreadID) { + output[tid.x] = tid.x * 2; +} diff --git a/tools/clang/test/DXC/dumpPSV_MS.hlsl b/tools/clang/test/DXC/dumpPSV_MS.hlsl index 82e1293bc4..e34c5c46ab 100644 --- a/tools/clang/test/DXC/dumpPSV_MS.hlsl +++ b/tools/clang/test/DXC/dumpPSV_MS.hlsl @@ -1,4 +1,4 @@ -// REQUIRES: dxil-1-8 +// REQUIRES: dxil-1-10 // RUN: %dxc -E main -T ms_6_8 %s -Fo %t // RUN: %dxa %t -dumppsv | FileCheck %s @@ -7,6 +7,7 @@ // CHECK-NEXT: Mesh Shader // CHECK-NEXT: MeshOutputTopology=triangle // CHECK-NEXT: NumThreads=(32,1,1) +// CHECK-NEXT: NumBytesGroupSharedMemory: 64 // CHECK-NEXT: MinimumExpectedWaveLaneCount: 0 // CHECK-NEXT: MaximumExpectedWaveLaneCount: 4294967295 // CHECK-NEXT: UsesViewID: true diff --git a/tools/clang/test/HLSLFileCheckLit/hlsl/entry/attributes/GroupSharedLimit.hlsl b/tools/clang/test/HLSLFileCheckLit/hlsl/entry/attributes/GroupSharedLimit.hlsl new file mode 100644 index 0000000000..defb95a9af --- /dev/null +++ b/tools/clang/test/HLSLFileCheckLit/hlsl/entry/attributes/GroupSharedLimit.hlsl @@ -0,0 +1,86 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -E MainPass -T cs_6_10 %s | FileCheck %s + +#define NUM_BYTES_OF_SHARED_MEM (32*1024) +#define NUM_DWORDS_SHARED_MEM (NUM_BYTES_OF_SHARED_MEM / 4) +#define THREAD_GROUP_SIZE_X 1024 + +groupshared uint g_testBufferPASS[NUM_DWORDS_SHARED_MEM]; + +RWStructuredBuffer g_output : register(u0); + +// CHECK: @MainPass + +[GroupSharedLimit(NUM_BYTES_OF_SHARED_MEM)] +[numthreads(THREAD_GROUP_SIZE_X, 1, 1)] +void MainPass( uint3 DTid : SV_DispatchThreadID ) +{ + uint iterations = NUM_DWORDS_SHARED_MEM / THREAD_GROUP_SIZE_X; + + for (uint i = 0; i < iterations; i++) + { + uint index = DTid.x + i * THREAD_GROUP_SIZE_X; + g_testBufferPASS[index] = index; + } + + // synchronize all threads in the group + GroupMemoryBarrierWithGroupSync(); + + // write the shared data to the output buffer + for (uint i = 0; i < iterations; i++) + { + uint index = DTid.x + i * THREAD_GROUP_SIZE_X; + g_output[index] = g_testBufferPASS[index]; + } +} + +// RUN: not %dxc -E MainFail -T cs_6_10 %s 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +// CHECK-ERROR: Total Thread Group Shared Memory storage is 32772, exceeded 32768. + +groupshared uint g_testBufferFAIL[NUM_DWORDS_SHARED_MEM + 1]; + +[GroupSharedLimit(NUM_BYTES_OF_SHARED_MEM)] +[numthreads(THREAD_GROUP_SIZE_X, 1, 1)] +void MainFail( uint3 DTid : SV_DispatchThreadID ) +{ + uint iterations = NUM_DWORDS_SHARED_MEM / THREAD_GROUP_SIZE_X; + + for (uint i = 0; i < iterations; i++) + { + uint index = DTid.x + i * THREAD_GROUP_SIZE_X; + g_testBufferFAIL[index] = index; + } + + // synchronize all threads in the group + GroupMemoryBarrierWithGroupSync(); + + // write the shared data to the output buffer + for (uint i = 0; i < iterations; i++) + { + uint index = DTid.x + i * THREAD_GROUP_SIZE_X; + g_output[index] = g_testBufferFAIL[index]; + } +} + +// RUN: not %dxc -E MainFail2 -T cs_6_10 %s 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +[numthreads(THREAD_GROUP_SIZE_X, 1, 1)] +void MainFail2( uint3 DTid : SV_DispatchThreadID ) +{ + uint iterations = NUM_DWORDS_SHARED_MEM / THREAD_GROUP_SIZE_X; + + for (uint i = 0; i < iterations; i++) + { + uint index = DTid.x + i * THREAD_GROUP_SIZE_X; + g_testBufferFAIL[index] = index; + } + + // synchronize all threads in the group + GroupMemoryBarrierWithGroupSync(); + + // write the shared data to the output buffer + for (uint i = 0; i < iterations; i++) + { + uint index = DTid.x + i * THREAD_GROUP_SIZE_X; + g_output[index] = g_testBufferFAIL[index]; + } +} diff --git a/tools/clang/unittests/HLSL/ValidationTest.cpp b/tools/clang/unittests/HLSL/ValidationTest.cpp index 5c4f013835..b1b79bbae9 100644 --- a/tools/clang/unittests/HLSL/ValidationTest.cpp +++ b/tools/clang/unittests/HLSL/ValidationTest.cpp @@ -4916,9 +4916,9 @@ TEST_F(ValidationTest, PSVStringTableReorder) { const uint32_t *PSVPtr = (const uint32_t *)GetDxilPartData(pPSVPart); uint32_t PSVRuntimeInfo_size = *(PSVPtr++); - VERIFY_ARE_EQUAL(sizeof(PSVRuntimeInfo3), PSVRuntimeInfo_size); - PSVRuntimeInfo3 *PSVInfo = - const_cast((const PSVRuntimeInfo3 *)PSVPtr); + VERIFY_ARE_EQUAL(sizeof(PSVRuntimeInfo4), PSVRuntimeInfo_size); + PSVRuntimeInfo4 *PSVInfo = + const_cast((const PSVRuntimeInfo4 *)PSVPtr); VERIFY_ARE_EQUAL(2u, PSVInfo->SigInputElements); PSVPtr += PSVRuntimeInfo_size / 4; uint32_t ResourceCount = *(PSVPtr++); @@ -5108,9 +5108,9 @@ TEST_F(ValidationTest, PSVSemanticIndexTableReorder) { const uint32_t *PSVPtr = (const uint32_t *)GetDxilPartData(pPSVPart); uint32_t PSVRuntimeInfo_size = *(PSVPtr++); - VERIFY_ARE_EQUAL(sizeof(PSVRuntimeInfo3), PSVRuntimeInfo_size); - PSVRuntimeInfo3 *PSVInfo = - const_cast((const PSVRuntimeInfo3 *)PSVPtr); + VERIFY_ARE_EQUAL(sizeof(PSVRuntimeInfo4), PSVRuntimeInfo_size); + PSVRuntimeInfo4 *PSVInfo = + const_cast((const PSVRuntimeInfo4 *)PSVPtr); VERIFY_ARE_EQUAL(PSVInfo->SigInputElements, 3u); VERIFY_ARE_EQUAL(PSVInfo->SigOutputElements, 3u); VERIFY_ARE_EQUAL(PSVInfo->SigPatchConstOrPrimElements, 2u); @@ -5453,10 +5453,11 @@ SimplePSV::SimplePSV(const DxilPartHeader *pPSVPart) { const uint32_t *PSVPtrEnd = PSVPtr + PartSize / 4; uint32_t PSVRuntimeInfoSize = *(PSVPtr++); - VERIFY_ARE_EQUAL(sizeof(PSVRuntimeInfo3), PSVRuntimeInfoSize); - PSVRuntimeInfo3 *PSVInfo3 = - const_cast((const PSVRuntimeInfo3 *)PSVPtr); - PSVInfo = PSVInfo3; + VERIFY_ARE_EQUAL(sizeof(PSVRuntimeInfo4), PSVRuntimeInfoSize); + PSVRuntimeInfo4 *PSVInfo4 = + const_cast((const PSVRuntimeInfo4 *)PSVPtr); + PSVInfo = PSVInfo4; + PSVRuntimeInfo3 *PSVInfo3 = reinterpret_cast(PSVInfo4); PSVPtr += PSVRuntimeInfoSize / 4; uint32_t ResourceCount = *(PSVPtr++); @@ -6559,9 +6560,9 @@ TEST_F(ValidationTest, WrongPSVSizeOnZeros) { const uint32_t *PSVPtr = (const uint32_t *)GetDxilPartData(pPSVPart); uint32_t PSVRuntimeInfo_size = *(PSVPtr++); - VERIFY_ARE_EQUAL(sizeof(PSVRuntimeInfo3), PSVRuntimeInfo_size); - PSVRuntimeInfo3 *PSVInfo = - const_cast((const PSVRuntimeInfo3 *)PSVPtr); + VERIFY_ARE_EQUAL(sizeof(PSVRuntimeInfo4), PSVRuntimeInfo_size); + PSVRuntimeInfo4 *PSVInfo = + const_cast((const PSVRuntimeInfo4 *)PSVPtr); VERIFY_ARE_EQUAL(2u, PSVInfo->SigInputElements); PSVPtr += PSVRuntimeInfo_size / 4; uint32_t *ResourceCountPtr = const_cast(PSVPtr++); @@ -6793,7 +6794,7 @@ TEST_F(ValidationTest, WrongPSVVersion) { CheckOperationResultMsgs( p60WithPSV68Result, {"DXIL container mismatch for 'PSVRuntimeInfoSize' between 'PSV0' " - "part:('52') and DXIL module:('24')"}, + "part:('56') and DXIL module:('24')"}, /*maySucceedAnyway*/ false, /*bRegex*/ false); // Create a new Blob. @@ -6811,6 +6812,6 @@ TEST_F(ValidationTest, WrongPSVVersion) { CheckOperationResultMsgs( p68WithPSV60Result, {"DXIL container mismatch for 'PSVRuntimeInfoSize' between 'PSV0' " - "part:('24') and DXIL module:('52')"}, + "part:('24') and DXIL module:('56')"}, /*maySucceedAnyway*/ false, /*bRegex*/ false); }