Skip to content

Commit 7870aa1

Browse files
ReloadablePipeline: use PSO create info wrappers from GraphicsTypesX
1 parent eb7bd97 commit 7870aa1

File tree

6 files changed

+161
-200
lines changed

6 files changed

+161
-200
lines changed

Graphics/GraphicsEngine/interface/GraphicsTypesX.hpp

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,7 +1756,7 @@ struct RayTracingPipelineStateCreateInfoX : PipelineStateCreateInfoX<RayTracingP
17561756
if (ProceduralHitShaderCount != 0)
17571757
ProceduralHitShaders.assign(pProceduralHitShaders, pProceduralHitShaders + ProceduralHitShaderCount);
17581758

1759-
SyncDesc(true);
1759+
SyncDesc(SYNC_FLAG_ALL);
17601760
}
17611761

17621762
RayTracingPipelineStateCreateInfoX(const std::initializer_list<RayTracingGeneralShaderGroup>& _GeneralShaders,
@@ -1766,7 +1766,7 @@ struct RayTracingPipelineStateCreateInfoX : PipelineStateCreateInfoX<RayTracingP
17661766
TriangleHitShaders{_TriangleHitShaders},
17671767
ProceduralHitShaders{_ProceduralHitShaders}
17681768
{
1769-
SyncDesc(true);
1769+
SyncDesc(SYNC_FLAG_ALL);
17701770
}
17711771

17721772
RayTracingPipelineStateCreateInfoX(const RayTracingPipelineStateCreateInfoX& _DescX) :
@@ -1796,7 +1796,7 @@ struct RayTracingPipelineStateCreateInfoX : PipelineStateCreateInfoX<RayTracingP
17961796
{
17971797
GeneralShaders.push_back(GenShader);
17981798
GeneralShaders.back().Name = StringPool.emplace(GenShader.Name).first->c_str();
1799-
return SyncDesc();
1799+
return SyncDesc(SYNC_FLAG_UPDATE_SHADERS);
18001800
}
18011801

18021802
template <typename... ArgsType>
@@ -1810,7 +1810,7 @@ struct RayTracingPipelineStateCreateInfoX : PipelineStateCreateInfoX<RayTracingP
18101810
{
18111811
TriangleHitShaders.push_back(TriHitShader);
18121812
TriangleHitShaders.back().Name = StringPool.emplace(TriHitShader.Name).first->c_str();
1813-
return SyncDesc();
1813+
return SyncDesc(SYNC_FLAG_UPDATE_SHADERS);
18141814
}
18151815

18161816
template <typename... ArgsType>
@@ -1824,7 +1824,7 @@ struct RayTracingPipelineStateCreateInfoX : PipelineStateCreateInfoX<RayTracingP
18241824
{
18251825
ProceduralHitShaders.push_back(ProcHitShader);
18261826
ProceduralHitShaders.back().Name = StringPool.emplace(ProcHitShader.Name).first->c_str();
1827-
return SyncDesc();
1827+
return SyncDesc(SYNC_FLAG_UPDATE_SHADERS);
18281828
}
18291829

18301830
template <typename... ArgsType>
@@ -1860,19 +1860,19 @@ struct RayTracingPipelineStateCreateInfoX : PipelineStateCreateInfoX<RayTracingP
18601860
RayTracingPipelineStateCreateInfoX& ClearGeneralShaders()
18611861
{
18621862
GeneralShaders.clear();
1863-
return SyncDesc();
1863+
return SyncDesc(SYNC_FLAG_UPDATE_SHADERS);
18641864
}
18651865

18661866
RayTracingPipelineStateCreateInfoX& ClearTriangleHitShaders()
18671867
{
18681868
TriangleHitShaders.clear();
1869-
return SyncDesc();
1869+
return SyncDesc(SYNC_FLAG_UPDATE_SHADERS);
18701870
}
18711871

18721872
RayTracingPipelineStateCreateInfoX& ClearProceduralHitShaders()
18731873
{
18741874
ProceduralHitShaders.clear();
1875-
return SyncDesc();
1875+
return SyncDesc(SYNC_FLAG_UPDATE_SHADERS);
18761876
}
18771877

18781878
RayTracingPipelineStateCreateInfoX& Clear()
@@ -1883,7 +1883,14 @@ struct RayTracingPipelineStateCreateInfoX : PipelineStateCreateInfoX<RayTracingP
18831883
}
18841884

