|
29 | 29 |
|
30 | 30 | #include "spirv-tools/optimizer.hpp" |
31 | 31 |
|
| 32 | +// SPIRV-Tools internal headers for custom pass implementation |
| 33 | +#include "source/opt/pass.h" |
| 34 | +#include "source/opt/ir_context.h" |
| 35 | +#include "source/opt/type_manager.h" |
| 36 | +#include "source/opt/decoration_manager.h" |
| 37 | + |
32 | 38 | namespace Diligent |
33 | 39 | { |
34 | 40 |
|
@@ -108,6 +114,339 @@ spv_target_env SpvTargetEnvFromSPIRV(const std::vector<uint32_t>& SPIRV) |
108 | 114 |
|
109 | 115 | #undef SPV_SPIRV_VERSION_WORD |
110 | 116 |
|
| 117 | +// A pass that converts a uniform buffer variable to a push constant. |
| 118 | +// This pass: |
| 119 | +// 1. Finds the variable with the specified block name |
| 120 | +// 2. Changes its storage class from Uniform to PushConstant |
| 121 | +// 3. Updates all pointer types that reference this variable |
| 122 | +// 4. Removes Binding and DescriptorSet decorations |
| 123 | +class ConvertUBOToPushConstantPass : public spvtools::opt::Pass |
| 124 | +{ |
| 125 | +public: |
| 126 | + explicit ConvertUBOToPushConstantPass(const std::string& block_name) : |
| 127 | + m_BlockName{block_name} |
| 128 | + {} |
| 129 | + |
| 130 | + const char* name() const override { return "convert-ubo-to-push-constant"; } |
| 131 | + |
| 132 | + Status Process() override |
| 133 | + { |
| 134 | + bool modified = false; |
| 135 | + |
| 136 | + // Find the ID that matches the block name by searching OpName instructions |
| 137 | + // This could be either a variable ID or a type ID (struct type) |
| 138 | + uint32_t named_id = 0; |
| 139 | + for (auto& debug_inst : context()->module()->debugs2()) |
| 140 | + { |
| 141 | + if (debug_inst.opcode() == spv::Op::OpName && |
| 142 | + debug_inst.GetOperand(1).AsString() == m_BlockName) |
| 143 | + { |
| 144 | + named_id = debug_inst.GetOperand(0).AsId(); |
| 145 | + break; |
| 146 | + } |
| 147 | + } |
| 148 | + |
| 149 | + if (named_id == 0) |
| 150 | + { |
| 151 | + // Block name not found |
| 152 | + return Status::SuccessWithoutChange; |
| 153 | + } |
| 154 | + |
| 155 | + // Check if the named_id is a variable or a type |
| 156 | + spvtools::opt::Instruction* target_var = nullptr; |
| 157 | + spvtools::opt::Instruction* named_inst = get_def_use_mgr()->GetDef(named_id); |
| 158 | + |
| 159 | + if (named_inst == nullptr) |
| 160 | + { |
| 161 | + return Status::SuccessWithoutChange; |
| 162 | + } |
| 163 | + |
| 164 | + if (named_inst->opcode() == spv::Op::OpVariable) |
| 165 | + { |
| 166 | + // The name refers directly to a variable |
| 167 | + target_var = named_inst; |
| 168 | + } |
| 169 | + else if (named_inst->opcode() == spv::Op::OpTypeStruct) |
| 170 | + { |
| 171 | + // The name refers to a struct type, we need to find the variable |
| 172 | + // that uses a pointer to this struct type with Uniform storage class |
| 173 | + uint32_t struct_type_id = named_id; |
| 174 | + |
| 175 | + // Search for a variable that points to this struct type with Uniform storage class |
| 176 | + for (auto& inst : context()->types_values()) |
| 177 | + { |
| 178 | + if (inst.opcode() != spv::Op::OpVariable) |
| 179 | + { |
| 180 | + continue; |
| 181 | + } |
| 182 | + |
| 183 | + // Get the pointer type of this variable |
| 184 | + spvtools::opt::Instruction* ptr_type = get_def_use_mgr()->GetDef(inst.type_id()); |
| 185 | + if (ptr_type == nullptr || ptr_type->opcode() != spv::Op::OpTypePointer) |
| 186 | + { |
| 187 | + continue; |
| 188 | + } |
| 189 | + |
| 190 | + // Check storage class is Uniform |
| 191 | + spv::StorageClass sc = static_cast<spv::StorageClass>( |
| 192 | + ptr_type->GetSingleWordInOperand(0)); |
| 193 | + if (sc != spv::StorageClass::Uniform) |
| 194 | + { |
| 195 | + continue; |
| 196 | + } |
| 197 | + |
| 198 | + // Check if the pointee type is our struct type |
| 199 | + uint32_t pointee_type_id = ptr_type->GetSingleWordInOperand(1); |
| 200 | + if (pointee_type_id == struct_type_id) |
| 201 | + { |
| 202 | + target_var = &inst; |
| 203 | + break; |
| 204 | + } |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + if (target_var == nullptr) |
| 209 | + { |
| 210 | + // Variable not found |
| 211 | + return Status::SuccessWithoutChange; |
| 212 | + } |
| 213 | + |
| 214 | + uint32_t target_var_id = target_var->result_id(); |
| 215 | + |
| 216 | + // Get the pointer type of the variable |
| 217 | + spvtools::opt::Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(target_var->type_id()); |
| 218 | + if (ptr_type_inst == nullptr || ptr_type_inst->opcode() != spv::Op::OpTypePointer) |
| 219 | + { |
| 220 | + return Status::SuccessWithoutChange; |
| 221 | + } |
| 222 | + |
| 223 | + // Check if the storage class is Uniform |
| 224 | + spv::StorageClass storage_class = |
| 225 | + static_cast<spv::StorageClass>(ptr_type_inst->GetSingleWordInOperand(0)); |
| 226 | + if (storage_class != spv::StorageClass::Uniform) |
| 227 | + { |
| 228 | + // Not a uniform buffer, nothing to do |
| 229 | + return Status::SuccessWithoutChange; |
| 230 | + } |
| 231 | + |
| 232 | + // Get the pointee type ID |
| 233 | + uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1); |
| 234 | + |
| 235 | + // Create or find a pointer type with PushConstant storage class |
| 236 | + spvtools::opt::analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
| 237 | + uint32_t new_ptr_type_id = |
| 238 | + type_mgr->FindPointerToType(pointee_type_id, spv::StorageClass::PushConstant); |
| 239 | + |
| 240 | + if (new_ptr_type_id == 0) |
| 241 | + { |
| 242 | + // Failed to create new pointer type |
| 243 | + return Status::Failure; |
| 244 | + } |
| 245 | + |
| 246 | + // Ensure the new pointer type is defined before the variable |
| 247 | + // FindPointerToType may have created it at the end, we need to move it |
| 248 | + spvtools::opt::Instruction* new_ptr_type_inst = get_def_use_mgr()->GetDef(new_ptr_type_id); |
| 249 | + if (new_ptr_type_inst != nullptr) |
| 250 | + { |
| 251 | + // Find the pointee type instruction to insert after it |
| 252 | + spvtools::opt::Instruction* pointee_type_inst = get_def_use_mgr()->GetDef(pointee_type_id); |
| 253 | + |
| 254 | + // Check if new_ptr_type_inst is after target_var in the types_values list |
| 255 | + bool needs_move = false; |
| 256 | + for (auto& inst : context()->types_values()) |
| 257 | + { |
| 258 | + if (&inst == target_var) |
| 259 | + { |
| 260 | + // Found target_var first, so new_ptr_type_inst is after it |
| 261 | + needs_move = true; |
| 262 | + break; |
| 263 | + } |
| 264 | + if (&inst == new_ptr_type_inst) |
| 265 | + { |
| 266 | + // Found new_ptr_type_inst first, it's in the right position |
| 267 | + needs_move = false; |
| 268 | + break; |
| 269 | + } |
| 270 | + } |
| 271 | + |
| 272 | + if (needs_move && pointee_type_inst != nullptr) |
| 273 | + { |
| 274 | + // Move the new pointer type to right after the pointee type |
| 275 | + // InsertAfter will automatically remove it from its current position |
| 276 | + new_ptr_type_inst->InsertAfter(pointee_type_inst); |
| 277 | + } |
| 278 | + } |
| 279 | + |
| 280 | + // Update the variable's type to the new pointer type |
| 281 | + target_var->SetResultType(new_ptr_type_id); |
| 282 | + |
| 283 | + // Also update the storage class operand of OpVariable itself |
| 284 | + // OpVariable has the storage class as the first operand (index 0) |
| 285 | + target_var->SetInOperand(0, {static_cast<uint32_t>(spv::StorageClass::PushConstant)}); |
| 286 | + |
| 287 | + context()->UpdateDefUse(target_var); |
| 288 | + modified = true; |
| 289 | + |
| 290 | + // Propagate storage class change to all users of this variable |
| 291 | + std::set<uint32_t> seen; |
| 292 | + std::vector<spvtools::opt::Instruction*> users; |
| 293 | + get_def_use_mgr()->ForEachUser(target_var, [&users](spvtools::opt::Instruction* user) { |
| 294 | + users.push_back(user); |
| 295 | + }); |
| 296 | + |
| 297 | + for (spvtools::opt::Instruction* user : users) |
| 298 | + { |
| 299 | + modified |= PropagateStorageClass(user, &seen); |
| 300 | + } |
| 301 | + |
| 302 | + // Remove Binding and DescriptorSet decorations from the variable |
| 303 | + auto* deco_mgr = context()->get_decoration_mgr(); |
| 304 | + deco_mgr->RemoveDecorationsFrom(target_var_id, [](const spvtools::opt::Instruction& inst) { |
| 305 | + if (inst.opcode() != spv::Op::OpDecorate) |
| 306 | + { |
| 307 | + return false; |
| 308 | + } |
| 309 | + spv::Decoration decoration = |
| 310 | + static_cast<spv::Decoration>(inst.GetSingleWordInOperand(1)); |
| 311 | + return decoration == spv::Decoration::Binding || |
| 312 | + decoration == spv::Decoration::DescriptorSet; |
| 313 | + }); |
| 314 | + |
| 315 | + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; |
| 316 | + } |
| 317 | + |
| 318 | + spvtools::opt::IRContext::Analysis GetPreservedAnalyses() override |
| 319 | + { |
| 320 | + // This pass modifies types and decorations |
| 321 | + return spvtools::opt::IRContext::kAnalysisNone; |
| 322 | + } |
| 323 | + |
| 324 | +private: |
| 325 | + // Recursively updates the storage class of pointer types used by instructions |
| 326 | + // that reference the target variable. |
| 327 | + bool PropagateStorageClass(spvtools::opt::Instruction* inst, std::set<uint32_t>* seen) |
| 328 | + { |
| 329 | + if (!IsPointerResultType(inst)) |
| 330 | + { |
| 331 | + return false; |
| 332 | + } |
| 333 | + |
| 334 | + // Already has the correct storage class |
| 335 | + if (IsPointerToStorageClass(inst, spv::StorageClass::PushConstant)) |
| 336 | + { |
| 337 | + if (inst->opcode() == spv::Op::OpPhi) |
| 338 | + { |
| 339 | + if (!seen->insert(inst->result_id()).second) |
| 340 | + { |
| 341 | + return false; |
| 342 | + } |
| 343 | + } |
| 344 | + |
| 345 | + bool modified = false; |
| 346 | + std::vector<spvtools::opt::Instruction*> users; |
| 347 | + get_def_use_mgr()->ForEachUser(inst, [&users](spvtools::opt::Instruction* user) { |
| 348 | + users.push_back(user); |
| 349 | + }); |
| 350 | + for (spvtools::opt::Instruction* user : users) |
| 351 | + { |
| 352 | + modified |= PropagateStorageClass(user, seen); |
| 353 | + } |
| 354 | + |
| 355 | + if (inst->opcode() == spv::Op::OpPhi) |
| 356 | + { |
| 357 | + seen->erase(inst->result_id()); |
| 358 | + } |
| 359 | + return modified; |
| 360 | + } |
| 361 | + |
| 362 | + // Handle instructions that produce pointer results |
| 363 | + switch (inst->opcode()) |
| 364 | + { |
| 365 | + case spv::Op::OpAccessChain: |
| 366 | + case spv::Op::OpPtrAccessChain: |
| 367 | + case spv::Op::OpInBoundsAccessChain: |
| 368 | + case spv::Op::OpInBoundsPtrAccessChain: |
| 369 | + case spv::Op::OpCopyObject: |
| 370 | + case spv::Op::OpPhi: |
| 371 | + case spv::Op::OpSelect: |
| 372 | + ChangeResultStorageClass(inst); |
| 373 | + { |
| 374 | + std::vector<spvtools::opt::Instruction*> users; |
| 375 | + get_def_use_mgr()->ForEachUser(inst, [&users](spvtools::opt::Instruction* user) { |
| 376 | + users.push_back(user); |
| 377 | + }); |
| 378 | + for (spvtools::opt::Instruction* user : users) |
| 379 | + { |
| 380 | + PropagateStorageClass(user, seen); |
| 381 | + } |
| 382 | + } |
| 383 | + return true; |
| 384 | + |
| 385 | + case spv::Op::OpLoad: |
| 386 | + case spv::Op::OpStore: |
| 387 | + case spv::Op::OpCopyMemory: |
| 388 | + case spv::Op::OpCopyMemorySized: |
| 389 | + // These don't produce pointer results that need updating |
| 390 | + return false; |
| 391 | + |
| 392 | + default: |
| 393 | + return false; |
| 394 | + } |
| 395 | + } |
| 396 | + |
| 397 | + // Changes the result type of an instruction to use the new storage class. |
| 398 | + void ChangeResultStorageClass(spvtools::opt::Instruction* inst) |
| 399 | + { |
| 400 | + spvtools::opt::analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
| 401 | + spvtools::opt::Instruction* result_type_inst = get_def_use_mgr()->GetDef(inst->type_id()); |
| 402 | + |
| 403 | + if (result_type_inst->opcode() != spv::Op::OpTypePointer) |
| 404 | + { |
| 405 | + return; |
| 406 | + } |
| 407 | + |
| 408 | + uint32_t pointee_type_id = result_type_inst->GetSingleWordInOperand(1); |
| 409 | + uint32_t new_result_type_id = |
| 410 | + type_mgr->FindPointerToType(pointee_type_id, spv::StorageClass::PushConstant); |
| 411 | + |
| 412 | + inst->SetResultType(new_result_type_id); |
| 413 | + context()->UpdateDefUse(inst); |
| 414 | + } |
| 415 | + |
| 416 | + // Checks if the instruction result type is a pointer. |
| 417 | + bool IsPointerResultType(spvtools::opt::Instruction* inst) |
| 418 | + { |
| 419 | + if (inst->type_id() == 0) |
| 420 | + { |
| 421 | + return false; |
| 422 | + } |
| 423 | + |
| 424 | + spvtools::opt::Instruction* type_def = get_def_use_mgr()->GetDef(inst->type_id()); |
| 425 | + return type_def != nullptr && type_def->opcode() == spv::Op::OpTypePointer; |
| 426 | + } |
| 427 | + |
| 428 | + // Checks if the instruction result type is a pointer to the specified storage class. |
| 429 | + bool IsPointerToStorageClass(spvtools::opt::Instruction* inst, spv::StorageClass storage_class) |
| 430 | + { |
| 431 | + if (inst->type_id() == 0) |
| 432 | + { |
| 433 | + return false; |
| 434 | + } |
| 435 | + |
| 436 | + spvtools::opt::Instruction* type_def = get_def_use_mgr()->GetDef(inst->type_id()); |
| 437 | + if (type_def == nullptr || type_def->opcode() != spv::Op::OpTypePointer) |
| 438 | + { |
| 439 | + return false; |
| 440 | + } |
| 441 | + |
| 442 | + spv::StorageClass pointer_storage_class = |
| 443 | + static_cast<spv::StorageClass>(type_def->GetSingleWordInOperand(0)); |
| 444 | + return pointer_storage_class == storage_class; |
| 445 | + } |
| 446 | + |
| 447 | + std::string m_BlockName; |
| 448 | +}; |
| 449 | + |
111 | 450 | } // namespace |
112 | 451 |
|
113 | 452 | std::vector<uint32_t> OptimizeSPIRV(const std::vector<uint32_t>& SrcSPIRV, spv_target_env TargetEnv, SPIRV_OPTIMIZATION_FLAGS Passes) |
@@ -168,10 +507,9 @@ std::vector<uint32_t> PatchSPIRVConvertUniformBufferToPushConstant( |
168 | 507 |
|
169 | 508 | optimizer.SetMessageConsumer(SpvOptimizerMessageConsumer); |
170 | 509 |
|
171 | | - // Register the pass to convert UBO to push constant |
172 | | - optimizer.RegisterPass(spvtools::CreateConvertUBOToPushConstantPass(BlockName)); |
173 | | - //optimizer.RegisterPass(spvtools::CreateDeadVariableEliminationPass()); |
174 | | - //optimizer.RegisterPerformancePasses(); |
| 510 | + // Register the pass to convert UBO to push constant using custom out-of-tree pass |
| 511 | + optimizer.RegisterPass(spvtools::Optimizer::PassToken( |
| 512 | + std::make_unique<ConvertUBOToPushConstantPass>(BlockName))); |
175 | 513 |
|
176 | 514 | spvtools::OptimizerOptions options; |
177 | 515 | #ifndef DILIGENT_DEVELOPMENT |
|
0 commit comments