Skip to content

Commit 3b72ff3

Browse files
ConvertUBOToPushConstantPass: better handle edge cases
1 parent 973c29e commit 3b72ff3

File tree

2 files changed

+45
-33
lines changed

2 files changed

+45
-33
lines changed

Graphics/ShaderTools/src/ConvertUBOToPushConstant.cpp

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class ConvertUBOToPushConstantPass : public spvtools::opt::Pass
131131

132132
// Get the pointee type ID and verify it has Block decoration
133133
uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
134-
if (HasBlockDecoration(pointee_type_id))
134+
if (IsUBOBlockType(pointee_type_id))
135135
{
136136
// Found a UniformBuffer!
137137
target_var = var_inst;
@@ -172,7 +172,7 @@ class ConvertUBOToPushConstantPass : public spvtools::opt::Pass
172172
if (pointee_type_id == struct_type_id)
173173
{
174174
// Verify it has Block decoration
175-
if (HasBlockDecoration(pointee_type_id))
175+
if (IsUBOBlockType(pointee_type_id))
176176
{
177177
// Found a UniformBuffer!
178178
target_var = &inst;
@@ -239,10 +239,10 @@ class ConvertUBOToPushConstantPass : public spvtools::opt::Pass
239239
users.push_back(user);
240240
});
241241

242-
std::unordered_set<uint32_t> seen;
242+
std::unordered_set<uint32_t> visited;
243243
for (spvtools::opt::Instruction* user : users)
244244
{
245-
modified |= PropagateStorageClass(*user, seen);
245+
modified |= PropagateStorageClass(*user, visited);
246246
}
247247

248248
// Remove Binding and DescriptorSet decorations from the variable
@@ -270,39 +270,36 @@ class ConvertUBOToPushConstantPass : public spvtools::opt::Pass
270270
private:
271271
// Recursively updates the storage class of pointer types used by instructions
272272
// that reference the target variable.
273-
bool PropagateStorageClass(spvtools::opt::Instruction& inst, std::unordered_set<uint32_t>& seen) const
273+
bool PropagateStorageClass(spvtools::opt::Instruction& inst, std::unordered_set<uint32_t>& visited)
274274
{
275275
if (!IsPointerResultType(inst))
276276
{
277277
return false;
278278
}
279279

280+
// Use a "visited" set keyed by result_id for ANY pointer-producing instruction.
281+
// This avoids infinite recursion in pointer SSA loops.
282+
if (inst.result_id() != 0)
283+
{
284+
if (!visited.insert(inst.result_id()).second)
285+
return false;
286+
}
287+
280288
// Already has the correct storage class
281289
if (IsPointerToStorageClass(inst, spv::StorageClass::PushConstant))
282290
{
283-
if (inst.opcode() == spv::Op::OpPhi)
284-
{
285-
if (!seen.insert(inst.result_id()).second)
286-
{
287-
return false;
288-
}
289-
}
290-
291-
bool modified = false;
292291
std::vector<spvtools::opt::Instruction*> users;
293292
get_def_use_mgr()->ForEachUser(&inst, [&users](spvtools::opt::Instruction* user) {
294293
users.push_back(user);
295294
});
295+
296+
bool modified = false;
296297
for (spvtools::opt::Instruction* user : users)
297298
{
298-
if (PropagateStorageClass(*user, seen))
299+
if (PropagateStorageClass(*user, visited))
299300
modified = true;
300301
}
301302

302-
if (inst.opcode() == spv::Op::OpPhi)
303-
{
304-
seen.erase(inst.result_id());
305-
}
306303
return modified;
307304
}
308305

@@ -318,6 +315,9 @@ class ConvertUBOToPushConstantPass : public spvtools::opt::Pass
318315
case spv::Op::OpCopyObject:
319316
case spv::Op::OpPhi:
320317
case spv::Op::OpSelect:
318+
case spv::Op::OpBitcast:
319+
case spv::Op::OpUndef:
320+
case spv::Op::OpConstantNull:
321321
ChangeResultStorageClass(inst);
322322
{
323323
std::vector<spvtools::opt::Instruction*> users;
@@ -326,7 +326,7 @@ class ConvertUBOToPushConstantPass : public spvtools::opt::Pass
326326
});
327327
for (spvtools::opt::Instruction* user : users)
328328
{
329-
PropagateStorageClass(*user, seen);
329+
PropagateStorageClass(*user, visited);
330330
}
331331
}
332332
return true;
@@ -343,7 +343,6 @@ class ConvertUBOToPushConstantPass : public spvtools::opt::Pass
343343
case spv::Op::OpCopyMemory:
344344
case spv::Op::OpCopyMemorySized:
345345
case spv::Op::OpImageTexelPointer:
346-
case spv::Op::OpBitcast:
347346
case spv::Op::OpVariable:
348347
// These don't produce pointer results that need updating,
349348
// or the result type is independent of the operand's storage class.
@@ -358,20 +357,21 @@ class ConvertUBOToPushConstantPass : public spvtools::opt::Pass
358357
}
359358

