Skip to content

Commit d1e2ce3

Browse files
committed
Port ConvertUBOToPushConstantPass to SPIRVTools.cpp
1 parent 63411b2 commit d1e2ce3

File tree

2 files changed

+353
-4
lines changed

2 files changed

+353
-4
lines changed

Graphics/ShaderTools/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,17 @@ if(ENABLE_SPIRV)
180180
PRIVATE
181181
SPIRV-Tools-opt
182182
)
183+
# Add SPIRV-Tools internal headers path for custom pass implementation
184+
# We need both the source directory (for internal headers like pass.h)
185+
# and the binary directory (for generated headers like NonSemanticShaderDebugInfo100.h)
186+
get_target_property(SPIRV_TOOLS_SOURCE_DIR SPIRV-Tools-opt SOURCE_DIR)
187+
get_target_property(SPIRV_TOOLS_BINARY_DIR SPIRV-Tools-opt BINARY_DIR)
188+
get_filename_component(SPIRV_TOOLS_ROOT_DIR "${SPIRV_TOOLS_SOURCE_DIR}/../.." ABSOLUTE)
189+
get_filename_component(SPIRV_TOOLS_ROOT_BINARY_DIR "${SPIRV_TOOLS_BINARY_DIR}/../.." ABSOLUTE)
190+
target_include_directories(Diligent-ShaderTools PRIVATE
191+
${SPIRV_TOOLS_ROOT_DIR}
192+
${SPIRV_TOOLS_ROOT_BINARY_DIR}
193+
)
183194
target_compile_definitions(Diligent-ShaderTools PRIVATE USE_SPIRV_TOOLS=1)
184195
endif()
185196

Graphics/ShaderTools/src/SPIRVTools.cpp

Lines changed: 342 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929

3030
#include "spirv-tools/optimizer.hpp"
3131