18851885
private:
1886-
RayTracingPipelineStateCreateInfoX& SyncDesc(bool UpdateStrings = false)
1886+
enum SYNC_FLAGS : Uint32
1887+
{
1888+
SYNC_FLAG_NONE = 0u,
1889+
SYNC_FLAG_UPDATE_STRINGS = 1u << 0u,
1890+
SYNC_FLAG_UPDATE_SHADERS = 1u << 1u,
1891+
SYNC_FLAG_ALL = SYNC_FLAG_UPDATE_STRINGS | SYNC_FLAG_UPDATE_SHADERS
1892+
};
1893+
RayTracingPipelineStateCreateInfoX& SyncDesc(SYNC_FLAGS UpdateFlags)
18871894
{
18881895
GeneralShaderCount = static_cast<Uint32>(GeneralShaders.size());
18891896
pGeneralShaders = GeneralShaderCount > 0 ? GeneralShaders.data() : nullptr;
@@ -1894,7 +1901,7 @@ struct RayTracingPipelineStateCreateInfoX : PipelineStateCreateInfoX<RayTracingP
18941901
ProceduralHitShaderCount = static_cast<Uint32>(ProceduralHitShaders.size());
18951902
pProceduralHitShaders = ProceduralHitShaderCount > 0 ? ProceduralHitShaders.data() : nullptr;
18961903

1897-
if (UpdateStrings)
1904+
if ((UpdateFlags & SYNC_FLAG_UPDATE_STRINGS) != 0)
18981905
{
18991906
for (auto& Shader : GeneralShaders)
19001907
Shader.Name = StringPool.emplace(Shader.Name).first->c_str();
@@ -1909,6 +1916,37 @@ struct RayTracingPipelineStateCreateInfoX : PipelineStateCreateInfoX<RayTracingP
19091916
pShaderRecordName = StringPool.emplace(pShaderRecordName).first->c_str();
19101917
}
19111918

1919+
if ((UpdateFlags & SYNC_FLAG_UPDATE_SHADERS) != 0)
1920+
{
1921+
// Keep current shader objects alive
1922+
std::vector<RefCntAutoPtr<IShader>> OldShaderObjects = std::move(ShaderObjects);
1923+
1924+
ShaderObjects.clear();
1925+
for (auto& Shader : GeneralShaders)
1926+
{
1927+
if (Shader.pShader != nullptr)
1928+
ShaderObjects.emplace_back(Shader.pShader);
1929+
}
1930+
1931+
for (auto& Shader : TriangleHitShaders)
1932+
{
1933+
if (Shader.pClosestHitShader != nullptr)
1934+
ShaderObjects.emplace_back(Shader.pClosestHitShader);
1935+
if (Shader.pAnyHitShader != nullptr)
1936+
ShaderObjects.emplace_back(Shader.pAnyHitShader);
1937+
}
1938+
1939+
for (auto& Shader : ProceduralHitShaders)
1940+
{
1941+
if (Shader.pIntersectionShader != nullptr)
1942+
ShaderObjects.emplace_back(Shader.pIntersectionShader);
1943+
if (Shader.pClosestHitShader != nullptr)
1944+
ShaderObjects.emplace_back(Shader.pClosestHitShader);
1945+
if (Shader.pAnyHitShader != nullptr)
1946+
ShaderObjects.emplace_back(Shader.pAnyHitShader);
1947+
}
1948+
}
1949+
19121950
return *this;
19131951
}
19141952

@@ -1923,12 +1961,13 @@ struct RayTracingPipelineStateCreateInfoX : PipelineStateCreateInfoX<RayTracingP
19231961
else
19241962
++it;
19251963
}
1926-
return SyncDesc();
1964+
return SyncDesc(SYNC_FLAG_UPDATE_SHADERS);
19271965
}
19281966

19291967
std::vector<RayTracingGeneralShaderGroup> GeneralShaders;
19301968
std::vector<RayTracingTriangleHitShaderGroup> TriangleHitShaders;
19311969
std::vector<RayTracingProceduralHitShaderGroup> ProceduralHitShaders;
1970+
std::vector<RefCntAutoPtr<IShader>> ShaderObjects;
19321971
};
19331972

