Skip to content

Commit b882f3e

Browse files
committed
Add PatchShaderConvertUniformBufferToPushConstant
1 parent 8d57cdf commit b882f3e

File tree

5 files changed

+167
-65
lines changed

5 files changed

+167
-65
lines changed

Graphics/GraphicsEngineVulkan/include/PipelineStateVkImpl.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ class PipelineStateVkImpl final : public PipelineStateBase<EngineVkImplTraits>
133133
void InitializePipeline(const ComputePipelineStateCreateInfo& CreateInfo);
134134
void InitializePipeline(const RayTracingPipelineStateCreateInfo& CreateInfo);
135135

136-
void InitPushConstantInfo(const TShaderStages& ShaderStages, PushConstantInfoVk& PushConstant) noexcept(false);
136+
void InitPushConstantInfoFromSignatures(PushConstantInfoVk& PushConstant) const noexcept(false);
137+
138+
void PatchShaderConvertUniformBufferToPushConstant(TShaderStages& ShaderStages) const noexcept(false);
137139

138140
// TPipelineStateBase::Construct needs access to InitializePipeline
139141
friend TPipelineStateBase;

Graphics/GraphicsEngineVulkan/include/ShaderVkImpl.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ class ShaderVkImpl final : public ShaderBase<EngineVkImplTraits>
110110
Size = m_SPIRV.size() * sizeof(m_SPIRV[0]);
111111
}
112112

113+
void SetSPIRV(const std::vector<uint32_t>& SPIRV) noexcept(false);
114+
115+
void CreateSPIRVShaderResources() noexcept(false);
116+
113117
private:
114118
void Initialize(const ShaderCreateInfo& ShaderCI,
115119
const CreateInfo& VkShaderCI) noexcept(false);
@@ -119,6 +123,8 @@ class ShaderVkImpl final : public ShaderBase<EngineVkImplTraits>
119123

120124
std::string m_EntryPoint;
121125
std::vector<uint32_t> m_SPIRV;
126+
127+
bool m_LoadConstantBufferReflection = false;
122128
};
123129

124130
} // namespace Diligent

Graphics/GraphicsEngineVulkan/src/PipelineStateVkImpl.cpp

Lines changed: 114 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@
4343
#include "EngineMemory.h"
4444
#include "StringTools.hpp"
4545

46-
#if !DILIGENT_NO_HLSL
47-
# include "SPIRVTools.hpp"
48-
#endif
46+
#include "SPIRVTools.hpp"
4947

5048
namespace Diligent
5149
{
@@ -854,33 +852,43 @@ void PipelineStateVkImpl::RemapOrVerifyShaderResources(
854852
}
855853
}
856854