32+
// SPIRV-Tools internal headers for custom pass implementation
33+
#include "source/opt/pass.h"
34+
#include "source/opt/ir_context.h"
35+
#include "source/opt/type_manager.h"
36+
#include "source/opt/decoration_manager.h"
37+
3238
namespace Diligent
3339
{
3440

@@ -108,6 +114,339 @@ spv_target_env SpvTargetEnvFromSPIRV(const std::vector<uint32_t>& SPIRV)
108114

109115
#undef SPV_SPIRV_VERSION_WORD
110116

117+
// A pass that converts a uniform buffer variable to a push constant.
118+
// This pass:
119+
// 1. Finds the variable with the specified block name
120+
// 2. Changes its storage class from Uniform to PushConstant
121+
// 3. Updates all pointer types that reference this variable
122+
// 4. Removes Binding and DescriptorSet decorations
123+
class ConvertUBOToPushConstantPass : public spvtools::opt::Pass
124+
{
125+
public:
126+
explicit ConvertUBOToPushConstantPass(const std::string& block_name) :
127+
m_BlockName{block_name}
128+
{}
129+
130+
const char* name() const override { return "convert-ubo-to-push-constant"; }
131+
132+
Status Process() override
133+
{
134+
bool modified = false;
135+
136+
// Find the ID that matches the block name by searching OpName instructions
137+
// This could be either a variable ID or a type ID (struct type)
138+
uint32_t named_id = 0;
139+
for (auto& debug_inst : context()->module()->debugs2())
140+
{
141+
if (debug_inst.opcode() == spv::Op::OpName &&
142+
debug_inst.GetOperand(1).AsString() == m_BlockName)
143+
{
144+
named_id = debug_inst.GetOperand(0).AsId();
145+
break;
146+
}
147+
}
148+
149+
if (named_id == 0)
150+
{
151+
// Block name not found
152+
return Status::SuccessWithoutChange;
153+
}
154+
155+
// Check if the named_id is a variable or a type
156+
spvtools::opt::Instruction* target_var = nullptr;
157+
spvtools::opt::Instruction* named_inst = get_def_use_mgr()->GetDef(named_id);
158+
159+
if (named_inst == nullptr)
160+
{
161+
return Status::SuccessWithoutChange;
162+
}
163+
164+
if (named_inst->opcode() == spv::Op::OpVariable)
165+
{
166+
// The name refers directly to a variable
167+
target_var = named_inst;
168+
}
169+
else if (named_inst->opcode() == spv::Op::OpTypeStruct)
170+
{
171+
// The name refers to a struct type, we need to find the variable
172+
// that uses a pointer to this struct type with Uniform storage class
173+
uint32_t struct_type_id = named_id;
174+
175+
// Search for a variable that points to this struct type with Uniform storage class
176+
for (auto& inst : context()->types_values())
177+
{
178+
if (inst.opcode() != spv::Op::OpVariable)
179+
{
180+
continue;
181+
}
182+
183+
// Get the pointer type of this variable
184+
spvtools::opt::Instruction* ptr_type = get_def_use_mgr()->GetDef(inst.type_id());
185+
if (ptr_type == nullptr || ptr_type->opcode() != spv::Op::OpTypePointer)
186+
{
187+
continue;
188+
}
189+
190+
// Check storage class is Uniform
191+
spv::StorageClass sc = static_cast<spv::StorageClass>(
192+
ptr_type->GetSingleWordInOperand(0));
193+
if (sc != spv::StorageClass::Uniform)
194+
{
195+
continue;
196+
}
197+
198+
// Check if the pointee type is our struct type
199+
uint32_t pointee_type_id = ptr_type->GetSingleWordInOperand(1);
200+
if (pointee_type_id == struct_type_id)
201+
{
202+
target_var = &inst;
203+
break;
204+
}
205+
}
206+
}
207+
208+
if (target_var == nullptr)
209+
{
210+
// Variable not found
211+
return Status::SuccessWithoutChange;
212+
}
213+
214+
uint32_t target_var_id = target_var->result_id();
215+
216+
// Get the pointer type of the variable
217+
spvtools::opt::Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(target_var->type_id());
218+
if (ptr_type_inst == nullptr || ptr_type_inst->opcode() != spv::Op::OpTypePointer)
219+
{
220+
return Status::SuccessWithoutChange;
221+
}
222+
223+
// Check if the storage class is Uniform
224+
spv::StorageClass storage_class =
225+
static_cast<spv::StorageClass>(ptr_type_inst->GetSingleWordInOperand(0));
226+
if (storage_class != spv::StorageClass::Uniform)
227+
{
228+
// Not a uniform buffer, nothing to do
229+
return Status::SuccessWithoutChange;
230+
}
231+
232+
// Get the pointee type ID
233+
uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
234+
235+
// Create or find a pointer type with PushConstant storage class
236+
spvtools::opt::analysis::TypeManager* type_mgr = context()->get_type_mgr();
237+
uint32_t new_ptr_type_id =
238+
type_mgr->FindPointerToType(pointee_type_id, spv::StorageClass::PushConstant);
239+
240+
if (new_ptr_type_id == 0)
241+
{
242+
// Failed to create new pointer type
243+
return Status::Failure;
244+
}
245+
246+
// Ensure the new pointer type is defined before the variable
247+
// FindPointerToType may have created it at the end, we need to move it
248+
spvtools::opt::Instruction* new_ptr_type_inst = get_def_use_mgr()->GetDef(new_ptr_type_id);
249+
if (new_ptr_type_inst != nullptr)
250+
{
251+
// Find the pointee type instruction to insert after it
252+
spvtools::opt::Instruction* pointee_type_inst = get_def_use_mgr()->GetDef(pointee_type_id);
253+
254+
// Check if new_ptr_type_inst is after target_var in the types_values list
255+
bool needs_move = false;
256+
for (auto& inst : context()->types_values())
257+
{
258+
if (&inst == target_var)
259+
{
260+
// Found target_var first, so new_ptr_type_inst is after it
261+
needs_move = true;
262+
break;
263+
}
264+
if (&inst == new_ptr_type_inst)
265+
{
266+
// Found new_ptr_type_inst first, it's in the right position
267+
needs_move = false;
268+
break;
269+
}
270+
}
271+
272+
if (needs_move && pointee_type_inst != nullptr)
273+
{
274+
// Move the new pointer type to right after the pointee type
275+
// InsertAfter will automatically remove it from its current position
276+
new_ptr_type_inst->InsertAfter(pointee_type_inst);
277+
}
278+
}
279+
280+
// Update the variable's type to the new pointer type
281+
target_var->SetResultType(new_ptr_type_id);
282+
283+
// Also update the storage class operand of OpVariable itself
284+
// OpVariable has the storage class as the first operand (index 0)
285+
target_var->SetInOperand(0, {static_cast<uint32_t>(spv::StorageClass::PushConstant)});
286+
287+
context()->UpdateDefUse(target_var);
288+
modified = true;
289+
290+
// Propagate storage class change to all users of this variable
291+
std::set<uint32_t> seen;
292+
std::vector<spvtools::opt::Instruction*> users;
293+
get_def_use_mgr()->ForEachUser(target_var, [&users](spvtools::opt::Instruction* user) {
294+
users.push_back(user);
295+
});
296+
297+
for (spvtools::opt::Instruction* user : users)
298+
{
299+
modified |= PropagateStorageClass(user, &seen);
300+
}
301+
302+
// Remove Binding and DescriptorSet decorations from the variable
303+
auto* deco_mgr = context()->get_decoration_mgr();
304+
deco_mgr->RemoveDecorationsFrom(target_var_id, [](const spvtools::opt::Instruction& inst) {
305+
if (inst.opcode() != spv::Op::OpDecorate)
306+
{
307+
return false;
308+
}
309+
spv::Decoration decoration =
310+
static_cast<spv::Decoration>(inst.GetSingleWordInOperand(1));
311+
return decoration == spv::Decoration::Binding ||
312+
decoration == spv::Decoration::DescriptorSet;
313+
});
314+
315+
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
316+
}
317+
318+
spvtools::opt::IRContext::Analysis GetPreservedAnalyses() override
319+
{
320+
// This pass modifies types and decorations
321+
return spvtools::opt::IRContext::kAnalysisNone;
322+
}
323+
324+
private:
325+
// Recursively updates the storage class of pointer types used by instructions
326+
// that reference the target variable.
327+
bool PropagateStorageClass(spvtools::opt::Instruction* inst, std::set<uint32_t>* seen)
328+
{
329+
if (!IsPointerResultType(inst))
330+
{
331+
return false;
332+
}
333+
334+
// Already has the correct storage class
335+
if (IsPointerToStorageClass(inst, spv::StorageClass::PushConstant))
336+
{
337+
if (inst->opcode() == spv::Op::OpPhi)
338+
{
339+
if (!seen->insert(inst->result_id()).second)
340+
{
341+
return false;
342+
}
343+
}
344+
345+
bool modified = false;
346+
std::vector<spvtools::opt::Instruction*> users;
347+
get_def_use_mgr()->ForEachUser(inst, [&users](spvtools::opt::Instruction* user) {
348+
users.push_back(user);
349+
});
350+
for (spvtools::opt::Instruction* user : users)
351+
{
352+
modified |= PropagateStorageClass(user, seen);
353+
}
354+
355+
if (inst->opcode() == spv::Op::OpPhi)
356+
{
357+
seen->erase(inst->result_id());
358+
}
359+
return modified;
360+
}
361+
362+
// Handle instructions that produce pointer results
363+
switch (inst->opcode())
364+
{
365+
case spv::Op::OpAccessChain:
366+
case spv::Op::OpPtrAccessChain:
367+
case spv::Op::OpInBoundsAccessChain:
368+
case spv::Op::OpInBoundsPtrAccessChain:
369+
case spv::Op::OpCopyObject:
370+
case spv::Op::OpPhi:
371+
case spv::Op::OpSelect:
372+
ChangeResultStorageClass(inst);
373+
{
374+
std::vector<spvtools::opt::Instruction*> users;
375+
get_def_use_mgr()->ForEachUser(inst, [&users](spvtools::opt::Instruction* user) {
376+
users.push_back(user);
377+
});
378+
for (spvtools::opt::Instruction* user : users)
379+
{
380+
PropagateStorageClass(user, seen);
381+
}
382+
}
383+
return true;
384+
385+
case spv::Op::OpLoad:
386+
case spv::Op::OpStore:
387+
case spv::Op::OpCopyMemory:
388+
case spv::Op::OpCopyMemorySized:
389+
// These don't produce pointer results that need updating
390+
return false;
391+
392+
default:
393+
return false;
394+
}
395+
}
396+
397+
// Changes the result type of an instruction to use the new storage class.
398+
void ChangeResultStorageClass(spvtools::opt::Instruction* inst)
399+
{
400+
spvtools::opt::analysis::TypeManager* type_mgr = context()->get_type_mgr();
401+
spvtools::opt::Instruction* result_type_inst = get_def_use_mgr()->GetDef(inst->type_id());
402+
403+
if (result_type_inst->opcode() != spv::Op::OpTypePointer)
404+
{
405+
return;
406+
}
407+
408+
uint32_t pointee_type_id = result_type_inst->GetSingleWordInOperand(1);
409+
uint32_t new_result_type_id =
410+
type_mgr->FindPointerToType(pointee_type_id, spv::StorageClass::PushConstant);
411+
412+
inst->SetResultType(new_result_type_id);
413+
context()->UpdateDefUse(inst);
414+
}
415+
416+
// Checks if the instruction result type is a pointer.
417+
bool IsPointerResultType(spvtools::opt::Instruction* inst)
418+
{
419+
if (inst->type_id() == 0)
420+
{
421+
return false;
422+
}
423+
424+
spvtools::opt::Instruction* type_def = get_def_use_mgr()->GetDef(inst->type_id());
425+
return type_def != nullptr && type_def->opcode() == spv::Op::OpTypePointer;
426+
}
427+
428+
// Checks if the instruction result type is a pointer to the specified storage class.
429+
bool IsPointerToStorageClass(spvtools::opt::Instruction* inst, spv::StorageClass storage_class)
430+
{
431+
if (inst->type_id() == 0)
432+
{
433+
return false;
434+
}
435+
436+
spvtools::opt::Instruction* type_def = get_def_use_mgr()->GetDef(inst->type_id());
437+
if (type_def == nullptr || type_def->opcode() != spv::Op::OpTypePointer)
438+
{
439+
return false;
440+
}
441+
442+
spv::StorageClass pointer_storage_class =
443+
static_cast<spv::StorageClass>(type_def->GetSingleWordInOperand(0));
444+
return pointer_storage_class == storage_class;
445+
}
446+
447+
std::string m_BlockName;
448+
};
449+
111450
} // namespace
112451

113452
std::vector<uint32_t> OptimizeSPIRV(const std::vector<uint32_t>& SrcSPIRV, spv_target_env TargetEnv, SPIRV_OPTIMIZATION_FLAGS Passes)
@@ -168,10 +507,9 @@ std::vector<uint32_t> PatchSPIRVConvertUniformBufferToPushConstant(
168507

169508
optimizer.SetMessageConsumer(SpvOptimizerMessageConsumer);
170509

171-
// Register the pass to convert UBO to push constant
172-
optimizer.RegisterPass(spvtools::CreateConvertUBOToPushConstantPass(BlockName));
173-
//optimizer.RegisterPass(spvtools::CreateDeadVariableEliminationPass());
174-
//optimizer.RegisterPerformancePasses();
510+
// Register the pass to convert UBO to push constant using custom out-of-tree pass
511+
optimizer.RegisterPass(spvtools::Optimizer::PassToken(
512+
std::make_unique<ConvertUBOToPushConstantPass>(BlockName)));
175513

176514
spvtools::OptimizerOptions options;
177515
#ifndef DILIGENT_DEVELOPMENT

0 commit comments

Comments
 (0)