19341973

Graphics/GraphicsTools/include/ReloadablePipelineState.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,6 @@ class ReloadablePipelineState final : public ObjectBase<IPipelineState>
168168
virtual ~DynamicHeapObjectBase() {}
169169
};
170170

171-
template <typename CreateInfoType>
172-
struct CreateInfoWrapperBase;
173-
174171
template <typename CreateInfoType>
175172
struct CreateInfoWrapper;
176173

Graphics/GraphicsTools/include/RenderStateCacheImpl.hpp

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -152,59 +152,4 @@ class RenderStateCacheImpl final : public ObjectBase<IRenderStateCache>
152152
std::unordered_map<UniqueIdentifier, RefCntWeakPtr<IPipelineState>> m_ReloadablePipelines;
153153
};
154154

155-
template <typename HandlerType>
156-
void ProcessPsoCreateInfoShaders(GraphicsPipelineStateCreateInfo& CI, HandlerType&& Handler)
157-
{
158-
Handler(CI.pVS);
159-
Handler(CI.pPS);
160-
Handler(CI.pDS);
161-
Handler(CI.pHS);
162-
Handler(CI.pGS);
163-
Handler(CI.pAS);
164-
Handler(CI.pMS);
165-
}
166-
167-
template <typename HandlerType>
168-
void ProcessPsoCreateInfoShaders(ComputePipelineStateCreateInfo& CI, HandlerType&& Handler)
169-
{
170-
Handler(CI.pCS);
171-
}
172-
173-
template <typename HandlerType>
174-
void ProcessPsoCreateInfoShaders(TilePipelineStateCreateInfo& CI, HandlerType&& Handler)
175-
{
176-
Handler(CI.pTS);
177-
}
178-
179-
template <typename HandlerType>
180-
void ProcessPsoCreateInfoShaders(RayTracingPipelineStateCreateInfo& CI, HandlerType&& Handler)
181-
{
182-
}
183-
184-
template <typename HandlerType>
185-
void ProcessRtPsoCreateInfoShaders(
186-
std::vector<RayTracingGeneralShaderGroup>& pGeneralShaders,
187-
std::vector<RayTracingTriangleHitShaderGroup>& pTriangleHitShaders,
188-
std::vector<RayTracingProceduralHitShaderGroup>& pProceduralHitShaders,
189-
HandlerType&& Handler)
190-
{
191-
for (auto& GeneralShader : pGeneralShaders)
192-
{
193-
Handler(GeneralShader.pShader);
194-
}
195-
196-
for (auto& TriHitShader : pTriangleHitShaders)
197-
{
198-
Handler(TriHitShader.pAnyHitShader);
199-
Handler(TriHitShader.pClosestHitShader);
200-
}
201-
202-
for (auto& ProcHitShader : pProceduralHitShaders)
203-
{
204-
Handler(ProcHitShader.pAnyHitShader);
205-
Handler(ProcHitShader.pClosestHitShader);
206-
Handler(ProcHitShader.pIntersectionShader);
207-
}
208-
}
209-
210155
} // namespace Diligent

Graphics/GraphicsTools/src/ReloadablePipelineState.cpp

Lines changed: 15 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include "RenderStateCacheImpl.hpp"
3333
#include "ReloadableShader.hpp"
34+
#include "GraphicsTypesX.hpp"
3435

