@@ -555,7 +555,8 @@ void VerifyResourceMerge(const char* PSOName,
555555PipelineStateVkImpl::ShaderStageInfo::ShaderStageInfo (const ShaderVkImpl* pShader) :
556556 Type{pShader->GetDesc ().ShaderType },
557557 Shaders{pShader},
558- SPIRVs{pShader->GetSPIRV ()}
558+ SPIRVs{pShader->GetSPIRV ()},
559+ ShaderResources{pShader->GetShaderResources ()}
559560{}
560561
561562void PipelineStateVkImpl::ShaderStageInfo::Append (const ShaderVkImpl* pShader)
@@ -578,6 +579,7 @@ void PipelineStateVkImpl::ShaderStageInfo::Append(const ShaderVkImpl* pShader)
578579 }
579580 Shaders.push_back (pShader);
580581 SPIRVs.push_back (pShader->GetSPIRV ());
582+ ShaderResources.push_back (pShader->GetShaderResources ());
581583}
582584
583585size_t PipelineStateVkImpl::ShaderStageInfo::Count () const
@@ -691,16 +693,19 @@ void PipelineStateVkImpl::RemapOrVerifyShaderResources(
691693 {
692694 const std::vector<const ShaderVkImpl*>& Shaders = ShaderStages[s].Shaders ;
693695 std::vector<std::vector<uint32_t >>& SPIRVs = ShaderStages[s].SPIRVs ;
694- const SHADER_TYPE ShaderType = ShaderStages[s].Type ;
696+ const SHADER_TYPE ShaderType = ShaderStages[s].Type ;
697+ std::vector<std::shared_ptr<const SPIRVShaderResources>>& ShaderResources = ShaderStages[s].ShaderResources ;
695698
696699 VERIFY_EXPR (Shaders.size () == SPIRVs.size ());
700+ VERIFY_EXPR (Shaders.size () == ShaderResources.size ());
697701
698702 for (size_t i = 0 ; i < Shaders.size (); ++i)
699703 {
700704 const ShaderVkImpl* pShader = Shaders[i];
701705 std::vector<uint32_t >& SPIRV = SPIRVs[i];
706+ const std::shared_ptr<const SPIRVShaderResources>& pShaderResources = ShaderResources[i];
702707
703- const auto & pShaderResources = pShader->GetShaderResources ();
708+ // const auto& pShaderResources = pShader->GetShaderResources();
704709 VERIFY_EXPR (pShaderResources);
705710
706711 if (pDvpShaderResources)
@@ -867,7 +872,7 @@ bool PipelineStateVkImpl::InitPushConstantInfoFromSignatures(PushConstantInfoVk&
867872 PushConstant.StageFlags |= ShaderTypeToVkShaderStageFlagBit (ShaderType);
868873 }
869874
870- // Found push constant from SPIR-V, we're done
875+ // Found push constant from SPIR-V
871876 return true ;
872877 }
873878 }
@@ -904,7 +909,7 @@ bool PipelineStateVkImpl::InitPushConstantInfoFromSignatures(PushConstantInfoVk&
904909 PushConstant.StageFlags |= ShaderTypeToVkShaderStageFlagBit (ShaderType);
905910 }
906911
907- // Found first inline constant, we're done
912+ // Found first inline constant
908913 return true ;
909914 }
910915 }
@@ -937,7 +942,7 @@ void PipelineStateVkImpl::InitPipelineLayout(const PipelineStateCreateInfo& Crea
937942
938943 // If we promoted an inline constant as push constant (not an existing SPIR-V push constant),
939944 // convert the uniform buffer to push constant in SPIRV bytecode.
940- PatchShaderConvertUniformBufferToPushConstant (ShaderStages, PushConstant );
945+ PatchShaderConvertUniformBufferToPushConstant (PushConstant, ShaderStages );
941946
942947 const bool RemapResources = (CreateInfo.Flags & PSO_CREATE_FLAG_DONT_REMAP_SHADER_RESOURCES) == 0 ;
943948 const bool VerifyBindings = !RemapResources && ((InternalFlags & PSO_CREATE_INTERNAL_FLAG_NO_SHADER_REFLECTION) == 0 );
@@ -965,8 +970,8 @@ void PipelineStateVkImpl::InitPipelineLayout(const PipelineStateCreateInfo& Crea
965970 }
966971}
967972
968- void PipelineStateVkImpl::PatchShaderConvertUniformBufferToPushConstant (TShaderStages& ShaderStages,
969- const PushConstantInfoVk& PushConstantInfo ) const noexcept (false )
973+ void PipelineStateVkImpl::PatchShaderConvertUniformBufferToPushConstant (const PushConstantInfoVk& PushConstantInfo,
974+ TShaderStages& ShaderStages ) const noexcept (false )
970975{
971976 // If no push constant was selected, no patching needed
972977 if (PushConstantInfo.SignatureIndex == INVALID_PUSH_CONSTANT_INDEX ||
@@ -987,43 +992,50 @@ void PipelineStateVkImpl::PatchShaderConvertUniformBufferToPushConstant(TShaderS
987992 for (size_t i = 0 ; i < Stage.Shaders .size (); ++i)
988993 {
989994 ShaderVkImpl* pShader = const_cast <ShaderVkImpl*>(Stage.Shaders [i]);
990- const SPIRVShaderResources* pRes = pShader->GetShaderResources ().get ();
991-
992- if (pRes == nullptr )
993- continue ;
994995
995996 // First check if the shader already has this as push constant
996997 bool AlreadyPushConstant = false ;
997- for (Uint32 pc = 0 ; pc < pRes->GetNumPushConstants (); ++pc)
998+
999+ // Check if this shader has a uniform buffer with the push constant name
1000+ bool ShouldPatchUniformBuffer = false ;
9981001 {
999- const SPIRVShaderResourceAttribs& PCAttribs = pRes->GetPushConstant (pc);
1000- if (PCAttribs.Name == PushConstantName)
1002+ const SPIRVShaderResources* pShaderRes = pShader->GetShaderResources ().get ();
1003+
1004+ if (pShaderRes == nullptr )
1005+ continue ;
1006+
1007+ for (Uint32 pc = 0 ; pc < pShaderRes->GetNumPushConstants (); ++pc)
1008+ {
1009+ const SPIRVShaderResourceAttribs& PCAttribs = pShaderRes->GetPushConstant (pc);
1010+ if (PCAttribs.Name == PushConstantName)
1011+ {
1012+ AlreadyPushConstant = true ;
1013+ break ;
1014+ }
1015+ }
1016+
1017+ if (!AlreadyPushConstant)
10011018 {
1002- AlreadyPushConstant = true ;
1003- break ;
1019+ for (Uint32 ub = 0 ; ub < pShaderRes->GetNumUBs (); ++ub)
1020+ {
1021+ const SPIRVShaderResourceAttribs& UBAttribs = pShaderRes->GetUB (ub);
1022+ if (UBAttribs.Name == PushConstantName)
1023+ {
1024+ ShouldPatchUniformBuffer = true ;
1025+ break ;
1026+ }
1027+ }
10041028 }
10051029 }
10061030
10071031 // If already push constant, no conversion needed
10081032 if (AlreadyPushConstant)
10091033 continue ;
10101034
1011- // Check if this shader has a uniform buffer with the push constant name
1012- bool ShouldPatchUniformBuffer = false ;
1013-
1014- for (Uint32 ub = 0 ; ub < pRes->GetNumUBs (); ++ub)
1015- {
1016- const SPIRVShaderResourceAttribs& UBAttribs = pRes->GetUB (ub);
1017- if (UBAttribs.Name == PushConstantName)
1018- {
1019- ShouldPatchUniformBuffer = true ;
1020- break ;
1021- }
1022- }
1023-
10241035 if (ShouldPatchUniformBuffer)
10251036 {
10261037 const std::vector<uint32_t >& SPIRV = Stage.SPIRVs [i];
1038+
10271039 std::vector<uint32_t > PatchedSPIRV = PatchSPIRVConvertUniformBufferToPushConstant (
10281040 SPIRV,
10291041 SPV_ENV_MAX, // Auto-detect target environment
@@ -1032,11 +1044,7 @@ void PipelineStateVkImpl::PatchShaderConvertUniformBufferToPushConstant(TShaderS
10321044 if (!PatchedSPIRV.empty ())
10331045 {
10341046 Stage.SPIRVs [i] = PatchedSPIRV;
1035-
1036- pShader->SetSPIRV (PatchedSPIRV);
1037-
1038- // Reconstruct shader resources from SPIRV
1039- pShader->CreateSPIRVShaderResources ();
1047+ Stage.ShaderResources [i] = pShader->CreateSPIRVShaderResources (PatchedSPIRV);
10401048 }
10411049 else
10421050 {
0 commit comments