Skip to content

Commit f2fc2ac

Browse files
Fixed issues with ray tracing shader type conversion
1 parent b9d8f66 commit f2fc2ac

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

Graphics/GraphicsEngineD3D12/src/PipelineStateD3D12Impl.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,13 @@ void BuildRTPipelineDescription(const RayTracingPipelineStateCreateInfo& CreateI
151151
auto& ShaderIndex = ShaderIndices[StageIdx];
152152

153153
// shaders must be in same order as in ExtractShaders()
154-
VERIFY_EXPR(Stage.Shaders[ShaderIndex] == pShader);
154+
RefCntAutoPtr<ShaderD3D12Impl> pShaderD3D12{pShader, ShaderD3D12Impl::IID_InternalImpl};
155+
VERIFY(pShaderD3D12, "Unexpected shader object implementation");
156+
VERIFY_EXPR(Stage.Shaders[ShaderIndex] == pShaderD3D12);
155157

156-
auto& LibDesc = *TempPool.Construct<D3D12_DXIL_LIBRARY_DESC>();
157-
auto& ExportDesc = *TempPool.Construct<D3D12_EXPORT_DESC>();
158-
const auto* pShaderD3D12 = ClassPtrCast<ShaderD3D12Impl>(pShader);
159-
const auto& pBlob = Stage.ByteCodes[ShaderIndex];
158+
auto& LibDesc = *TempPool.Construct<D3D12_DXIL_LIBRARY_DESC>();
159+
auto& ExportDesc = *TempPool.Construct<D3D12_EXPORT_DESC>();
160+
const auto& pBlob = Stage.ByteCodes[ShaderIndex];
160161
++ShaderIndex;
161162

162163
LibDesc.DXILLibrary.BytecodeLength = pBlob->GetBufferSize();

Graphics/GraphicsEngineVulkan/src/PipelineStateVkImpl.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,10 @@ std::vector<VkRayTracingShaderGroupCreateInfoKHR> BuildRTShaderGroupDescription(
370370
if (pShader == nullptr)
371371
return VK_SHADER_UNUSED_KHR;
372372

373-
const auto ShaderType = pShader->GetDesc().ShaderType;
373+
RefCntAutoPtr<ShaderVkImpl> pShaderVk{const_cast<IShader*>(pShader), ShaderVkImpl::IID_InternalImpl};
374+
VERIFY(pShaderVk, "Unexpected shader object implementation");
375+
376+
const auto ShaderType = pShaderVk->GetDesc().ShaderType;
374377
// Shader modules are initialized in the same order by InitPipelineShaderStages().
375378
uint32_t idx = 0;
376379
for (const auto& Stage : ShaderStages)
@@ -379,18 +382,18 @@ std::vector<VkRayTracingShaderGroupCreateInfoKHR> BuildRTShaderGroupDescription(
379382
{
380383
for (Uint32 i = 0; i < Stage.Shaders.size(); ++i, ++idx)
381384
{
382-
if (Stage.Shaders[i] == pShader)
385+
if (Stage.Shaders[i] == pShaderVk)
383386
return idx;
384387
}
385-
UNEXPECTED("Unable to find shader '", pShader->GetDesc().Name, "' in the shader stage. This should never happen and is a bug.");
388+
UNEXPECTED("Unable to find shader '", pShaderVk->GetDesc().Name, "' in the shader stage. This should never happen and is a bug.");
386389
return VK_SHADER_UNUSED_KHR;
387390
}
388391
else
389392
{
390393
idx += static_cast<Uint32>(Stage.Count());
391394
}
392395
}
393-
UNEXPECTED("Unable to find corresponding shader stage for shader '", pShader->GetDesc().Name, "'. This should never happen and is a bug.");
396+
UNEXPECTED("Unable to find corresponding shader stage for shader '", pShaderVk->GetDesc().Name, "'. This should never happen and is a bug.");
394397
return VK_SHADER_UNUSED_KHR;
395398
};
396399

0 commit comments

Comments
 (0)