Skip to content

Commit 9fb503a

Browse files
committed
[CHERI] Update CallSiteInfo handling in RISCVExpandPseudoInsts for improved assertions in upstream
1 parent 094f9b1 commit 9fb503a

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ class RISCVExpandPseudo : public MachineFunctionPass {
116116
const Function *Fn,
117117
Register DestReg,
118118
bool TreatAsLibrary = false,
119-
bool CallImportTarget = false);
119+
bool CallImportTarget = false,
120+
const MachineInstr *OriginalCall = nullptr);
120121
/**
121122
* Helper that inserts a load from the import table identified by an import
122123
* and export table entry symbol.
@@ -128,7 +129,8 @@ class RISCVExpandPseudo : public MachineFunctionPass {
128129
MachineBasicBlock::iterator MBBI,
129130
MCSymbol *ImportSymbol, MCSymbol *ExportSymbol,
130131
const StringRef ImportName, Register DestReg,
131-
bool IsLibrary, bool IsPublic, bool CallImportTarget);
132+
bool IsLibrary, bool IsPublic, bool CallImportTarget,
133+
const MachineInstr* OriginalCall);
132134

133135
#ifndef NDEBUG
134136
unsigned getInstSizeInBytes(const MachineFunction &MF) const {
@@ -267,8 +269,10 @@ bool RISCVExpandPseudo::expandMI(MachineBasicBlock &MBB,
267269

268270
MachineBasicBlock *RISCVExpandPseudo::insertLoadOfImportTable(
269271
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
270-
const Function *Fn, Register DestReg, bool TreatAsLibrary,
271-
bool CallImportTarget) {
272+
const Function *Fn, Register DestReg,
273+
bool TreatAsLibrary,
274+
bool CallImportTarget,
275+
const MachineInstr *OriginalCall) {
272276
auto *MF = MBB.getParent();
273277
const StringRef ImportName = Fn->getName();
274278
// We can hit this code path if we need to do a library-style import
@@ -296,7 +300,7 @@ MachineBasicBlock *RISCVExpandPseudo::insertLoadOfImportTable(
296300
MCSymbol *ExportSymbol = MF->getContext().getOrCreateSymbol(ExportEntryName);
297301
return insertLoadOfImportTable(
298302
MBB, MBBI, ImportSymbol, ExportSymbol, ImportName, DestReg,
299-
IsLibrary || TreatAsLibrary, Fn->hasExternalLinkage(), CallImportTarget);
303+
IsLibrary || TreatAsLibrary, Fn->hasExternalLinkage(), CallImportTarget, OriginalCall);
300304
}
301305

302306
static const GlobalValue *resolveGlobalAlias(const GlobalValue *GV) {
@@ -309,7 +313,9 @@ static const GlobalValue *resolveGlobalAlias(const GlobalValue *GV) {
309313
MachineBasicBlock *RISCVExpandPseudo::insertLoadOfImportTable(
310314
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
311315
MCSymbol *ImportSymbol, MCSymbol *ExportSymbol, const StringRef ImportName,
312-
Register DestReg, bool IsLibrary, bool IsPublic, bool CallImportTarget) {
316+
Register DestReg,
317+
bool IsLibrary, bool IsPublic, bool CallImportTarget,
318+
const MachineInstr* OriginalCall) {
313319
auto *MF = MBB.getParent();
314320
const DebugLoc DL = MBBI->getDebugLoc();
315321
MachineBasicBlock *NewMBB =
@@ -328,9 +334,12 @@ MachineBasicBlock *RISCVExpandPseudo::insertLoadOfImportTable(
328334
.addReg(DestReg, RegState::Kill)
329335
.addMBB(NewMBB, RISCVII::MO_CHERIOT_COMPARTMENT_LO_I);
330336

331-
if (CallImportTarget)
332-
BuildMI(NewMBB, DL, TII->get(RISCV::C_CJALR))
337+
if (CallImportTarget) {
338+
auto NewCallMI = BuildMI(NewMBB, DL, TII->get(RISCV::C_CJALR))
333339
.addReg(DestReg, RegState::Kill);
340+
if (OriginalCall && OriginalCall->shouldUpdateAdditionalCallInfo())
341+
MF->moveAdditionalCallInfo(OriginalCall, NewCallMI);
342+
}
334343

335344
NewMBB->splice(NewMBB->end(), &MBB, std::next(MBBI), MBB.end());
336345
// Update machine-CFG edges.
@@ -408,8 +417,10 @@ bool RISCVExpandPseudo::expandCompartmentCall(MachineBasicBlock &MBB,
408417
BuildMI(NewMBB, DL, TII->get(RISCV::CLC_64), RISCV::X7_Y)
409418
.addReg(RISCV::X7_Y, RegState::Kill)
410419
.addMBB(NewMBB, RISCVII::MO_CHERIOT_COMPARTMENT_LO_I);
411-
BuildMI(NewMBB, DL, TII->get(RISCV::C_CJALR))
420+
auto NewCallMI = BuildMI(NewMBB, DL, TII->get(RISCV::C_CJALR))
412421
.addReg(RISCV::X7_Y, RegState::Kill);
422+
if (MI.shouldUpdateAdditionalCallInfo())
423+
MF->moveAdditionalCallInfo(&MI, NewCallMI);
413424

414425
// Move all the rest of the instructions to NewMBB.
415426
NewMBB->splice(NewMBB->end(), &MBB, std::next(MBBI), MBB.end());
@@ -455,7 +466,7 @@ bool RISCVExpandPseudo::expandLibraryCall(
455466
MI.setDesc(TII->get(RISCV::PseudoCCALL));
456467
return true;
457468
}
458-
insertLoadOfImportTable(MBB, MBBI, Fn, RISCV::X7_Y, true, true);
469+
insertLoadOfImportTable(MBB, MBBI, Fn, RISCV::X7_Y, true, true, &MI);
459470

460471
NextMBBI = MBB.end();
461472
} else if (Callee.isSymbol()) {
@@ -486,14 +497,17 @@ bool RISCVExpandPseudo::expandLibraryCall(
486497
MF->getContext().getOrCreateSymbol(ExportEntryName);
487498
insertLoadOfImportTable(MBB, MBBI, ImportSymbol, ExportSymbol,
488499
Callee.getSymbolName(), RISCV::X7_Y, true, true,
489-
true);
500+
true, &MI);
490501

491502
NextMBBI = MBB.end();
492503
} else {
493504
assert(Callee.isReg() && "Expected register operand");
494505
// Indirect library calls are just cjalr instructions.
495-
BuildMI(&MBB, MI.getDebugLoc(), TII->get(RISCV::C_CJALR)).add(Callee);
506+
auto NewCallMI = BuildMI(&MBB, MI.getDebugLoc(), TII->get(RISCV::C_CJALR)).add(Callee);
507+
if (MI.shouldUpdateAdditionalCallInfo())
508+
MF->moveAdditionalCallInfo(NewCallMI, &MI);
496509
}
510+
497511
MI.eraseFromParent();
498512
return true;
499513
}

0 commit comments

Comments
 (0)