diff --git a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp index dc3db668077a5..3ee61aeb2e500 100644 --- a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp +++ b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp @@ -111,24 +111,23 @@ class RISCVExpandPseudo : public MachineFunctionPass { * function is / may be exported from this compartment but, at this call site, * should be treated as a library call. */ - MachineBasicBlock *insertLoadOfImportTable(MachineBasicBlock &MBB, - MachineBasicBlock::iterator MBBI, - const Function *Fn, - Register DestReg, - bool TreatAsLibrary = false, - bool CallImportTarget = false); + MachineBasicBlock * + insertLoadOfImportTable(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, const Function *Fn, + Register DestReg, bool TreatAsLibrary = false, + bool CallImportTarget = false, + const MachineInstr *OriginalCall = nullptr); /** * Helper that inserts a load from the import table identified by an import * and export table entry symbol. * * Calls the result if `CallImportTarget` is true. */ - MachineBasicBlock * - insertLoadOfImportTable(MachineBasicBlock &MBB, - MachineBasicBlock::iterator MBBI, - MCSymbol *ImportSymbol, MCSymbol *ExportSymbol, - const StringRef ImportName, Register DestReg, - bool IsLibrary, bool IsPublic, bool CallImportTarget); + MachineBasicBlock *insertLoadOfImportTable( + MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, + MCSymbol *ImportSymbol, MCSymbol *ExportSymbol, + const StringRef ImportName, Register DestReg, bool IsLibrary, + bool IsPublic, bool CallImportTarget, const MachineInstr *OriginalCall); #ifndef NDEBUG unsigned getInstSizeInBytes(const MachineFunction &MF) const { @@ -268,7 +267,7 @@ bool RISCVExpandPseudo::expandMI(MachineBasicBlock &MBB, MachineBasicBlock *RISCVExpandPseudo::insertLoadOfImportTable( MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, const Function *Fn, Register DestReg, bool TreatAsLibrary, - bool CallImportTarget) { + bool CallImportTarget, const MachineInstr *OriginalCall) { auto *MF = MBB.getParent(); const StringRef ImportName = Fn->getName(); // We can hit this code path if we need to do a library-style import @@ -296,7 +295,8 @@ MachineBasicBlock *RISCVExpandPseudo::insertLoadOfImportTable( MCSymbol *ExportSymbol = MF->getContext().getOrCreateSymbol(ExportEntryName); return insertLoadOfImportTable( MBB, MBBI, ImportSymbol, ExportSymbol, ImportName, DestReg, - IsLibrary || TreatAsLibrary, Fn->hasExternalLinkage(), CallImportTarget); + IsLibrary || TreatAsLibrary, Fn->hasExternalLinkage(), CallImportTarget, + OriginalCall); } static const GlobalValue *resolveGlobalAlias(const GlobalValue *GV) { @@ -309,7 +309,8 @@ static const GlobalValue *resolveGlobalAlias(const GlobalValue *GV) { MachineBasicBlock *RISCVExpandPseudo::insertLoadOfImportTable( MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, MCSymbol *ImportSymbol, MCSymbol *ExportSymbol, const StringRef ImportName, - Register DestReg, bool IsLibrary, bool IsPublic, bool CallImportTarget) { + Register DestReg, bool IsLibrary, bool IsPublic, bool CallImportTarget, + const MachineInstr *OriginalCall) { auto *MF = MBB.getParent(); const DebugLoc DL = MBBI->getDebugLoc(); MachineBasicBlock *NewMBB = @@ -328,9 +329,12 @@ MachineBasicBlock *RISCVExpandPseudo::insertLoadOfImportTable( .addReg(DestReg, RegState::Kill) .addMBB(NewMBB, RISCVII::MO_CHERIOT_COMPARTMENT_LO_I); - if (CallImportTarget) - BuildMI(NewMBB, DL, TII->get(RISCV::C_CJALR)) - .addReg(DestReg, RegState::Kill); + if (CallImportTarget) { + auto NewCallMI = BuildMI(NewMBB, DL, TII->get(RISCV::C_CJALR)) + .addReg(DestReg, RegState::Kill); + if (OriginalCall && OriginalCall->shouldUpdateAdditionalCallInfo()) + MF->moveAdditionalCallInfo(OriginalCall, NewCallMI); + } NewMBB->splice(NewMBB->end(), &MBB, std::next(MBBI), MBB.end()); // Update machine-CFG edges. @@ -408,8 +412,10 @@ bool RISCVExpandPseudo::expandCompartmentCall(MachineBasicBlock &MBB, BuildMI(NewMBB, DL, TII->get(RISCV::CLC_64), RISCV::X7_Y) .addReg(RISCV::X7_Y, RegState::Kill) .addMBB(NewMBB, RISCVII::MO_CHERIOT_COMPARTMENT_LO_I); - BuildMI(NewMBB, DL, TII->get(RISCV::C_CJALR)) - .addReg(RISCV::X7_Y, RegState::Kill); + auto NewCallMI = BuildMI(NewMBB, DL, TII->get(RISCV::C_CJALR)) + .addReg(RISCV::X7_Y, RegState::Kill); + if (MI.shouldUpdateAdditionalCallInfo()) + MF->moveAdditionalCallInfo(&MI, NewCallMI); // Move all the rest of the instructions to NewMBB. NewMBB->splice(NewMBB->end(), &MBB, std::next(MBBI), MBB.end()); @@ -455,7 +461,7 @@ bool RISCVExpandPseudo::expandLibraryCall( MI.setDesc(TII->get(RISCV::PseudoCCALL)); return true; } - insertLoadOfImportTable(MBB, MBBI, Fn, RISCV::X7_Y, true, true); + insertLoadOfImportTable(MBB, MBBI, Fn, RISCV::X7_Y, true, true, &MI); NextMBBI = MBB.end(); } else if (Callee.isSymbol()) { @@ -486,14 +492,18 @@ bool RISCVExpandPseudo::expandLibraryCall( MF->getContext().getOrCreateSymbol(ExportEntryName); insertLoadOfImportTable(MBB, MBBI, ImportSymbol, ExportSymbol, Callee.getSymbolName(), RISCV::X7_Y, true, true, - true); + true, &MI); NextMBBI = MBB.end(); } else { assert(Callee.isReg() && "Expected register operand"); // Indirect library calls are just cjalr instructions. - BuildMI(&MBB, MI.getDebugLoc(), TII->get(RISCV::C_CJALR)).add(Callee); + auto NewCallMI = + BuildMI(&MBB, MI.getDebugLoc(), TII->get(RISCV::C_CJALR)).add(Callee); + if (MI.shouldUpdateAdditionalCallInfo()) + MF->moveAdditionalCallInfo(NewCallMI, &MI); } + MI.eraseFromParent(); return true; }