857-
void PipelineStateVkImpl::InitPushConstantInfo(const TShaderStages& ShaderStages, PushConstantInfoVk& PushConstant) noexcept(false)
855+
void PipelineStateVkImpl::InitPushConstantInfoFromSignatures(PushConstantInfoVk& PushConstant) const noexcept(false)
858856
{
859-
for (const ShaderStageInfo& Stage : ShaderStages)
857+
// Iterate through all signatures to find resources marked as VULKAN_PUSH_CONSTANT
858+
for (Uint32 s = 0; s < m_SignatureCount; ++s)
860859
{
861-
for (const ShaderVkImpl* pShader : Stage.Shaders)
860+
const PipelineResourceSignatureVkImpl* pSignature = m_Signatures[s];
861+
if (pSignature == nullptr)
862+
continue;
863+
864+
for (Uint32 r = 0; r < pSignature->GetTotalResourceCount(); ++r)
862865
{
863-
const auto& pShaderResources = pShader->GetShaderResources();
864-
if (pShaderResources && pShaderResources->GetNumPushConstants() > 0)
866+
const PipelineResourceDesc& ResDesc = pSignature->GetResourceDesc(r);
867+
868+
// Check if this resource is marked as a Vulkan push constant
869+
if ((ResDesc.Flags & PIPELINE_RESOURCE_FLAG_VULKAN_PUSH_CONSTANT) != 0)
865870
{
866-
// There should be at most one push constant block per shader
867-
const SPIRVShaderResourceAttribs& PCAttribs = pShaderResources->GetPushConstant(0);
868-
const Uint32 PCSize = PCAttribs.BufferStaticSize;
871+
// For push constants, ArraySize contains the number of 32-bit constants
872+
const Uint32 PCSize = ResDesc.ArraySize * sizeof(Uint32);
869873

870874
if (PushConstant.Size == 0)
871875
{
872-
// First shader with push constants - record the size
876+
// First push constant resource - record the size
873877
PushConstant.Size = PCSize;
874878
}
875879
else if (PushConstant.Size != PCSize)
876880
{
877-
// Multiple shaders with different push constant sizes
878-
// This is allowed in Vulkan - take the maximum size
881+
// Multiple push constant resources with different sizes
882+
// Take the maximum size (Vulkan allows only one push constant block per pipeline)
879883
PushConstant.Size = std::max(PushConstant.Size, PCSize);
880884
}
881885

882-
// Add this shader stage to the stage flags
883-
PushConstant.StageFlags |= ShaderTypeToVkShaderStageFlagBit(Stage.Type);
886+
// Add shader stages to the stage flags
887+
for (SHADER_TYPE ShaderStages = ResDesc.ShaderStages; ShaderStages != SHADER_TYPE_UNKNOWN;)
888+
{
889+
const SHADER_TYPE ShaderType = ExtractLSB(ShaderStages);
890+
PushConstant.StageFlags |= ShaderTypeToVkShaderStageFlagBit(ShaderType);
891+
}
884892
}
885893
}
886894
}
@@ -900,13 +908,19 @@ void PipelineStateVkImpl::InitPipelineLayout(const PipelineStateCreateInfo& Crea
900908
DvpValidateResourceLimits();
901909
#endif
902910

903-
// Extract push constant information from shaders
904-
// Vulkan allows only one push constant block per pipeline, but it can be accessed from multiple shader stages
911+
// Extract push constant information from signatures
912+
// Vulkan allows only one push constant block per pipeline, but it can be accessed from multiple shader stages.
913+
// We use resource attributes from m_Signatures instead of from shader stages because we may patch SPIRV
914+
// later according to m_Signatures definitions (e.g., converting uniform buffers to push constants).
905915
PushConstantInfoVk PushConstant;
906-
InitPushConstantInfo(ShaderStages, PushConstant);
916+
InitPushConstantInfoFromSignatures(PushConstant);
907917

908918
m_PipelineLayout.Create(GetDevice(), m_Signatures, m_SignatureCount, PushConstant);
909919

920+
// Check if any resource in shader is a uniform buffer but marked as VULKAN_PUSH_CONSTANT in signatures.
921+
// If so, convert the uniform buffer to push constant in SPIRV bytecode.
922+
PatchShaderConvertUniformBufferToPushConstant(ShaderStages);
923+
910924
const bool RemapResources = (CreateInfo.Flags & PSO_CREATE_FLAG_DONT_REMAP_SHADER_RESOURCES) == 0;
911925
const bool VerifyBindings = !RemapResources && ((InternalFlags & PSO_CREATE_INTERNAL_FLAG_NO_SHADER_REFLECTION) == 0);
912926
if (RemapResources || VerifyBindings)
@@ -933,6 +947,87 @@ void PipelineStateVkImpl::InitPipelineLayout(const PipelineStateCreateInfo& Crea
933947
}
934948
}
935949

950+
void PipelineStateVkImpl::PatchShaderConvertUniformBufferToPushConstant(TShaderStages& ShaderStages) const noexcept(false)
951+
{
952+
// Build a set of resource names that need to be converted to push constants
953+
std::unordered_set<std::string> PushConstantNames;
954+
for (Uint32 s = 0; s < m_SignatureCount; ++s)
955+
{
956+
const PipelineResourceSignatureVkImpl* pSignature = m_Signatures[s];
957+
if (pSignature == nullptr)
958+
continue;
959+
960+
for (Uint32 r = 0; r < pSignature->GetTotalResourceCount(); ++r)
961+
{
962+
const PipelineResourceDesc& ResDesc = pSignature->GetResourceDesc(r);
963+
if ((ResDesc.Flags & PIPELINE_RESOURCE_FLAG_VULKAN_PUSH_CONSTANT) != 0)
964+
{
965+
PushConstantNames.insert(ResDesc.Name);
966+
}
967+
}
968+
}
969+
970+
if (PushConstantNames.empty())
971+
return;
972+
973+
// For each shader stage, check if any uniform buffer needs to be patched
974+
for (ShaderStageInfo& Stage : ShaderStages)
975+
{
976+
for (size_t i = 0; i < Stage.Shaders.size(); ++i)
977+
{
978+
ShaderVkImpl* pShader = const_cast<ShaderVkImpl*>(Stage.Shaders[i]);
979+
980+
bool ShouldPatchUniformBuffer = false;
981+
std::string PatchingUniformBufferName;
982+
{
983+
const SPIRVShaderResources* pRes = pShader->GetShaderResources().get();
984+
985+
if (pRes == nullptr)
986+
continue;
987+
988+
// Check each uniform buffer in the shader
989+
for (Uint32 ub = 0; ub < pRes->GetNumUBs(); ++ub)
990+
{
991+
const SPIRVShaderResourceAttribs& UBAttribs = pRes->GetUB(ub);
992+
993+
// If this uniform buffer is marked as push constant in the signature,
994+
// convert it in the SPIRV bytecode
995+
if (PushConstantNames.count(UBAttribs.Name) > 0)
996+
{
997+
PatchingUniformBufferName = UBAttribs.Name;
998+
ShouldPatchUniformBuffer = true;
999+
break;
1000+
}
1001+
}
1002+
}
1003+
1004+
if(ShouldPatchUniformBuffer)
1005+
{
1006+
const std::vector<uint32_t>& SPIRV = Stage.SPIRVs[i];
1007+
std::vector<uint32_t> ConvertedSPIRV = PatchSPIRVConvertUniformBufferToPushConstant(
1008+
SPIRV,
1009+
SPV_ENV_MAX, // Auto-detect target environment
1010+
PatchingUniformBufferName);
1011+
1012+
if (!ConvertedSPIRV.empty())
1013+
{
1014+
Stage.SPIRVs[i] = ConvertedSPIRV;
1015+
1016+
pShader->SetSPIRV(ConvertedSPIRV);
1017+
1018+
//Reconstruct shader resources from SPIRV
1019+
pShader->CreateSPIRVShaderResources();
1020+
}
1021+
else
1022+
{
1023+
LOG_ERROR_MESSAGE("Failed to convert uniform buffer '", PatchingUniformBufferName,
1024+
"' to push constant in shader '", pShader->GetDesc().Name, "'");
1025+
}
1026+
}
1027+
}
1028+
}
1029+
}
1030+
9361031
template <typename PSOCreateInfoType>
9371032
PipelineStateVkImpl::TShaderStages PipelineStateVkImpl::InitInternalObjects(
9381033
const PSOCreateInfoType& CreateInfo,

Graphics/GraphicsEngineVulkan/src/ShaderVariableManagerVk.cpp

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -667,30 +667,17 @@ void ShaderVariableManagerVk::SetInlineConstants(Uint32 ResIndex,
667667
const bool IsPushConstant = (ResDesc.Flags & PIPELINE_RESOURCE_FLAG_VULKAN_PUSH_CONSTANT) != 0;
668668
if (IsPushConstant)
669669
{
670-
const InlineConstantBufferAttribsVk* pInlineCBAttribs = m_pSignature->GetInlineConstantBufferAttribs();
671-
const Uint32 NumInlineCBAttribs = m_pSignature->GetNumInlineConstantBufferAttribs();
672-
673-
// For both Signature's static cache and SRB cache, push constant data is stored
674-
// in the resource cache via GetPushConstantDataPtr()
675-
for (Uint32 i = 0; i < NumInlineCBAttribs; ++i)
670+
// Get the data pointer from the resource cache
671+
void* pPushConstantData = m_ResourceCache.GetPushConstantDataPtr(ResIndex);
672+
if (pPushConstantData != nullptr)
676673
{
677-
const InlineConstantBufferAttribsVk& InlineCBAttr = pInlineCBAttribs[i];
678-
if (!InlineCBAttr.IsPushConstant)
679-
continue;
680-
681-
if (InlineCBAttr.ResIndex == ResIndex)
682-
{
683-
// Get the data pointer from the resource cache
684-
void* pPushConstantData = m_ResourceCache.GetPushConstantDataPtr(InlineCBAttr.ResIndex);
685-
if (pPushConstantData != nullptr)
686-
{
687-
Uint32* pDstConstants = reinterpret_cast<Uint32*>(pPushConstantData);
688-
memcpy(pDstConstants + FirstConstant, pConstants, NumConstants * sizeof(Uint32));
689-
}
690-
return;
691-
}
674+
Uint32* pDstConstants = reinterpret_cast<Uint32*>(pPushConstantData);
675+
memcpy(pDstConstants + FirstConstant, pConstants, NumConstants * sizeof(Uint32));
676+
}
677+
else
678+
{
679+
UNEXPECTED("Push constant buffer not found");
692680
}
693-
UNEXPECTED("Push constant buffer not found");
694681
}
695682
else
696683
{

Graphics/GraphicsEngineVulkan/src/ShaderVkImpl.cpp

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -229,37 +229,17 @@ void ShaderVkImpl::Initialize(const ShaderCreateInfo& ShaderCI,
229229

230230
// We cannot create shader module here because resource bindings are assigned when
231231
// pipeline state is created
232-
232+
m_LoadConstantBufferReflection = ShaderCI.LoadConstantBufferReflection;
233233
// Load shader resources
234234
if (!m_SPIRV.empty())
235235
{
236236
if ((ShaderCI.CompileFlags & SHADER_COMPILE_FLAG_SKIP_REFLECTION) == 0)
237237
{
238-
IMemoryAllocator& Allocator = GetRawAllocator();
239-
240-
std::unique_ptr<void, STDDeleterRawMem<void>> pRawMem{
241-
ALLOCATE(Allocator, "Memory for SPIRVShaderResources", SPIRVShaderResources, 1),
242-
STDDeleterRawMem<void>(Allocator),
243-
};
244-
const bool LoadShaderInputs = m_Desc.ShaderType == SHADER_TYPE_VERTEX;
245-
new (pRawMem.get()) SPIRVShaderResources // May throw
246-
{
247-
Allocator,
248-
m_SPIRV,
249-
m_Desc,
250-
m_Desc.UseCombinedTextureSamplers ? m_Desc.CombinedSamplerSuffix : nullptr,
251-
LoadShaderInputs,
252-
ShaderCI.LoadConstantBufferReflection,
253-
m_EntryPoint //
254-
};
238+
CreateSPIRVShaderResources();
239+
255240
VERIFY_EXPR(ShaderCI.ByteCode != nullptr || m_EntryPoint == ShaderCI.EntryPoint ||
256241
(m_EntryPoint == "main" && (ShaderCI.CompileFlags & SHADER_COMPILE_FLAG_HLSL_TO_SPIRV_VIA_GLSL) != 0));
257-
m_pShaderResources.reset(static_cast<SPIRVShaderResources*>(pRawMem.release()), STDDeleterRawMem<SPIRVShaderResources>(Allocator));
258242

259-
if (LoadShaderInputs && m_pShaderResources->IsHLSLSource())
260-
{
261-
m_pShaderResources->MapHLSLVertexShaderInputs(m_SPIRV);
262-
}
263243
}
264244
else
265245
{
@@ -270,6 +250,38 @@ void ShaderVkImpl::Initialize(const ShaderCreateInfo& ShaderCI,
270250
}
271251
}
272252

253+
void ShaderVkImpl::SetSPIRV(const std::vector<uint32_t>& SPIRV) noexcept(false)
254+
{
255+
m_SPIRV = SPIRV;
256+
}
257+
258+
void ShaderVkImpl::CreateSPIRVShaderResources() noexcept(false)
259+
{
260+
IMemoryAllocator& Allocator = GetRawAllocator();
261+
262+
std::unique_ptr<void, STDDeleterRawMem<void>> pRawMem{
263+
ALLOCATE(Allocator, "Memory for SPIRVShaderResources", SPIRVShaderResources, 1),
264+
STDDeleterRawMem<void>(Allocator),
265+
};
266+
const bool LoadShaderInputs = m_Desc.ShaderType == SHADER_TYPE_VERTEX;
267+
new (pRawMem.get()) SPIRVShaderResources // May throw
268+
{
269+
Allocator,
270+
m_SPIRV,
271+
m_Desc,
272+
m_Desc.UseCombinedTextureSamplers ? m_Desc.CombinedSamplerSuffix : nullptr,
273+
LoadShaderInputs,
274+
m_LoadConstantBufferReflection,
275+
m_EntryPoint //
276+
};
277+
278+
m_pShaderResources.reset(static_cast<SPIRVShaderResources*>(pRawMem.release()), STDDeleterRawMem<SPIRVShaderResources>(Allocator));
279+
280+
if (LoadShaderInputs && m_pShaderResources->IsHLSLSource())
281+
{
282+
m_pShaderResources->MapHLSLVertexShaderInputs(m_SPIRV);
283+
}
284+
}
273285

274286
ShaderVkImpl::ShaderVkImpl(IReferenceCounters* pRefCounters,
275287
RenderDeviceVkImpl* pRenderDeviceVk,

0 commit comments

Comments
 (0)