360359
// Changes the result type of an instruction to use the new storage class.
361-
void ChangeResultStorageClass(spvtools::opt::Instruction& inst) const
360+
void ChangeResultStorageClass(spvtools::opt::Instruction& inst)
362361
{
363362
spvtools::opt::analysis::TypeManager* type_mgr = context()->get_type_mgr();
364363
spvtools::opt::Instruction* result_type_inst = get_def_use_mgr()->GetDef(inst.type_id());
365364

366-
if (result_type_inst->opcode() != spv::Op::OpTypePointer)
367-
{
365+
if (result_type_inst == nullptr || result_type_inst->opcode() != spv::Op::OpTypePointer)
368366
return;
369-
}
370367

371368
uint32_t pointee_type_id = result_type_inst->GetSingleWordInOperand(1);
372369
uint32_t new_result_type_id =
373370
type_mgr->FindPointerToType(pointee_type_id, spv::StorageClass::PushConstant);
374371

372+
if (new_result_type_id == 0)
373+
return;
374+
375375
inst.SetResultType(new_result_type_id);
376376
context()->UpdateDefUse(&inst);
377377
}
@@ -407,16 +407,23 @@ class ConvertUBOToPushConstantPass : public spvtools::opt::Pass
407407
return pointer_storage_class == storage_class;
408408
}
409409

410-
// Checks if a type has the Block decoration, which identifies it as a UBO struct type.
411-
bool HasBlockDecoration(uint32_t type_id) const
410+
bool HasDecoration(uint32_t id, spv::Decoration deco) const
412411
{
413-
bool has_block = false;
412+
bool found = false;
414413
get_decoration_mgr()->ForEachDecoration(
415-
type_id, static_cast<uint32_t>(spv::Decoration::Block),
416-
[&has_block](const spvtools::opt::Instruction&) {
417-
has_block = true;
414+
id, static_cast<uint32_t>(deco),
415+
[&found](const spvtools::opt::Instruction&) {
416+
found = true;
418417
});
419-
return has_block;
418+
return found;
419+
}
420+
421+
// Checks if a type has the Block decoration (but not the BufferBlock),
422+
// which identifies it as a UBO struct type.
423+
bool IsUBOBlockType(uint32_t type_id) const
424+
{
425+
return HasDecoration(type_id, spv::Decoration::Block) &&
426+
!HasDecoration(type_id, spv::Decoration::BufferBlock);
420427
}
421428

422429
std::string m_BlockName;

Tests/DiligentCoreTest/assets/shaders/SPIRV/PushConstants.psh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,15 @@ struct PushConstants_t
1111
//note that cbuffer PushConstants is not allowed in DXC, but ConstantBuffer<PushConstants_t> is allowed
1212
[[vk::push_constant]] ConstantBuffer<PushConstants_t> PushConstants;
1313

14+
float GetScale(ConstantBuffer<PushConstants_t> PC)
15+
{
16+
return PC.g_Scale;
17+
}
18+
1419
float4 main() : SV_Target
1520
{
1621
// Use push constant data through the structure instance
17-
float4 result = PushConstants.g_Color * PushConstants.g_Scale;
22+
float4 result = PushConstants.g_Color * GetScale(PushConstants);
1823

1924
// Apply offset to result (simplified example)
2025
result.xy += PushConstants.g_Offset;

0 commit comments

Comments
 (0)