3536
namespace Diligent
3637
{
@@ -39,43 +40,24 @@ constexpr INTERFACE_ID ReloadablePipelineState::IID_InternalImpl;
3940

4041

4142
template <typename CreateInfoType>
42-
struct ReloadablePipelineState::CreateInfoWrapperBase : DynamicHeapObjectBase
43+
struct ReloadablePipelineState::CreateInfoWrapper : DynamicHeapObjectBase
4344
{
44-
CreateInfoWrapperBase(const CreateInfoType& CI) :
45-
m_CI{CI},
46-
m_Variables{CI.PSODesc.ResourceLayout.Variables, CI.PSODesc.ResourceLayout.Variables + CI.PSODesc.ResourceLayout.NumVariables},
47-
m_ImtblSamplers{CI.PSODesc.ResourceLayout.ImmutableSamplers, CI.PSODesc.ResourceLayout.ImmutableSamplers + CI.PSODesc.ResourceLayout.NumImmutableSamplers},
48-
m_ppSignatures{CI.ppResourceSignatures, CI.ppResourceSignatures + CI.ResourceSignaturesCount}
45+
CreateInfoWrapper(const CreateInfoType& CI) :
46+
m_CI{CI}
4947
{
50-
if (CI.PSODesc.Name != nullptr)
51-
m_CI.PSODesc.Name = m_Strings.emplace(CI.PSODesc.Name).first->c_str();
48+
ProcessPipelineStateCreateInfoShaders(static_cast<const CreateInfoType&>(m_CI), [](IShader* pShader) {
49+
if (pShader == nullptr)
50+
return;
5251

53-
for (auto& Var : m_Variables)
54-
Var.Name = m_Strings.emplace(Var.Name).first->c_str();
55-
for (auto& ImtblSam : m_ImtblSamplers)
56-
ImtblSam.SamplerOrTextureName = m_Strings.emplace(ImtblSam.SamplerOrTextureName).first->c_str();
57-
58-
m_CI.PSODesc.ResourceLayout.Variables = m_Variables.data();
59-
m_CI.PSODesc.ResourceLayout.ImmutableSamplers = m_ImtblSamplers.data();
60-
61-
m_CI.ppResourceSignatures = !m_ppSignatures.empty() ? m_ppSignatures.data() : nullptr;
62-
for (auto* pSign : m_ppSignatures)
63-
m_Objects.emplace_back(pSign);
64-
65-
m_Objects.emplace_back(m_CI.pPSOCache);
66-
67-
// Replace shaders with reloadable shaders
68-
ProcessPsoCreateInfoShaders(m_CI,
69-
[&](IShader*& pShader) {
70-
AddShader(pShader);
71-
});
52+
if (!RefCntAutoPtr<IShader>{pShader, ReloadableShader::IID_InternalImpl})
53+
{
54+
const auto* Name = pShader->GetDesc().Name;
55+
LOG_WARNING_MESSAGE("Shader '", (Name ? Name : "<unnamed>"),
56+
"' is not a reloadable shader. To enable hot pipeline state reload, all shaders must be created through the render state cache.");
57+
}
58+
});
7259
}
7360

74-
CreateInfoWrapperBase(const CreateInfoWrapperBase&) = delete;
75-
CreateInfoWrapperBase(CreateInfoWrapperBase&&) = delete;
76-
CreateInfoWrapperBase& operator=(const CreateInfoWrapperBase&) = delete;
77-
CreateInfoWrapperBase& operator=(CreateInfoWrapperBase&&) = delete;
78-
7961
const CreateInfoType& Get() const
8062
{
8163
return m_CI;
@@ -91,98 +73,8 @@ struct ReloadablePipelineState::CreateInfoWrapperBase : DynamicHeapObjectBase
9173
return m_CI;
9274
}
9375

94-
void AddShader(IShader* pShader)
95-
{
96-
if (pShader == nullptr)
97-
return;
98-
99-
if (!RefCntAutoPtr<IShader>{pShader, ReloadableShader::IID_InternalImpl})
100-
{
101-
const auto* Name = pShader->GetDesc().Name;
102-
LOG_WARNING_MESSAGE("Shader '", (Name ? Name : "<unnamed>"),
103-
"' is not a reloadable shader. To enable hot pipeline state reload, all shaders must be created through the render state cache.");
104-
}
105-
106-
m_Objects.emplace_back(pShader);
107-
}
108-
10976
protected:
110-
CreateInfoType m_CI;
111-
112-
std::unordered_set<std::string> m_Strings;
113-
std::vector<ShaderResourceVariableDesc> m_Variables;
114-
std::vector<ImmutableSamplerDesc> m_ImtblSamplers;
115-
std::vector<IPipelineResourceSignature*> m_ppSignatures;
116-
std::vector<RefCntAutoPtr<IObject>> m_Objects;
117-
};
118-
119-
120-
template <>
121-
struct ReloadablePipelineState::CreateInfoWrapper<GraphicsPipelineStateCreateInfo> : CreateInfoWrapperBase<GraphicsPipelineStateCreateInfo>
122-
{
123-
CreateInfoWrapper(const GraphicsPipelineStateCreateInfo& CI) :
124-
CreateInfoWrapperBase<GraphicsPipelineStateCreateInfo>{CI},
125-
m_LayoutElements{CI.GraphicsPipeline.InputLayout.LayoutElements, CI.GraphicsPipeline.InputLayout.LayoutElements + CI.GraphicsPipeline.InputLayout.NumElements}
126-
{
127-
m_Objects.emplace_back(CI.GraphicsPipeline.pRenderPass);
128-
129-
for (auto& Elem : m_LayoutElements)
130-
Elem.HLSLSemantic = m_Strings.emplace(Elem.HLSLSemantic != nullptr ? Elem.HLSLSemantic : LayoutElement{}.HLSLSemantic).first->c_str();
131-
132-
m_CI.GraphicsPipeline.InputLayout.LayoutElements = m_LayoutElements.data();
133-
}
134-
135-
private:
136-
std::vector<LayoutElement> m_LayoutElements;
137-
};
138-
139-
template <>
140-
struct ReloadablePipelineState::CreateInfoWrapper<ComputePipelineStateCreateInfo> : CreateInfoWrapperBase<ComputePipelineStateCreateInfo>
141-
{
142-
CreateInfoWrapper(const ComputePipelineStateCreateInfo& CI) :
143-
CreateInfoWrapperBase<ComputePipelineStateCreateInfo>{CI}
144-
{
145-
}
146-
};
147-
148-
template <>
149-
struct ReloadablePipelineState::CreateInfoWrapper<TilePipelineStateCreateInfo> : CreateInfoWrapperBase<TilePipelineStateCreateInfo>
150-
{
151-
CreateInfoWrapper(const TilePipelineStateCreateInfo& CI) :
152-
CreateInfoWrapperBase<TilePipelineStateCreateInfo>{CI}
153-
{
154-
}
155-
};
156-
157-
template <>
158-
struct ReloadablePipelineState::CreateInfoWrapper<RayTracingPipelineStateCreateInfo> : CreateInfoWrapperBase<RayTracingPipelineStateCreateInfo>
159-
{
160-
CreateInfoWrapper(const RayTracingPipelineStateCreateInfo& CI) :
161-
CreateInfoWrapperBase<RayTracingPipelineStateCreateInfo>{CI},
162-
// clang-format off
163-
m_pGeneralShaders {CI.pGeneralShaders, CI.pGeneralShaders + CI.GeneralShaderCount},
164-
m_pTriangleHitShaders {CI.pTriangleHitShaders, CI.pTriangleHitShaders + CI.TriangleHitShaderCount},
165-
m_pProceduralHitShaders{CI.pProceduralHitShaders, CI.pProceduralHitShaders + CI.ProceduralHitShaderCount}
166-
// clang-format on
167-
{
168-
m_CI.pGeneralShaders = m_pGeneralShaders.data();
169-
m_CI.pTriangleHitShaders = m_pTriangleHitShaders.data();
170-
m_CI.pProceduralHitShaders = m_pProceduralHitShaders.data();
171-
172-
if (m_CI.pShaderRecordName != nullptr)
173-
m_CI.pShaderRecordName = m_Strings.emplace(m_CI.pShaderRecordName).first->c_str();
174-
175-
// Replace shaders with reloadable shaders
176-
ProcessRtPsoCreateInfoShaders(m_pGeneralShaders, m_pTriangleHitShaders, m_pProceduralHitShaders,
177-
[&](IShader*& pShader) {
178-
AddShader(pShader);
179-
});
180-
}
181-
182-
private:
183-
std::vector<RayTracingGeneralShaderGroup> m_pGeneralShaders;
184-
std::vector<RayTracingTriangleHitShaderGroup> m_pTriangleHitShaders;
185-
std::vector<RayTracingProceduralHitShaderGroup> m_pProceduralHitShaders;
77+
typename PipelineStateCreateInfoXTraits<CreateInfoType>::CreateInfoXType m_CI;
18678
};
18779

18880
ReloadablePipelineState::ReloadablePipelineState(IReferenceCounters* pRefCounters,

0 commit comments

Comments
 (0)