@@ -118,24 +118,25 @@ class PrimitiveTopology_To_D3D12_PRIMITIVE_TOPOLOGY_TYPE
118118 std::array<D3D12_PRIMITIVE_TOPOLOGY_TYPE, PRIMITIVE_TOPOLOGY_NUM_TOPOLOGIES> m_Map;
119119};
120120
121- template <typename TShaderStages>
122121void BuildRTPipelineDescription (const RayTracingPipelineStateCreateInfo& CreateInfo,
123122 std::vector<D3D12_STATE_SUBOBJECT>& Subobjects,
124123 DynamicLinearAllocator& TempPool,
125- TShaderStages& ShaderStages) noexcept (false )
124+ PipelineStateD3D12Impl:: TShaderStages& ShaderStages) noexcept (false )
126125{
127126#define LOG_PSO_ERROR_AND_THROW (...) LOG_ERROR_AND_THROW(" Description of ray tracing PSO '" , (CreateInfo.PSODesc.Name ? CreateInfo.PSODesc.Name : " " ), " ' is invalid: " , ##__VA_ARGS__)
128127
129128 Uint32 UnnamedExportIndex = 0 ;
130129
131130 std::unordered_map<IShader*, LPCWSTR> UniqueShaders;
132131
133- std::array<typename TShaderStages::value_type*, MAX_SHADERS_IN_PIPELINE> StagesPtr = {};
134- std::array<Uint32, MAX_SHADERS_IN_PIPELINE> ShaderIndices = {};
132+ using ShaderStageInfo = PipelineStateD3D12Impl::ShaderStageInfo;
135133
136- for (auto & Stage : ShaderStages)
134+ std::array<ShaderStageInfo*, MAX_SHADERS_IN_PIPELINE> StagesPtr = {};
135+ std::array<Uint32, MAX_SHADERS_IN_PIPELINE> ShaderIndices = {};
136+
137+ for (ShaderStageInfo& Stage : ShaderStages)
137138 {
138- const auto Idx = GetShaderTypePipelineIndex (Stage.Type , PIPELINE_TYPE_RAY_TRACING);
139+ const Int32 Idx = GetShaderTypePipelineIndex (Stage.Type , PIPELINE_TYPE_RAY_TRACING);
139140 VERIFY_EXPR (StagesPtr[Idx] == nullptr );
140141 StagesPtr[Idx] = &Stage;
141142 }
@@ -147,18 +148,18 @@ void BuildRTPipelineDescription(const RayTracingPipelineStateCreateInfo& CreateI
147148 auto it_inserted = UniqueShaders.emplace (pShader, nullptr );
148149 if (it_inserted.second )
149150 {
150- const auto StageIdx = GetShaderTypePipelineIndex (pShader->GetDesc ().ShaderType , PIPELINE_TYPE_RAY_TRACING);
151- const auto & Stage = *StagesPtr[StageIdx];
152- auto & ShaderIndex = ShaderIndices[StageIdx];
151+ const Int32 StageIdx = GetShaderTypePipelineIndex (pShader->GetDesc ().ShaderType , PIPELINE_TYPE_RAY_TRACING);
152+ const ShaderStageInfo & Stage = *StagesPtr[StageIdx];
153+ Uint32& ShaderIndex = ShaderIndices[StageIdx];
153154
154155 // shaders must be in same order as in ExtractShaders()
155156 RefCntAutoPtr<ShaderD3D12Impl> pShaderD3D12{pShader, ShaderD3D12Impl::IID_InternalImpl};
156157 VERIFY (pShaderD3D12, " Unexpected shader object implementation" );
157158 VERIFY_EXPR (Stage.Shaders [ShaderIndex] == pShaderD3D12);
158159
159- auto & LibDesc = *TempPool.Construct <D3D12_DXIL_LIBRARY_DESC>();
160- auto & ExportDesc = *TempPool.Construct <D3D12_EXPORT_DESC>();
161- const IDataBlob* pByteCode = Stage.ByteCodes [ShaderIndex];
160+ D3D12_DXIL_LIBRARY_DESC& LibDesc = *TempPool.Construct <D3D12_DXIL_LIBRARY_DESC>();
161+ D3D12_EXPORT_DESC& ExportDesc = *TempPool.Construct <D3D12_EXPORT_DESC>();
162+ const IDataBlob* pByteCode = Stage.ByteCodes [ShaderIndex];
162163 ++ShaderIndex;
163164
164165 LibDesc.DXILLibrary .BytecodeLength = pByteCode->GetSize ();
@@ -189,15 +190,15 @@ void BuildRTPipelineDescription(const RayTracingPipelineStateCreateInfo& CreateI
189190
190191 for (Uint32 i = 0 ; i < CreateInfo.GeneralShaderCount ; ++i)
191192 {
192- const auto & GeneralShader = CreateInfo.pGeneralShaders [i];
193+ const RayTracingGeneralShaderGroup & GeneralShader = CreateInfo.pGeneralShaders [i];
193194 AddDxilLib (GeneralShader.pShader , GeneralShader.Name );
194195 }
195196
196197 for (Uint32 i = 0 ; i < CreateInfo.TriangleHitShaderCount ; ++i)
197198 {
198- const auto & TriHitShader = CreateInfo.pTriangleHitShaders [i];
199+ const RayTracingTriangleHitShaderGroup & TriHitShader = CreateInfo.pTriangleHitShaders [i];
199200
200- auto & HitGroupDesc = *TempPool.Construct <D3D12_HIT_GROUP_DESC>();
201+ D3D12_HIT_GROUP_DESC & HitGroupDesc = *TempPool.Construct <D3D12_HIT_GROUP_DESC>();
201202 HitGroupDesc.HitGroupExport = TempPool.CopyWString (TriHitShader.Name );
202203 HitGroupDesc.Type = D3D12_HIT_GROUP_TYPE_TRIANGLES;
203204 HitGroupDesc.ClosestHitShaderImport = AddDxilLib (TriHitShader.pClosestHitShader , nullptr );
@@ -209,9 +210,9 @@ void BuildRTPipelineDescription(const RayTracingPipelineStateCreateInfo& CreateI
209210
210211 for (Uint32 i = 0 ; i < CreateInfo.ProceduralHitShaderCount ; ++i)
211212 {
212- const auto & ProcHitShader = CreateInfo.pProceduralHitShaders [i];
213+ const RayTracingProceduralHitShaderGroup & ProcHitShader = CreateInfo.pProceduralHitShaders [i];
213214
214- auto & HitGroupDesc = *TempPool.Construct <D3D12_HIT_GROUP_DESC>();
215+ D3D12_HIT_GROUP_DESC & HitGroupDesc = *TempPool.Construct <D3D12_HIT_GROUP_DESC>();
215216 HitGroupDesc.HitGroupExport = TempPool.CopyWString (ProcHitShader.Name );
216217 HitGroupDesc.Type = D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE;
217218 HitGroupDesc.ClosestHitShaderImport = AddDxilLib (ProcHitShader.pClosestHitShader , nullptr );
@@ -223,14 +224,14 @@ void BuildRTPipelineDescription(const RayTracingPipelineStateCreateInfo& CreateI
223224
224225 constexpr Uint32 DefaultPayloadSize = sizeof (float ) * 8 ;
225226
226- auto & PipelineConfig = *TempPool.Construct <D3D12_RAYTRACING_PIPELINE_CONFIG>();
227+ D3D12_RAYTRACING_PIPELINE_CONFIG & PipelineConfig = *TempPool.Construct <D3D12_RAYTRACING_PIPELINE_CONFIG>();
227228
228229 PipelineConfig.MaxTraceRecursionDepth = CreateInfo.RayTracingPipeline .MaxRecursionDepth ;
229230 Subobjects.push_back ({D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_PIPELINE_CONFIG, &PipelineConfig});
230231
231- auto & ShaderConfig = *TempPool.Construct <D3D12_RAYTRACING_SHADER_CONFIG>();
232- ShaderConfig.MaxAttributeSizeInBytes = CreateInfo.MaxAttributeSize == 0 ? D3D12_RAYTRACING_MAX_ATTRIBUTE_SIZE_IN_BYTES : CreateInfo.MaxAttributeSize ;
233- ShaderConfig.MaxPayloadSizeInBytes = CreateInfo.MaxPayloadSize == 0 ? DefaultPayloadSize : CreateInfo.MaxPayloadSize ;
232+ D3D12_RAYTRACING_SHADER_CONFIG & ShaderConfig = *TempPool.Construct <D3D12_RAYTRACING_SHADER_CONFIG>();
233+ ShaderConfig.MaxAttributeSizeInBytes = CreateInfo.MaxAttributeSize == 0 ? D3D12_RAYTRACING_MAX_ATTRIBUTE_SIZE_IN_BYTES : CreateInfo.MaxAttributeSize ;
234+ ShaderConfig.MaxPayloadSizeInBytes = CreateInfo.MaxPayloadSize == 0 ? DefaultPayloadSize : CreateInfo.MaxPayloadSize ;
234235 Subobjects.push_back ({D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_SHADER_CONFIG, &ShaderConfig});
235236#undef LOG_PSO_ERROR_AND_THROW
236237}
0 commit comments