diff --git a/debug.log b/debug.log new file mode 100644 index 0000000000000..b0a65a1bd5c97 --- /dev/null +++ b/debug.log @@ -0,0 +1,2 @@ +clang++: error: no such file or directory: 'test_jump_table.cpp' +clang++: error: no input files diff --git a/llvm/include/llvm/Transforms/IPO/JumpTableFinder.h b/llvm/include/llvm/Transforms/IPO/JumpTableFinder.h new file mode 100644 index 0000000000000..9d2fb3b405f8c --- /dev/null +++ b/llvm/include/llvm/Transforms/IPO/JumpTableFinder.h @@ -0,0 +1,62 @@ +// #ifndef LLVM_ANALYSIS_JUMPTABLEFINDERPASS_H +// #define LLVM_ANALYSIS_JUMPTABLEFINDERPASS_H + +// #include "llvm/IR/PassManager.h" +// #include "llvm/IR/Module.h" +// #include "llvm/Support/raw_ostream.h" +// #include + +// namespace llvm { + +// class JumptableFinderPass : public PassInfoMixin { +// public: +// /// Main entry point for the pass. Analyzes the module to find and analyze jump tables. +// PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); +// /// Implementation of the jump table finder. +// void jumptableFinderImpl(Module &M); + +// /// Analyze a SwitchInst for potential jump table patterns. +// void findJumpTableFromSwitch(SwitchInst *SI); + +// /// Analyze a GetElementPtrInst for jump table patterns. +// void analyzeJumpTable(GetElementPtrInst *GEP); + +// /// Analyze the index computation of a jump table. +// void analyzeIndex(Value *Index); + +// /// Find all potential targets for a jump table. +// void findTargets(GetElementPtrInst *GEP, std::set &Targets); + +// /// Check the density of a SwitchInst's cases to determine if it forms a jump table. +// bool checkDensity(SwitchInst *SI); + +// /// Check if a GetElementPtrInst leads to an indirect branch. +// bool leadsToIndirectBranch(GetElementPtrInst *GEP); +// }; + +// } // namespace llvm + +// #endif // LLVM_ANALYSIS_JUMPTABLEFINDERPASS_H + +#ifndef LLVM_TRANSFORMS_IPO_JUMPTABLEFINDER_H +#define LLVM_TRANSFORMS_IPO_JUMPTABLEFINDER_H + +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" // For PassInfoMixin and PreservedAnalyses +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Constants.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace llvm { + +class JumptableFinderPass : public PassInfoMixin { +public: + // Entry point for the pass + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); +}; + +} // namespace llvm + +#endif // LLVM_TRANSFORMS_IPO_JUMPTABLEFINDER_H diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp index 3ba45900e4569..3f56c5aca5be5 100644 --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -2862,6 +2862,7 @@ void AsmPrinter::emitJumpTableInfo() { // Pick the directive to use to print the jump table entries, and switch to // the appropriate section. const Function &F = MF->getFunction(); + // errs() << "F" << F.getName() << "\n"; const TargetLoweringObjectFile &TLOF = getObjFileLowering(); bool JTInDiffSection = !TLOF.shouldPutJumpTableInFunctionSection( MJTI->getEntryKind() == MachineJumpTableInfo::EK_LabelDifference32 || @@ -2882,27 +2883,34 @@ void AsmPrinter::emitJumpTableInfo() { for (unsigned JTI = 0, e = JT.size(); JTI != e; ++JTI) { const std::vector &JTBBs = JT[JTI].MBBs; + // If this jump table was deleted, ignore it. if (JTBBs.empty()) continue; + // errs() <<"test0" << "\n"; // For the EK_LabelDifference32 entry, if using .set avoids a relocation, /// emit a .set directive for each unique entry. if (MJTI->getEntryKind() == MachineJumpTableInfo::EK_LabelDifference32 && MAI->doesSetDirectiveSuppressReloc()) { + // errs() << "test1" <<"\n"; SmallPtrSet EmittedSets; const TargetLowering *TLI = MF->getSubtarget().getTargetLowering(); const MCExpr *Base = TLI->getPICJumpTableRelocBaseExpr(MF,JTI,OutContext); for (const MachineBasicBlock *MBB : JTBBs) { if (!EmittedSets.insert(MBB).second) continue; - + // errs() <<"test2" <<"\n"; // .set LJTSet, LBB32-base + MCSymbol *MCsy = MBB->getSymbol(); + MCSymbol *JTsetsy = GetJTSetSymbol(JTI, MBB->getNumber()); const MCExpr *LHS = - MCSymbolRefExpr::create(MBB->getSymbol(), OutContext); - OutStreamer->emitAssignment(GetJTSetSymbol(JTI, MBB->getNumber()), + MCSymbolRefExpr::create(MCsy, OutContext); + OutStreamer->emitAssignment(JTsetsy, MCBinaryExpr::createSub(LHS, Base, OutContext)); + // errs() << "Symbol" << MCsy->getName() << "\n"; + // errs() << "JTSymbol" << JTsetsy->getName() << "\n"; } } @@ -2921,8 +2929,10 @@ void AsmPrinter::emitJumpTableInfo() { // Defer MCAssembler based constant folding due to a performance issue. The // label differences will be evaluated at write time. - for (const MachineBasicBlock *MBB : JTBBs) + for (const MachineBasicBlock *MBB : JTBBs){ + // errs() <<"test6" << "\n"; emitJumpTableEntry(MJTI, MBB, JTI); + } } if (EmitJumpTableSizesSection) @@ -2987,6 +2997,7 @@ void AsmPrinter::emitJumpTableEntry(const MachineJumpTableInfo *MJTI, unsigned UID) const { assert(MBB && MBB->getNumber() >= 0 && "Invalid basic block"); const MCExpr *Value = nullptr; + // errs() << "EntryKind:" << MJTI->getEntryKind() <<"\n"; switch (MJTI->getEntryKind()) { case MachineJumpTableInfo::EK_Inline: llvm_unreachable("Cannot emit EK_Inline jump table entry"); @@ -3026,9 +3037,10 @@ void AsmPrinter::emitJumpTableEntry(const MachineJumpTableInfo *MJTI, // If the .set directive avoids relocations, this is emitted as: // .set L4_5_set_123, LBB123 - LJTI1_2 // .word L4_5_set_123 + MCSymbol *JTsetsy = GetJTSetSymbol(UID, MBB->getNumber()); if (MJTI->getEntryKind() == MachineJumpTableInfo::EK_LabelDifference32 && MAI->doesSetDirectiveSuppressReloc()) { - Value = MCSymbolRefExpr::create(GetJTSetSymbol(UID, MBB->getNumber()), + Value = MCSymbolRefExpr::create(JTsetsy, OutContext); break; } diff --git a/llvm/lib/Target/X86/CMakeLists.txt b/llvm/lib/Target/X86/CMakeLists.txt index 9553a8619feb5..f5dcf13e47ed2 100644 --- a/llvm/lib/Target/X86/CMakeLists.txt +++ b/llvm/lib/Target/X86/CMakeLists.txt @@ -23,6 +23,7 @@ tablegen(LLVM X86GenFoldTables.inc -gen-x86-fold-tables -asmwriternum=1) add_public_tablegen_target(X86CommonTableGen) set(sources + X86MatchJumptablePass.cpp X86ArgumentStackSlotRebase.cpp X86AsmPrinter.cpp X86AvoidTrailingCall.cpp diff --git a/llvm/lib/Target/X86/X86.h b/llvm/lib/Target/X86/X86.h index 48a3fe1934a96..5d4882b921a23 100644 --- a/llvm/lib/Target/X86/X86.h +++ b/llvm/lib/Target/X86/X86.h @@ -54,6 +54,8 @@ FunctionPass *createX86IndirectBranchTrackingPass(); /// This will prevent a stall when returning on the Atom. FunctionPass *createX86PadShortFunctions(); +FunctionPass *createX86MatchJumptablePass(); + /// Return a pass that selectively replaces certain instructions (like add, /// sub, inc, dec, some shifts, and some multiplies) by equivalent LEA /// instructions, in order to eliminate execution delays in some processors. diff --git a/llvm/lib/Target/X86/X86MatchJumptablePass.cpp b/llvm/lib/Target/X86/X86MatchJumptablePass.cpp new file mode 100644 index 0000000000000..1498bee37a9fa --- /dev/null +++ b/llvm/lib/Target/X86/X86MatchJumptablePass.cpp @@ -0,0 +1,276 @@ +// X86MatchJumptablePass.cpp +#include "X86MatchJumptablePass.h" +#include "X86.h" +#include "X86InstrInfo.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/PassRegistry.h" +#include "llvm/Pass.h" +#include "llvm/MC/MCContext.h" +#include "llvm/MC/MCSymbol.h" +#include "llvm/ADT/Statistic.h" + +#define DEBUG_TYPE "match-jump-table" + STATISTIC(RunCount, "Number of RunCount"); + STATISTIC(MaxEntrySize, "Size of MaxEntrySize"); +namespace llvm { + +char X86MatchJumptablePass::ID = 0; + +void initializeX86MatchJumptablePassPass(PassRegistry &Registry) { + RegisterPass X("match-jump-table", "Match Jump Table Pass", false, false); +} + +bool X86MatchJumptablePass::runOnMachineFunction(MachineFunction &MF) { + + Function &F = MF.getFunction(); + + // LLVM_DEBUG(dbgs() << "Function address: " << FuncAddr << "\n"); + + // Process jump tables + MachineJumpTableInfo *JumpTableInfo = MF.getJumpTableInfo(); + if (!JumpTableInfo) { + // LLVM_DEBUG(dbgs() << "No jump tables in this function.\n"); + return false; + } + + bool Modified = false; + // LLVM_DEBUG(dbgs() << "Jump Table Size#" << JumpTableInfo->getJumpTables().size() << "\n"); + + for (unsigned JTIndex = 0; JTIndex < JumpTableInfo->getJumpTables().size(); ++JTIndex) { + const MachineJumpTableEntry &JTEntry = JumpTableInfo->getJumpTables()[JTIndex]; + // LLVM_DEBUG(dbgs() << "FuncAddr:" << FuncAddr << "Jump Table #" << JTIndex << " contains " + // << JTEntry.MBBs.size() << " entries.\n"); + + // Handle indirect jump instruction + MachineInstr *indirectJumpInstr = traceIndirectJumps(MF, JTIndex, JumpTableInfo); + if (indirectJumpInstr) { + // Create label for indirect jump + std::string LabelName = std::to_string(RunCount) + "_IJUMP_" + std::to_string(JTIndex); + MCSymbol *Label = MF.getContext().getOrCreateSymbol(LabelName); + indirectJumpInstr->setPreInstrSymbol(MF, Label); + if (MaxEntrySize < JTEntry.MBBs.size()) { + MaxEntrySize = JTEntry.MBBs.size(); + } + for (unsigned EntryIndex = 0; EntryIndex < JTEntry.MBBs.size(); ++EntryIndex) { + MachineBasicBlock *TargetMBB = JTEntry.MBBs[EntryIndex]; + if (!TargetMBB->empty()) { + std::string EntryLabelName = std::to_string(RunCount) + "_JTENTRY_" + std::to_string(JTIndex) + "_" + std::to_string(EntryIndex); + MCSymbol *EntryLabel = MF.getContext().getOrCreateSymbol(EntryLabelName); + + // Set label only on first instruction + MachineInstr &FirstInstr = TargetMBB->front(); + FirstInstr.setPreInstrSymbol(MF, EntryLabel); + + // LLVM_DEBUG(dbgs() << "Created label for jump table entry: " << EntryLabelName << "\n"); + } + } + RunCount ++; + } + Modified = true; + } + + + return Modified; +} + + +MachineInstr* X86MatchJumptablePass::traceIndirectJumps(MachineFunction &MF, + unsigned JTIndex, + MachineJumpTableInfo *JumpTableInfo) { + const MachineJumpTableEntry &JTEntry = JumpTableInfo->getJumpTables()[JTIndex]; + // LLVM_DEBUG(dbgs() << "Tracing indirect jumps:\n"); + for (auto &MBB : MF) { + // LLVM_DEBUG(dbgs() << " Checking BB: " << MBB.getName() << "\n"); + for (auto &MI : MBB) { + // LLVM_DEBUG(dbgs() << " Checking instruction: " << MI << "\n"); + if (MI.isIndirectBranch()) { + // LLVM_DEBUG(dbgs() << " Found indirect jump: " << MI << "\n"); + + if (isJumpTableRelated(MI, JTEntry, MF)) { + // LLVM_DEBUG(dbgs() << " This indirect jump is related to Jump Table #" + // << JTIndex << "\n"); + return &MI; + } else { + // LLVM_DEBUG(dbgs() << " Jump is not related to this jump table\n"); + } + } + } + } + + // LLVM_DEBUG(dbgs() << " No related indirect jump found\n"); + return nullptr; +} + +bool X86MatchJumptablePass::isJumpTableLoad(MachineInstr &MI, const MachineJumpTableEntry &JTEntry) { + // LLVM_DEBUG(dbgs() << "\nAnalyzing potential jump table load instruction: " << MI << "\n"); + + // First check memory operands for jump table metadata + for (const MachineMemOperand *MMO : MI.memoperands()) { + // LLVM_DEBUG(dbgs() << " Checking memory operand flags: " << MMO->getFlags() << "\n"); + if (MMO->getValue()) { + StringRef ValueName = MMO->getValue()->getName(); + // LLVM_DEBUG(dbgs() << " Memory value name: '" << ValueName << "'\n"); + if (ValueName.contains("jump-table")) { + // LLVM_DEBUG(dbgs() << " Found jump table in memory value name\n"); + return true; + } + } + + // Check if this is a jump table load directly from memory operand comments + if (MI.getDesc().mayLoad() && MI.hasOneMemOperand()) { + // Look for jump table reference in the instruction's debug info or comments + if (MI.getDebugLoc()) { + std::string Comment; + raw_string_ostream OS(Comment); + MI.print(OS); + if (Comment.find("jump-table") != std::string::npos) { + // LLVM_DEBUG(dbgs() << " Found jump table reference in instruction comment\n"); + return true; + } + } + } + } + + // Check for the MOVSX pattern + if (MI.getOpcode() == X86::MOVSX64rm32) { + // LLVM_DEBUG(dbgs() << " Found MOVSX64rm32 instruction\n"); + Register BaseReg; + + // Find base register + for (const MachineOperand &MO : MI.operands()) { + if (MO.isReg() && MO.isUse()) { + BaseReg = MO.getReg(); + // LLVM_DEBUG(dbgs() << " Found base register: " << printReg(BaseReg, nullptr) << "\n"); + break; + } + } + + if (BaseReg) { + // Look for preceding LEA + MachineBasicBlock::iterator MBBI = MI; + const MachineBasicBlock *MBB = MI.getParent(); + + // LLVM_DEBUG(dbgs() << " Looking for LEA defining register: " << printReg(BaseReg, nullptr) << "\n"); + + while (MBBI != MBB->begin()) { + --MBBI; + // LLVM_DEBUG(dbgs() << " Checking: " << *MBBI << "\n"); + + if (MBBI->getOpcode() == X86::LEA64r) { + // LLVM_DEBUG(dbgs() << " Found LEA64r\n"); + + // Verify this LEA defines our base register + const MachineOperand &DefReg = MBBI->getOperand(0); + if (!DefReg.isReg() || DefReg.getReg() != BaseReg) { + // LLVM_DEBUG(dbgs() << " LEA defines different register\n"); + continue; + } + + // Check for jump table symbol + for (const MachineOperand &MO : MBBI->operands()) { + if (MO.isSymbol()) { + StringRef SymName = MO.getSymbolName(); + // LLVM_DEBUG(dbgs() << " Checking symbol: '" << SymName << "'\n"); + if (SymName.contains("jump-table")) { + // LLVM_DEBUG(dbgs() << " Found jump table symbol!\n"); + return true; + } + } + } + } + } + // LLVM_DEBUG(dbgs() << " No matching LEA found\n"); + } + } + + return false; +} + +bool X86MatchJumptablePass::isJumpTableRelated(MachineInstr &MI, + const MachineJumpTableEntry &JTEntry, + MachineFunction &MF) { + if (!MI.isIndirectBranch()) { + // LLVM_DEBUG(dbgs() << "Not an indirect branch, skipping\n"); + return false; + } + + // LLVM_DEBUG(dbgs() << "\nAnalyzing indirect jump: " << MI << "\n"); + + // Get jump register + Register JumpReg; + for (const MachineOperand &MO : MI.operands()) { + if (MO.isReg() && MO.isUse()) { + JumpReg = MO.getReg(); + // LLVM_DEBUG(dbgs() << "Found jump register: " << printReg(JumpReg, nullptr) << "\n"); + break; + } + } + + if (!JumpReg) { + // LLVM_DEBUG(dbgs() << "No jump register found\n"); + return false; + } + + SmallVector Worklist; + SmallPtrSet Visited; + + // LLVM_DEBUG(dbgs() << "Starting backward analysis from register " << printReg(JumpReg, nullptr) << "\n"); + + for (MachineInstr &DefMI : MF.getRegInfo().def_instructions(JumpReg)) { + Worklist.push_back(&DefMI); + // LLVM_DEBUG(dbgs() << "Added to worklist: " << DefMI << "\n"); + } + + while (!Worklist.empty()) { + MachineInstr *CurrMI = Worklist.pop_back_val(); + if (!Visited.insert(CurrMI).second) { + // LLVM_DEBUG(dbgs() << "Already visited: " << *CurrMI << "\n"); + continue; + } + + // LLVM_DEBUG(dbgs() << "Analyzing instruction: " << *CurrMI << "\n"); + + if (isJumpTableLoad(*CurrMI, JTEntry)) { + // LLVM_DEBUG(dbgs() << "Found jump table load!\n"); + return true; + } + + if (CurrMI->getOpcode() == X86::ADD64rr) { + // LLVM_DEBUG(dbgs() << "Found ADD64rr, checking operands\n"); + for (const MachineOperand &MO : CurrMI->operands()) { + if (MO.isReg() && MO.isUse()) { + // LLVM_DEBUG(dbgs() << "Checking register operand: " << printReg(MO.getReg(), nullptr) << "\n"); + for (MachineInstr &DefMI : MF.getRegInfo().def_instructions(MO.getReg())) { + if (isJumpTableLoad(DefMI, JTEntry)) { + // LLVM_DEBUG(dbgs() << "Found jump table load via ADD operand!\n"); + return true; + } + } + } + } + } + + // Add uses to worklist + for (const MachineOperand &MO : CurrMI->operands()) { + if (MO.isReg() && MO.isUse()) { + // LLVM_DEBUG(dbgs() << "Adding definitions of register " << printReg(MO.getReg(), nullptr) << " to worklist\n"); + for (MachineInstr &DefMI : MF.getRegInfo().def_instructions(MO.getReg())) { + if (!Visited.count(&DefMI)) { + Worklist.push_back(&DefMI); + // LLVM_DEBUG(dbgs() << "Added to worklist: " << DefMI << "\n"); + } + } + } + } + } + + // LLVM_DEBUG(dbgs() << "No jump table relation found\n"); + return false; +} + +FunctionPass *createX86MatchJumptablePass() { + return new X86MatchJumptablePass(); +} + +} // end namespace llvm \ No newline at end of file diff --git a/llvm/lib/Target/X86/X86MatchJumptablePass.h b/llvm/lib/Target/X86/X86MatchJumptablePass.h new file mode 100644 index 0000000000000..bec6ec9cb5972 --- /dev/null +++ b/llvm/lib/Target/X86/X86MatchJumptablePass.h @@ -0,0 +1,48 @@ +// X86MatchJumptablePass.h +#ifndef X86_MATCH_JUMPTABLE_PASS_H +#define X86_MATCH_JUMPTABLE_PASS_H + +#include "MCTargetDesc/X86MCTargetDesc.h" +#include "X86.h" +#include "X86InstrInfo.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/CodeGen/MachineJumpTableInfo.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Constants.h" +#include "llvm/CodeGen/TargetInstrInfo.h" + +namespace llvm { + +class X86MatchJumptablePass : public MachineFunctionPass { +private: + MachineInstr* traceIndirectJumps(MachineFunction &MF, unsigned JTIndex, + MachineJumpTableInfo *JumpTableInfo); + bool isJumpTableRelated(MachineInstr &MI, const MachineJumpTableEntry &JTEntry, + MachineFunction &MF); + bool isJumpTableLoad(MachineInstr &MI, const MachineJumpTableEntry &JTEntry); + bool isRegUsedInJumpTableLoad(Register Reg,MachineFunction &MF, + const MachineJumpTableEntry &JTEntry); + +public: + static char ID; + + X86MatchJumptablePass() : MachineFunctionPass(ID) {} + + bool runOnMachineFunction(MachineFunction &MF) override; + StringRef getPassName() const override { return "Match Jump Table Pass"; } +}; + +FunctionPass *createX86MatchJumptablePass(); + +// Pass initialization declaration +void initializeX86MatchJumptablePassPass(PassRegistry &Registry); + +} // namespace llvm + +#endif // X86_MATCH_JUMPTABLE_PASS_H + diff --git a/llvm/lib/Target/X86/X86TargetMachine.cpp b/llvm/lib/Target/X86/X86TargetMachine.cpp index 20dfdd27b33df..7faeeb2af64e8 100644 --- a/llvm/lib/Target/X86/X86TargetMachine.cpp +++ b/llvm/lib/Target/X86/X86TargetMachine.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "X86TargetMachine.h" +#include "X86MatchJumptablePass.h" #include "MCTargetDesc/X86MCTargetDesc.h" #include "TargetInfo/X86TargetInfo.h" #include "X86.h" @@ -596,6 +597,7 @@ void X86PassConfig::addPreEmitPass() { addPass(createBreakFalseDeps()); } + addPass(createX86MatchJumptablePass()); addPass(createX86IndirectBranchTrackingPass()); addPass(createX86IssueVZeroUpperPass());