Skip to content

Commit f932fc2

Browse files
authored
spirv-opt: Handle id overflow in MergeReturnPass (#6340)
This CL adds error handling to the MergeReturnPass to gracefully handle cases where the pass runs out of IDs. The following functions were modified to return a boolean indicating success or failure: - AddNewPhiNodes - AddReturnFlag - AddReturnValue - BranchToBlock - CreatePhiNodesForInst - ProcessStructuredBlock - RecordReturned - UpdatePhiNodes The callers of these functions were updated to check the return value and propagate the failure. This prevents the pass from crashing when it runs out of IDs.
1 parent 01628d6 commit f932fc2

File tree

2 files changed

+149
-68
lines changed

2 files changed

+149
-68
lines changed

source/opt/merge_return_pass.cpp

Lines changed: 119 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ bool MergeReturnPass::ProcessStructured(
134134
state_.pop_back();
135135
}
136136

137-
ProcessStructuredBlock(block);
137+
if (!ProcessStructuredBlock(block)) {
138+
return false;
139+
}
138140

139141
// Generate state for next block if warranted
140142
GenerateState(block);
@@ -169,7 +171,9 @@ bool MergeReturnPass::ProcessStructured(
169171
// We have not kept the dominator tree up-to-date.
170172
// Invalidate it at this point to make sure it will be rebuilt.
171173
context()->RemoveDominatorAnalysis(function);
172-
AddNewPhiNodes();
174+
if (!AddNewPhiNodes()) {
175+
return false;
176+
}
173177
return true;
174178
}
175179

@@ -196,7 +200,9 @@ bool MergeReturnPass::CreateReturnBlock() {
196200
}
197201

198202
bool MergeReturnPass::CreateReturn(BasicBlock* block) {
199-
AddReturnValue();
203+
if (!AddReturnValue()) {
204+
return false;
205+
}
200206

201207
if (return_value_) {
202208
// Load and return the final return value
@@ -229,12 +235,18 @@ bool MergeReturnPass::CreateReturn(BasicBlock* block) {
229235
return true;
230236
}
231237

232-
void MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) {
238+
bool MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) {
239+
if (block->tail() == block->end()) {
240+
return true;
241+
}
242+
233243
spv::Op tail_opcode = block->tail()->opcode();
234244
if (tail_opcode == spv::Op::OpReturn ||
235245
tail_opcode == spv::Op::OpReturnValue) {
236246
if (!return_flag_) {
237-
AddReturnFlag();
247+
if (!AddReturnFlag()) {
248+
return false;
249+
}
238250
}
239251
}
240252

@@ -243,43 +255,57 @@ void MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) {
243255
tail_opcode == spv::Op::OpUnreachable) {
244256
assert(CurrentState().InBreakable() &&
245257
"Should be in the placeholder construct.");
246-
BranchToBlock(block, CurrentState().BreakMergeId());
258+
if (!BranchToBlock(block, CurrentState().BreakMergeId())) {
259+
return false;
260+
}
247261
return_blocks_.insert(block->id());
248262
}
263+
return true;
249264
}
250265

251-
void MergeReturnPass::BranchToBlock(BasicBlock* block, uint32_t target) {
266+
bool MergeReturnPass::BranchToBlock(BasicBlock* block, uint32_t target) {
252267
if (block->tail()->opcode() == spv::Op::OpReturn ||
253268
block->tail()->opcode() == spv::Op::OpReturnValue) {
254-
RecordReturned(block);
269+
if (!RecordReturned(block)) {
270+
return false;
271+
}
255272
RecordReturnValue(block);
256273
}
257274

258275
BasicBlock* target_block = context()->get_instr_block(target);
259276
if (target_block->GetLoopMergeInst()) {
260277
cfg()->SplitLoopHeader(target_block);
261278
}
262-
UpdatePhiNodes(block, target_block);
279+
if (!UpdatePhiNodes(block, target_block)) {
280+
return false;
281+
}
263282

264283
Instruction* return_inst = block->terminator();
265284
return_inst->SetOpcode(spv::Op::OpBranch);
266285
return_inst->ReplaceOperands({{SPV_OPERAND_TYPE_ID, {target}}});
267286
context()->get_def_use_mgr()->AnalyzeInstDefUse(return_inst);
268287
new_edges_[target_block].insert(block->id());
269288
cfg()->AddEdge(block->id(), target);
289+
return true;
270290
}
271291

272-
void MergeReturnPass::UpdatePhiNodes(BasicBlock* new_source,
292+
bool MergeReturnPass::UpdatePhiNodes(BasicBlock* new_source,
273293
BasicBlock* target) {
274-
target->ForEachPhiInst([this, new_source](Instruction* inst) {
294+
bool succeeded = true;
295+
target->ForEachPhiInst([this, new_source, &succeeded](Instruction* inst) {
275296
uint32_t undefId = Type2Undef(inst->type_id());
297+
if (undefId == 0) {
298+
succeeded = false;
299+
return;
300+
}
276301
inst->AddOperand({SPV_OPERAND_TYPE_ID, {undefId}});
277302
inst->AddOperand({SPV_OPERAND_TYPE_ID, {new_source->id()}});
278303
context()->UpdateDefUse(inst);
279304
});
305+
return succeeded;
280306
}
281307

282-
void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
308+
bool MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
283309
Instruction& inst) {
284310
DominatorAnalysis* dom_tree =
285311
context()->GetDominatorAnalysis(merge_block->GetParent());
@@ -313,7 +339,7 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
313339
});
314340

315341
if (users_to_update.empty()) {
316-
return;
342+
return true;
317343
}
318344

319345
// There is at least one values that needs to be replaced.
@@ -357,6 +383,9 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
357383
if (regenerateInstruction) {
358384
std::unique_ptr<Instruction> regen_inst(inst.Clone(context()));
359385
uint32_t new_id = TakeNextId();
386+
if (new_id == 0) {
387+
return false;
388+
}
360389
regen_inst->SetResultId(new_id);
361390
Instruction* insert_pos = &*merge_block->begin();
362391
while (insert_pos->opcode() == spv::Op::OpPhi) {
@@ -366,19 +395,31 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
366395
get_def_use_mgr()->AnalyzeInstDefUse(new_phi);
367396
context()->set_instr_block(new_phi, merge_block);
368397

369-
new_phi->ForEachInId([dom_tree, merge_block, this](uint32_t* use_id) {
398+
bool succeeded = true;
399+
new_phi->ForEachInId([dom_tree, merge_block, this,
400+
&succeeded](uint32_t* use_id) {
401+
if (!succeeded) {
402+
return;
403+
}
370404
Instruction* use = get_def_use_mgr()->GetDef(*use_id);
371405
BasicBlock* use_bb = context()->get_instr_block(use);
372406
if (use_bb != nullptr && !dom_tree->Dominates(use_bb, merge_block)) {
373-
CreatePhiNodesForInst(merge_block, *use);
407+
if (!CreatePhiNodesForInst(merge_block, *use)) {
408+
succeeded = false;
409+
}
374410
}
375411
});
412+
if (!succeeded) {
413+
return false;
414+
}
376415
} else {
377416
InstructionBuilder builder(
378417
context(), &*merge_block->begin(),
379418
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
380-
// TODO(1841): Handle id overflow.
381419
new_phi = builder.AddPhi(inst.type_id(), phi_operands);
420+
if (new_phi == nullptr) {
421+
return false;
422+
}
382423
}
383424
uint32_t result_of_phi = new_phi->result_id();
384425

@@ -392,6 +433,7 @@ void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
392433
context()->AnalyzeUses(user);
393434
}
394435
}
436+
return true;
395437
}
396438

397439
bool MergeReturnPass::PredicateBlocks(
@@ -484,6 +526,9 @@ bool MergeReturnPass::BreakFromConstruct(
484526
cfg()->RemoveSuccessorEdges(block);
485527

486528
auto old_body_id = TakeNextId();
529+
if (old_body_id == 0) {
530+
return false;
531+
}
487532
BasicBlock* old_body = block->SplitBasicBlock(context(), old_body_id, iter);
488533
predicated->insert(old_body);
489534

@@ -520,9 +565,11 @@ bool MergeReturnPass::BreakFromConstruct(
520565
analysis::Bool bool_type;
521566
uint32_t bool_id = context()->get_type_mgr()->GetId(&bool_type);
522567
assert(bool_id != 0);
523-
// TODO(1841): Handle id overflow.
524-
uint32_t load_id =
525-
builder.AddLoad(bool_id, return_flag_->result_id())->result_id();
568+
Instruction* load_inst = builder.AddLoad(bool_id, return_flag_->result_id());
569+
if (load_inst == nullptr) {
570+
return false;
571+
}
572+
uint32_t load_id = load_inst->result_id();
526573

527574
// 2. Branch to |merge_block| (true) or |old_body| (false)
528575
builder.AddConditionalBranch(load_id, merge_block->id(), old_body->id(),
@@ -535,7 +582,9 @@ bool MergeReturnPass::BreakFromConstruct(
535582
}
536583

537584
// 3. Update OpPhi instructions in |merge_block|.
538-
UpdatePhiNodes(block, merge_block);
585+
if (!UpdatePhiNodes(block, merge_block)) {
586+
return false;
587+
}
539588

540589
// 4. Update the CFG. We do this after updating the OpPhi instructions
541590
// because |UpdatePhiNodes| assumes the edge from |block| has not been added
@@ -548,10 +597,10 @@ bool MergeReturnPass::BreakFromConstruct(
548597
return true;
549598
}
550599

551-
void MergeReturnPass::RecordReturned(BasicBlock* block) {
600+
bool MergeReturnPass::RecordReturned(BasicBlock* block) {
552601
if (block->tail()->opcode() != spv::Op::OpReturn &&
553602
block->tail()->opcode() != spv::Op::OpReturnValue)
554-
return;
603+
return true;
555604

556605
assert(return_flag_ && "Did not generate the return flag variable.");
557606

@@ -564,6 +613,9 @@ void MergeReturnPass::RecordReturned(BasicBlock* block) {
564613
const analysis::Constant* true_const =
565614
const_mgr->GetConstant(bool_type, {true});
566615
constant_true_ = const_mgr->GetDefiningInstruction(true_const);
616+
if (!constant_true_) {
617+
return false;
618+
}
567619
context()->UpdateDefUse(constant_true_);
568620
}
569621

@@ -577,6 +629,7 @@ void MergeReturnPass::RecordReturned(BasicBlock* block) {
577629
&*block->tail().InsertBefore(std::move(return_store));
578630
context()->set_instr_block(store_inst, block);
579631
context()->AnalyzeDefUse(store_inst);
632+
return true;
580633
}
581634

582635
void MergeReturnPass::RecordReturnValue(BasicBlock* block) {
@@ -600,18 +653,21 @@ void MergeReturnPass::RecordReturnValue(BasicBlock* block) {
600653
context()->AnalyzeDefUse(store_inst);
601654
}
602655

603-
void MergeReturnPass::AddReturnValue() {
604-
if (return_value_) return;
656+
bool MergeReturnPass::AddReturnValue() {
657+
if (return_value_) return true;
605658

606659
uint32_t return_type_id = function_->type_id();
607660
if (get_def_use_mgr()->GetDef(return_type_id)->opcode() ==
608661
spv::Op::OpTypeVoid)
609-
return;
662+
return true;
610663

611664
uint32_t return_ptr_type = context()->get_type_mgr()->FindPointerToType(
612665
return_type_id, spv::StorageClass::Function);
613666

614667
uint32_t var_id = TakeNextId();
668+
if (var_id == 0) {
669+
return false;
670+
}
615671
std::unique_ptr<Instruction> returnValue(
616672
new Instruction(context(), spv::Op::OpVariable, return_ptr_type, var_id,
617673
std::initializer_list<Operand>{
@@ -627,27 +683,44 @@ void MergeReturnPass::AddReturnValue() {
627683

628684
context()->get_decoration_mgr()->CloneDecorations(
629685
function_->result_id(), var_id, {spv::Decoration::RelaxedPrecision});
686+
return true;
630687
}
631688

632-
void MergeReturnPass::AddReturnFlag() {
633-
if (return_flag_) return;
689+
bool MergeReturnPass::AddReturnFlag() {
690+
if (return_flag_) return true;
634691

635692
analysis::TypeManager* type_mgr = context()->get_type_mgr();
636693
analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
637694

638695
analysis::Bool temp;
639696
uint32_t bool_id = type_mgr->GetTypeInstruction(&temp);
697+
if (bool_id == 0) {
698+
return false;
699+
}
640700
analysis::Bool* bool_type = type_mgr->GetType(bool_id)->AsBool();
641701

642702
const analysis::Constant* false_const =
643703
const_mgr->GetConstant(bool_type, {false});
644-
uint32_t const_false_id =
645-
const_mgr->GetDefiningInstruction(false_const)->result_id();
704+
Instruction* false_inst = const_mgr->GetDefiningInstruction(false_const);
705+
if (false_inst == nullptr) {
706+
return false;
707+
}
708+
uint32_t const_false_id = false_inst->result_id();
646709

647710
uint32_t bool_ptr_id =
648711
type_mgr->FindPointerToType(bool_id, spv::StorageClass::Function);
649712

713+
if (bool_ptr_id == 0) {
714+
return false;
715+
;
716+
}
717+
650718
uint32_t var_id = TakeNextId();
719+
720+
if (var_id == 0) {
721+
return false;
722+
}
723+
651724
std::unique_ptr<Instruction> returnFlag(new Instruction(
652725
context(), spv::Op::OpVariable, bool_ptr_id, var_id,
653726
std::initializer_list<Operand>{{SPV_OPERAND_TYPE_STORAGE_CLASS,
@@ -661,6 +734,7 @@ void MergeReturnPass::AddReturnFlag() {
661734
return_flag_ = &*entry_block->begin();
662735
context()->AnalyzeDefUse(return_flag_);
663736
context()->set_instr_block(return_flag_, entry_block);
737+
return true;
664738
}
665739

666740
std::vector<BasicBlock*> MergeReturnPass::CollectReturnBlocks(
@@ -739,16 +813,19 @@ bool MergeReturnPass::MergeReturnBlocks(
739813
return true;
740814
}
741815

742-
void MergeReturnPass::AddNewPhiNodes() {
816+
bool MergeReturnPass::AddNewPhiNodes() {
743817
std::list<BasicBlock*> order;
744818
cfg()->ComputeStructuredOrder(function_, &*function_->begin(), &order);
745819

746820
for (BasicBlock* bb : order) {
747-
AddNewPhiNodes(bb);
821+
if (!AddNewPhiNodes(bb)) {
822+
return false;
823+
}
748824
}
825+
return true;
749826
}
750827

751-
void MergeReturnPass::AddNewPhiNodes(BasicBlock* bb) {
828+
bool MergeReturnPass::AddNewPhiNodes(BasicBlock* bb) {
752829
// New phi nodes are needed for any id whose definition used to dominate |bb|,
753830
// but no longer dominates |bb|. These are found by walking the dominator
754831
// tree starting at the original immediate dominator of |bb| and ending at its
@@ -766,16 +843,19 @@ void MergeReturnPass::AddNewPhiNodes(BasicBlock* bb) {
766843

767844
BasicBlock* dominator = dom_tree->ImmediateDominator(bb);
768845
if (dominator == nullptr) {
769-
return;
846+
return true;
770847
}
771848

772849
BasicBlock* current_bb = context()->get_instr_block(original_dominator_[bb]);
773850
while (current_bb != nullptr && current_bb != dominator) {
774851
for (Instruction& inst : *current_bb) {
775-
CreatePhiNodesForInst(bb, inst);
852+
if (!CreatePhiNodesForInst(bb, inst)) {
853+
return false;
854+
}
776855
}
777856
current_bb = dom_tree->ImmediateDominator(current_bb);
778857
}
858+
return true;
779859
}
780860

781861
void MergeReturnPass::RecordImmediateDominators(Function* function) {
@@ -859,8 +939,12 @@ bool MergeReturnPass::CreateSingleCaseSwitch(BasicBlock* merge_target) {
859939
++split_pos;
860940
}
861941

942+
uint32_t new_block_id = TakeNextId();
943+
if (new_block_id == 0) {
944+
return false;
945+
}
862946
BasicBlock* old_block =
863-
start_block->SplitBasicBlock(context(), TakeNextId(), split_pos);
947+
start_block->SplitBasicBlock(context(), new_block_id, split_pos);
864948

865949
// Find DebugFunctionDefinition inst in the old block, and if we can find it,
866950
// move it to the entry block. Since DebugFunctionDefinition is not necessary

0 commit comments

Comments
 (0)