diff --git a/llvm/include/llvm/CodeGen/MachineLaneSSAUpdater.h b/llvm/include/llvm/CodeGen/MachineLaneSSAUpdater.h new file mode 100644 index 0000000000000..f0f8144904f86 --- /dev/null +++ b/llvm/include/llvm/CodeGen/MachineLaneSSAUpdater.h @@ -0,0 +1,199 @@ +//===- MachineLaneSSAUpdater.h - SSA repair for Machine IR (lane-aware) -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// === MachineLaneSSAUpdater Design Notes === +// + +#ifndef LLVM_CODEGEN_MACHINELANESSAUPDATER_H +#define LLVM_CODEGEN_MACHINELANESSAUPDATER_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" // SmallVector +#include "llvm/CodeGen/LiveInterval.h" // LiveRange +#include "llvm/CodeGen/Register.h" // Register +#include "llvm/CodeGen/SlotIndexes.h" // SlotIndex +#include "llvm/CodeGen/TargetRegisterInfo.h" // For inline function +#include "llvm/MC/LaneBitmask.h" // LaneBitmask + +namespace llvm { + +// Forward declarations to avoid heavy includes in the header. +class MachineFunction; +class MachineBasicBlock; +class MachineInstr; +class LiveIntervals; +class LiveRange; +class MachineDominatorTree; +class MachinePostDominatorTree; // optional if you choose to use it + +//===----------------------------------------------------------------------===// +// MachineLaneSSAUpdater: universal SSA repair for Machine IR (lane-aware) +// +// Primary Use Case: repairSSAForNewDef() +// - Caller creates a new instruction that defines an existing vreg (violating SSA) +// - This function creates a new vreg (or uses a caller-provided one), +// replaces the operand, and repairs SSA +// - Example: Insert "OrigVReg = ADD ..." and call repairSSAForNewDef() +// - Works for full register and subregister definitions +// - Handles all scenarios including spill/reload +// +// Advanced Usage: Caller-provided NewVReg +// - By default, repairSSAForNewDef() creates a new virtual register automatically +// - For special cases (e.g., subregister reloads where the spiller already +// created a register of a specific class), caller can provide NewVReg +// - This gives full control over register class selection when needed +//===----------------------------------------------------------------------===// +class MachineLaneSSAUpdater { +public: + MachineLaneSSAUpdater(MachineFunction &MF, + LiveIntervals &LIS, + MachineDominatorTree &MDT, + const TargetRegisterInfo &TRI) + : MF(MF), LIS(LIS), MDT(MDT), TRI(TRI) {} + + // Repair SSA for a new definition that violates SSA form + // + // Parameters: + // NewDefMI: Instruction with a def operand that currently defines OrigVReg (violating SSA) + // OrigVReg: The virtual register being redefined + // NewVReg: (Optional) Pre-allocated virtual register to use instead of auto-creating one + // + // This function will: + // 1. Find the def operand in NewDefMI that defines OrigVReg + // 2. Derive the lane mask from the operand's subreg index (if any) + // 3. Use NewVReg if provided, or create a new virtual register with appropriate class + // 4. Replace the operand in NewDefMI to define the new vreg + // 5. Perform SSA repair (insert PHIs, rewrite uses) + // + // When to provide NewVReg: + // - Leave it empty (default) for most cases - automatic class selection works well + // - Provide it when you need precise control over register class selection + // - Common use case: subregister spill/reload where target-specific constraints apply + // - Example: Reloading a 96-bit subregister requires vreg_96 class (not vreg_128) + // + // Returns: The SSA-repaired virtual register (either NewVReg or auto-created) + Register repairSSAForNewDef(MachineInstr &NewDefMI, Register OrigVReg, + Register NewVReg = Register()); + +private: + // Common SSA repair logic + void performSSARepair(Register NewVReg, Register OrigVReg, + LaneBitmask DefMask, MachineBasicBlock *DefBB); + + // Optional knobs (fluent style); no-ops until implemented in .cpp. + MachineLaneSSAUpdater &setUndefEdgePolicy(bool MaterializeImplicitDef) { + UndefEdgeAsImplicitDef = MaterializeImplicitDef; return *this; } + MachineLaneSSAUpdater &setVerifyOnExit(bool Enable) { + VerifyOnExit = Enable; return *this; } + + // --- Internal helpers --- + + // Index MI in SlotIndexes / LIS maps immediately after insertion. + // Returns the SlotIndex assigned to the instruction. + SlotIndex indexNewInstr(MachineInstr &MI); + + // Extend the main live range and the specific subranges at MI's index + // for the lanes actually used/defined. + void extendPreciselyAt(const Register VReg, + const SmallVector &LaneMasks, + const MachineInstr &AtMI); + + // Compute pruned IDF for a set of definition blocks (usually {block(NewDef)}), + // intersected with blocks where OrigVReg lanes specified by DefMask are live-in. + void computePrunedIDF(Register OrigVReg, + LaneBitmask DefMask, + ArrayRef NewDefBlocks, + SmallVectorImpl &OutIDFBlocks); + + // Insert lane-aware Machine PHIs with iterative worklist processing. + // Seeds with InitialVReg definition, computes IDF, places PHIs, repeats until convergence. + // Returns all PHI result registers created during the iteration. + SmallVector insertLaneAwarePHI(Register InitialVReg, + Register OrigVReg, + LaneBitmask DefMask, + MachineBasicBlock *InitialDefBB); + + // Helper: Create PHI in a specific block with per-edge lane analysis + Register createPHIInBlock(MachineBasicBlock &JoinMBB, + Register OrigVReg, + Register NewVReg, + LaneBitmask DefMask); + + // Rewrite dominated uses of OrigVReg to NewSSA according to the + // exact/subset/super policy; create REG_SEQUENCE only when needed. + void rewriteDominatedUses(Register OrigVReg, + Register NewSSA, + LaneBitmask MaskToRewrite); + + // Internal helper methods for use rewriting + VNInfo *incomingOnEdge(LiveInterval &LI, MachineInstr *Phi, MachineOperand &PhiOp); + bool defReachesUse(MachineInstr *DefMI, MachineInstr *UseMI, MachineOperand &UseOp); + LaneBitmask operandLaneMask(const MachineOperand &MO); + Register buildRSForSuperUse(MachineInstr *UseMI, MachineOperand &MO, + Register OldVR, Register NewVR, LaneBitmask MaskToRewrite, + LiveInterval &LI, const TargetRegisterClass *OpRC, + SlotIndex &OutIdx, SmallVectorImpl &LanesToExtend); + void extendAt(LiveInterval &LI, SlotIndex Idx, ArrayRef Lanes); + void updateDeadFlags(Register Reg); + + // --- Data members --- + MachineFunction &MF; + LiveIntervals &LIS; + MachineDominatorTree &MDT; + const TargetRegisterInfo &TRI; + + bool UndefEdgeAsImplicitDef = true; // policy hook + bool VerifyOnExit = true; // run MF.verify()/LI.verify() at end +}; + +/// Get the subregister index that corresponds to the given lane mask. +/// \param Mask The lane mask to convert to a subregister index +/// \param TRI The target register info (provides target-specific subregister mapping) +/// \return The subregister index, or 0 if no single subregister matches +inline unsigned getSubRegIndexForLaneMask(LaneBitmask Mask, const TargetRegisterInfo *TRI) { + if (Mask.none()) + return 0; // No subregister + + // Iterate through all subregister indices to find a match + for (unsigned SubIdx = 1; SubIdx < TRI->getNumSubRegIndices(); ++SubIdx) { + LaneBitmask SubMask = TRI->getSubRegIndexLaneMask(SubIdx); + if (SubMask == Mask) { + return SubIdx; + } + } + + // No exact match found - this might be a composite mask requiring REG_SEQUENCE + return 0; +} + +// DenseMapInfo specialization for LaneBitmask +template<> +struct DenseMapInfo { + static inline LaneBitmask getEmptyKey() { + // Use a specific bit pattern for empty key + return LaneBitmask(~0U - 1); + } + + static inline LaneBitmask getTombstoneKey() { + // Use a different bit pattern for tombstone + return LaneBitmask(~0U); + } + + static unsigned getHashValue(const LaneBitmask &Val) { + return (unsigned)Val.getAsInteger(); + } + + static bool isEqual(const LaneBitmask &LHS, const LaneBitmask &RHS) { + return LHS == RHS; + } +}; + +} // end namespace llvm + +#endif // LLVM_CODEGEN_MACHINELANESSAUPDATER_H \ No newline at end of file diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt index f8f9bbba53e43..68a57539fe255 100644 --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -148,6 +148,7 @@ add_llvm_component_library(LLVMCodeGen MachineSizeOpts.cpp MachineSSAContext.cpp MachineSSAUpdater.cpp + MachineLaneSSAUpdater.cpp MachineStripDebug.cpp MachineTraceMetrics.cpp MachineUniformityAnalysis.cpp diff --git a/llvm/lib/CodeGen/MachineLaneSSAUpdater.cpp b/llvm/lib/CodeGen/MachineLaneSSAUpdater.cpp new file mode 100644 index 0000000000000..0ffd489a93ce0 --- /dev/null +++ b/llvm/lib/CodeGen/MachineLaneSSAUpdater.cpp @@ -0,0 +1,940 @@ +//===- MachineLaneSSAUpdater.cpp - SSA repair for Machine IR (lane-aware) ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implementation of the MachineLaneSSAUpdater - a universal SSA repair utility +// for Machine IR that handles both regular new definitions and reload-after- +// spill scenarios with full subregister lane awareness. +// +// Key features: +// - Two explicit entry points: +// * repairSSAForNewDef - Common use case: caller creates instruction defining +// existing vreg (violating SSA), updater creates new vreg and repairs +// * addDefAndRepairAfterSpill - Spill/reload use case: caller creates instruction +// with new vreg, updater repairs SSA using spill-time EndPoints +// - Lane-aware PHI insertion with per-edge masks +// - Pruned IDF computation (NewDefBlocks ∩ LiveIn(OldVR)) +// - Precise LiveInterval extension using captured EndPoints +// - REG_SEQUENCE insertion only when necessary +// - Preservation of undef/dead flags on partial definitions +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/MachineLaneSSAUpdater.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/CodeGen/LiveIntervals.h" +#include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineDominators.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/SlotIndexes.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/GenericIteratedDominanceFrontier.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "machine-lane-ssa-updater" + +using namespace llvm; + +//===----------------------------------------------------------------------===// +// MachineLaneSSAUpdater Implementation +//===----------------------------------------------------------------------===// + +Register MachineLaneSSAUpdater::repairSSAForNewDef(MachineInstr &NewDefMI, + Register OrigVReg, + Register NewVReg) { + LLVM_DEBUG(dbgs() << "MachineLaneSSAUpdater::repairSSAForNewDef VReg=" << OrigVReg); + if (NewVReg.isValid()) { + LLVM_DEBUG(dbgs() << ", caller-provided NewVReg=" << NewVReg); + } + LLVM_DEBUG(dbgs() << "\n"); + + MachineRegisterInfo &MRI = MF.getRegInfo(); + + // Step 1: Find the def operand that currently defines OrigVReg (violating SSA) + MachineOperand *DefOp = nullptr; + unsigned DefOpIdx = 0; + for (MachineOperand &MO : NewDefMI.defs()) { + if (MO.getReg() == OrigVReg) { + DefOp = &MO; + break; + } + ++DefOpIdx; + } + + assert(DefOp && "NewDefMI should have a def operand for OrigVReg"); + assert(DefOp->isDef() && "Found operand should be a definition"); + + // Step 2: Derive DefMask from the operand's subreg index (if any) + unsigned SubRegIdx = DefOp->getSubReg(); + LaneBitmask DefMask; + + if (SubRegIdx) { + // Partial register definition - get lane mask for this subreg + DefMask = TRI.getSubRegIndexLaneMask(SubRegIdx); + LLVM_DEBUG(dbgs() << " Partial def with subreg " << TRI.getSubRegIndexName(SubRegIdx) + << ", DefMask=" << PrintLaneMask(DefMask) << "\n"); + } else { + // Full register definition - get all lanes for this register class + DefMask = MRI.getMaxLaneMaskForVReg(OrigVReg); + LLVM_DEBUG(dbgs() << " Full register def, DefMask=" << PrintLaneMask(DefMask) << "\n"); + } + + // Step 3: Create or use provided new virtual register + Register NewSSAVReg; + if (NewVReg.isValid()) { + // Caller provided a register - use it + NewSSAVReg = NewVReg; + const TargetRegisterClass *RC = MRI.getRegClass(NewSSAVReg); + LLVM_DEBUG(dbgs() << " Using caller-provided SSA vreg " << NewSSAVReg + << " with RC=" << TRI.getRegClassName(RC) << "\n"); + } else { + // Create a new virtual register with appropriate register class + // If this is a subreg def, we need the class for the subreg, not the full reg + const TargetRegisterClass *RC; + if (SubRegIdx) { + // For subreg defs, get the subreg class + const TargetRegisterClass *OrigRC = MRI.getRegClass(OrigVReg); + RC = TRI.getSubRegisterClass(OrigRC, SubRegIdx); + assert(RC && "Failed to get subregister class for subreg def - would create incorrect MIR"); + } else { + // For full reg defs, use the same class as OrigVReg + RC = MRI.getRegClass(OrigVReg); + } + + NewSSAVReg = MRI.createVirtualRegister(RC); + LLVM_DEBUG(dbgs() << " Created new SSA vreg " << NewSSAVReg << " with RC=" << TRI.getRegClassName(RC) << "\n"); + } + + // Step 4: Replace the operand in NewDefMI to define the new vreg + // If this was a subreg def, the new vreg is a full register of the subreg class + // so we clear the subreg index (e.g., %1.sub0:vreg_64 becomes %3:vgpr_32) + DefOp->setReg(NewSSAVReg); + if (SubRegIdx) { + DefOp->setSubReg(0); + LLVM_DEBUG(dbgs() << " Replaced operand: " << OrigVReg << "." << TRI.getSubRegIndexName(SubRegIdx) + << " -> " << NewSSAVReg << " (full register)\n"); + } else { + LLVM_DEBUG(dbgs() << " Replaced operand: " << OrigVReg << " -> " << NewSSAVReg << "\n"); + } + + // Step 5: Index the new instruction in SlotIndexes/LIS + indexNewInstr(NewDefMI); + + // Step 6: Perform common SSA repair (PHI placement + use rewriting) + // LiveInterval for NewSSAVReg will be created by getInterval() as needed + performSSARepair(NewSSAVReg, OrigVReg, DefMask, NewDefMI.getParent()); + + // Step 7: If SSA repair created subregister uses of OrigVReg (e.g., in PHIs or REG_SEQUENCEs), + // recompute its LiveInterval to create subranges + LaneBitmask AllLanes = MRI.getMaxLaneMaskForVReg(OrigVReg); + if (DefMask != AllLanes) { + LiveInterval &OrigLI = LIS.getInterval(OrigVReg); + if (!OrigLI.hasSubRanges()) { + // Check if any uses now access OrigVReg with subregister indices + bool HasSubregUses = false; + for (const MachineOperand &MO : MRI.use_operands(OrigVReg)) { + if (MO.getSubReg() != 0) { + HasSubregUses = true; + break; + } + } + + if (HasSubregUses) { + LLVM_DEBUG(dbgs() << " Recomputing LiveInterval for " << OrigVReg + << " after SSA repair created subregister uses\n"); + LIS.removeInterval(OrigVReg); + LIS.createAndComputeVirtRegInterval(OrigVReg); + } + } + } + + LLVM_DEBUG(dbgs() << " repairSSAForNewDef complete, returning " << NewSSAVReg << "\n"); + return NewSSAVReg; +} + +//===----------------------------------------------------------------------===// +// Common SSA Repair Logic +//===----------------------------------------------------------------------===// + +void MachineLaneSSAUpdater::performSSARepair(Register NewVReg, Register OrigVReg, + LaneBitmask DefMask, MachineBasicBlock *DefBB) { + LLVM_DEBUG(dbgs() << "MachineLaneSSAUpdater::performSSARepair NewVReg=" << NewVReg + << " OrigVReg=" << OrigVReg << " DefMask=" << PrintLaneMask(DefMask) << "\n"); + + // Step 1: Use worklist-driven PHI placement + SmallVector AllPHIVRegs = insertLaneAwarePHI(NewVReg, OrigVReg, DefMask, DefBB); + + // Step 2: Rewrite dominated uses once for each new register + // Note: getInterval() will automatically create LiveIntervals if needed + rewriteDominatedUses(OrigVReg, NewVReg, DefMask); + for (Register PHIVReg : AllPHIVRegs) { + rewriteDominatedUses(OrigVReg, PHIVReg, DefMask); + } + + // Step 3: Renumber values if needed + LiveInterval &NewLI = LIS.getInterval(NewVReg); + NewLI.RenumberValues(); + + // Also renumber PHI intervals + for (Register PHIVReg : AllPHIVRegs) { + LiveInterval &PHILI = LIS.getInterval(PHIVReg); + PHILI.RenumberValues(); + } + + // Recompute OrigVReg's LiveInterval to account for PHI operands + // We do a full recomputation because PHI operands may reference subregisters + // that weren't previously live on those paths, and we need to extend liveness + // from the definition to the PHI use. + LIS.removeInterval(OrigVReg); + LIS.createAndComputeVirtRegInterval(OrigVReg); + + // Note: We do NOT call shrinkToUses on OrigVReg even after recomputation because: + // shrinkToUses has a fundamental bug with PHI operands - it doesn't understand + // that PHI operands require their source lanes to be live at the END of + // predecessor blocks. When it sees a PHI operand like "%0.sub2_sub3" from BB3, + // it only considers the PHI location (start of join block), not the predecessor + // end where the value must be available. This causes it to incorrectly shrink + // away lanes that ARE needed by PHI operands, leading to verification errors: + // "Not all lanes of PHI source live at use". The createAndComputeVirtRegInterval + // already produces correct, minimal liveness that includes PHI uses properly. + + // Step 4: Update operand flags to match the LiveIntervals + updateDeadFlags(NewVReg); + for (Register PHIVReg : AllPHIVRegs) { + updateDeadFlags(PHIVReg); + } + + // Step 5: Verification if enabled + if (VerifyOnExit) { + LLVM_DEBUG(dbgs() << " Verifying after SSA repair...\n"); + // TODO: Add verification calls + } + + LLVM_DEBUG(dbgs() << " performSSARepair complete\n"); +} + +//===----------------------------------------------------------------------===// +// Internal Helper Methods (Stubs) +//===----------------------------------------------------------------------===// + +SlotIndex MachineLaneSSAUpdater::indexNewInstr(MachineInstr &MI) { + LLVM_DEBUG(dbgs() << "MachineLaneSSAUpdater::indexNewInstr: " << MI); + + // Register the instruction in SlotIndexes and LiveIntervals + // This is typically done automatically when instructions are inserted, + // but we need to ensure it's properly indexed + SlotIndexes *SI = LIS.getSlotIndexes(); + + // Check if instruction is already indexed + if (SI->hasIndex(MI)) { + SlotIndex Idx = SI->getInstructionIndex(MI); + LLVM_DEBUG(dbgs() << " Already indexed at " << Idx << "\n"); + return Idx; + } + + // Insert the instruction in maps - this should be done by the caller + // before calling our SSA repair methods, but we can verify + LIS.InsertMachineInstrInMaps(MI); + + SlotIndex Idx = SI->getInstructionIndex(MI); + LLVM_DEBUG(dbgs() << " Indexed at " << Idx << "\n"); + return Idx; +} + +void MachineLaneSSAUpdater::extendPreciselyAt(const Register VReg, + const SmallVector &LaneMasks, + const MachineInstr &AtMI) { + LLVM_DEBUG(dbgs() << "MachineLaneSSAUpdater::extendPreciselyAt VReg=" << VReg + << " at " << LIS.getInstructionIndex(AtMI) << "\n"); + + if (!VReg.isVirtual()) { + return; // Only handle virtual registers + } + + SlotIndex DefIdx = LIS.getInstructionIndex(AtMI).getRegSlot(); + + // Create or get the LiveInterval for this register + LiveInterval &LI = LIS.getInterval(VReg); + + // Extend the main live range to include the definition point + SmallVector DefPoint = { DefIdx }; + LIS.extendToIndices(LI, DefPoint); + + // For each lane mask, ensure appropriate subranges exist and are extended + // For now, assume all lanes are valid - we'll refine this later based on register class + LaneBitmask RegCoverageMask = MF.getRegInfo().getMaxLaneMaskForVReg(VReg); + + for (LaneBitmask LaneMask : LaneMasks) { + if (LaneMask == MF.getRegInfo().getMaxLaneMaskForVReg(VReg) || LaneMask == LaneBitmask::getNone()) { + continue; // Main range handles getAll(), skip getNone() + } + + // Only process lanes that are valid for this register class + LaneBitmask ValidLanes = LaneMask & RegCoverageMask; + if (ValidLanes.none()) { + continue; + } + + // Find or create the appropriate subrange + LiveInterval::SubRange *SR = nullptr; + for (LiveInterval::SubRange &Sub : LI.subranges()) { + if (Sub.LaneMask == ValidLanes) { + SR = ⋐ + break; + } + } + if (!SR) { + SR = LI.createSubRange(LIS.getVNInfoAllocator(), ValidLanes); + } + + // Extend this subrange to include the definition point + LIS.extendToIndices(*SR, DefPoint); + + LLVM_DEBUG(dbgs() << " Extended subrange " << PrintLaneMask(ValidLanes) << "\n"); + } + + LLVM_DEBUG(dbgs() << " LiveInterval extension complete\n"); +} + +void MachineLaneSSAUpdater::computePrunedIDF(Register OrigVReg, + LaneBitmask DefMask, + ArrayRef NewDefBlocks, + SmallVectorImpl &OutIDFBlocks) { + LLVM_DEBUG(dbgs() << "MachineLaneSSAUpdater::computePrunedIDF VReg=" << OrigVReg + << " DefMask=" << PrintLaneMask(DefMask) + << " with " << NewDefBlocks.size() << " new def blocks\n"); + + // Clear output vector at entry + OutIDFBlocks.clear(); + + // Early bail-out checks for robustness + if (!OrigVReg.isVirtual()) { + LLVM_DEBUG(dbgs() << " Skipping non-virtual register\n"); + return; + } + + if (!LIS.hasInterval(OrigVReg)) { + LLVM_DEBUG(dbgs() << " OrigVReg not tracked by LiveIntervals, bailing out\n"); + return; + } + + // Get the main LiveInterval for OrigVReg + LiveInterval &LI = LIS.getInterval(OrigVReg); + + // Build prune set: blocks where specified lanes (DefMask) are live-in at entry + SmallPtrSet LiveIn; + for (MachineBasicBlock &BB : MF) { + SlotIndex Start = LIS.getMBBStartIdx(&BB); + + // Collect live lanes at block entry + LaneBitmask LiveLanes = LaneBitmask::getNone(); + + if (DefMask == MF.getRegInfo().getMaxLaneMaskForVReg(OrigVReg)) { + // For full register (e.g., reload case), check main interval + if (LI.liveAt(Start)) { + LiveLanes = MF.getRegInfo().getMaxLaneMaskForVReg(OrigVReg); + } + } else { + // For specific lanes, check subranges + for (LiveInterval::SubRange &S : LI.subranges()) { + if (S.liveAt(Start)) { + LiveLanes |= S.LaneMask; + } + } + + // If no subranges found but main interval is live, + // assume all lanes are covered by the main interval + if (LiveLanes == LaneBitmask::getNone() && LI.liveAt(Start)) { + LiveLanes = MF.getRegInfo().getMaxLaneMaskForVReg(OrigVReg); + } + } + + // Check if any of the requested lanes (DefMask) are live + if ((LiveLanes & DefMask).any()) { + LiveIn.insert(&BB); + } + } + + // Seed set: the blocks where new defs exist (e.g., reload or prior PHIs) + SmallPtrSet DefBlocks; + for (MachineBasicBlock *B : NewDefBlocks) { + if (B) { // Robust to null entries + DefBlocks.insert(B); + } + } + + // Early exit if either set is empty + if (DefBlocks.empty() || LiveIn.empty()) { + LLVM_DEBUG(dbgs() << " DefBlocks=" << DefBlocks.size() << " LiveIn=" << LiveIn.size() + << ", early exit\n"); + return; + } + + LLVM_DEBUG(dbgs() << " DefBlocks=" << DefBlocks.size() << " LiveIn=" << LiveIn.size() << "\n"); + + // Use LLVM's IDFCalculatorBase for MachineBasicBlock with forward dominance + using NodeTy = MachineBasicBlock; + + // Access the underlying DomTreeBase from MachineDominatorTree + // MachineDominatorTree inherits from DomTreeBase + DomTreeBase &DT = MDT; + + // Compute pruned IDF (forward dominance, IsPostDom=false) + llvm::IDFCalculatorBase IDF(DT); + IDF.setDefiningBlocks(DefBlocks); + IDF.setLiveInBlocks(LiveIn); + IDF.calculate(OutIDFBlocks); + + LLVM_DEBUG(dbgs() << " Computed " << OutIDFBlocks.size() << " IDF blocks\n"); + + // Note: We do not place PHIs here; this function only computes candidate + // join blocks. The IDFCalculator handles deduplication automatically. +} + +SmallVector MachineLaneSSAUpdater::insertLaneAwarePHI(Register InitialVReg, + Register OrigVReg, + LaneBitmask DefMask, + MachineBasicBlock *InitialDefBB) { + LLVM_DEBUG(dbgs() << "MachineLaneSSAUpdater::insertLaneAwarePHI InitialVReg=" << InitialVReg + << " OrigVReg=" << OrigVReg << " DefMask=" << PrintLaneMask(DefMask) << "\n"); + + SmallVector AllCreatedPHIs; + + // Step 1: Compute IDF (Iterated Dominance Frontier) for the initial definition + // This gives us ALL blocks where PHI nodes need to be inserted + SmallVector DefBlocks = {InitialDefBB}; + SmallVector IDFBlocks; + computePrunedIDF(OrigVReg, DefMask, DefBlocks, IDFBlocks); + + LLVM_DEBUG(dbgs() << " Computed IDF: found " << IDFBlocks.size() << " blocks needing PHIs\n"); + for (MachineBasicBlock *MBB : IDFBlocks) { + LLVM_DEBUG(dbgs() << " BB#" << MBB->getNumber() << "\n"); + } + + // Step 2: Iterate through IDF blocks sequentially, creating PHIs + // Key insight: After creating a PHI, update NewVReg to the PHI result + // so subsequent PHIs use the correct register + Register CurrentNewVReg = InitialVReg; + + for (MachineBasicBlock *JoinMBB : IDFBlocks) { + LLVM_DEBUG(dbgs() << " Creating PHI in BB#" << JoinMBB->getNumber() + << " with CurrentNewVReg=" << CurrentNewVReg << "\n"); + + // Create PHI: merges OrigVReg and CurrentNewVReg based on dominance + Register PHIResult = createPHIInBlock(*JoinMBB, OrigVReg, CurrentNewVReg, DefMask); + + if (PHIResult.isValid()) { + AllCreatedPHIs.push_back(PHIResult); + + // Update CurrentNewVReg to be the PHI result + // This ensures the next PHI (if any) uses this PHI's result, not the original InitialVReg + CurrentNewVReg = PHIResult; + + LLVM_DEBUG(dbgs() << " Created PHI result VReg=" << PHIResult + << ", will use this for subsequent PHIs\n"); + } + } + + LLVM_DEBUG(dbgs() << " PHI insertion complete. Created " + << AllCreatedPHIs.size() << " PHI registers total.\n"); + + return AllCreatedPHIs; +} + +// Helper: Create lane-specific PHI in a join block +Register MachineLaneSSAUpdater::createPHIInBlock(MachineBasicBlock &JoinMBB, + Register OrigVReg, + Register NewVReg, + LaneBitmask DefMask) { + LLVM_DEBUG(dbgs() << " createPHIInBlock in BB#" << JoinMBB.getNumber() + << " OrigVReg=" << OrigVReg << " NewVReg=" << NewVReg + << " DefMask=" << PrintLaneMask(DefMask) << "\n"); + + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + const LaneBitmask FullMask = MF.getRegInfo().getMaxLaneMaskForVReg(OrigVReg); + + // Check if this is a partial lane redefinition + const bool IsPartialReload = (DefMask != FullMask); + + // Collect PHI operands for the specific reload lanes + SmallVector PHIOperands; + + LLVM_DEBUG(dbgs() << " Creating PHI for " << (IsPartialReload ? "partial reload" : "full reload") + << " DefMask=" << PrintLaneMask(DefMask) << "\n"); + + // Get the definition block of NewVReg for dominance checks + MachineRegisterInfo &MRI = MF.getRegInfo(); + MachineInstr *NewDefMI = MRI.getVRegDef(NewVReg); + MachineBasicBlock *NewDefBB = NewDefMI->getParent(); + + for (MachineBasicBlock *PredMBB : JoinMBB.predecessors()) { + // Use dominance check instead of liveness: if NewDefBB dominates PredMBB, + // then NewVReg is available at the end of PredMBB + bool UseNewReg = MDT.dominates(NewDefBB, PredMBB); + + if (UseNewReg) { + // This is the reload path - use NewVReg (always full register for its class) + LLVM_DEBUG(dbgs() << " Pred BB#" << PredMBB->getNumber() + << " contributes NewVReg (reload path)\n"); + + PHIOperands.push_back(MachineOperand::CreateReg(NewVReg, /*isDef*/ false)); + PHIOperands.push_back(MachineOperand::CreateMBB(PredMBB)); + + } else { + // This is the original path - use OrigVReg with appropriate subregister + LLVM_DEBUG(dbgs() << " Pred BB#" << PredMBB->getNumber() + << " contributes OrigVReg (original path)\n"); + + if (IsPartialReload) { + // Partial case: z = PHI(y, BB1, x.sub2_3, BB0) + // Use DefMask to find which subreg of OrigVReg was redefined + unsigned SubIdx = getSubRegIndexForLaneMask(DefMask, &TRI); + PHIOperands.push_back(MachineOperand::CreateReg(OrigVReg, /*isDef*/ false, + /*isImp*/ false, /*isKill*/ false, + /*isDead*/ false, /*isUndef*/ false, + /*isEarlyClobber*/ false, SubIdx)); + } else { + // Full register case: z = PHI(y, BB1, x, BB0) + PHIOperands.push_back(MachineOperand::CreateReg(OrigVReg, /*isDef*/ false)); + } + PHIOperands.push_back(MachineOperand::CreateMBB(PredMBB)); + } + } + + // Create the single lane-specific PHI + if (!PHIOperands.empty()) { + const TargetRegisterClass *RC = MF.getRegInfo().getRegClass(NewVReg); + Register PHIVReg = MF.getRegInfo().createVirtualRegister(RC); + + auto PHINode = BuildMI(JoinMBB, JoinMBB.begin(), DebugLoc(), + TII->get(TargetOpcode::PHI), PHIVReg); + for (const MachineOperand &Op : PHIOperands) { + PHINode.add(Op); + } + + MachineInstr *PHI = PHINode.getInstr(); + LIS.InsertMachineInstrInMaps(*PHI); + + LLVM_DEBUG(dbgs() << " Created lane-specific PHI: "); + LLVM_DEBUG(PHI->print(dbgs())); + + return PHIVReg; + } + + return Register(); +} + +void MachineLaneSSAUpdater::rewriteDominatedUses(Register OrigVReg, + Register NewSSA, + LaneBitmask MaskToRewrite) { + LLVM_DEBUG(dbgs() << "MachineLaneSSAUpdater::rewriteDominatedUses OrigVReg=" << OrigVReg + << " NewSSA=" << NewSSA << " Mask=" << PrintLaneMask(MaskToRewrite) << "\n"); + + const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + + // Find the definition instruction for NewSSA + MachineInstr *DefMI = MRI.getVRegDef(NewSSA); + if (!DefMI) { + LLVM_DEBUG(dbgs() << " No definition found for NewSSA, skipping\n"); + return; + } + + MachineBasicBlock *DefBB = DefMI->getParent(); + const TargetRegisterClass *NewRC = MRI.getRegClass(NewSSA); + + LLVM_DEBUG(dbgs() << " Rewriting uses dominated by definition in BB#" << DefBB->getNumber() << ": "); + LLVM_DEBUG(DefMI->print(dbgs())); + + // Get OrigVReg's LiveInterval for reference + LiveInterval &OrigLI = LIS.getInterval(OrigVReg); + + // Iterate through all uses of OrigVReg + for (MachineOperand &MO : llvm::make_early_inc_range(MRI.use_operands(OrigVReg))) { + MachineInstr *UseMI = MO.getParent(); + + // Skip the definition instruction itself + if (UseMI == DefMI) + continue; + + // Check if this use is reached by our definition + if (!defReachesUse(DefMI, UseMI, MO)) + continue; + + // Get the lane mask for this operand + LaneBitmask OpMask = operandLaneMask(MO); + if ((OpMask & MaskToRewrite).none()) + continue; + + LLVM_DEBUG(dbgs() << " Processing use with OpMask=" << PrintLaneMask(OpMask) << ": "); + LLVM_DEBUG(UseMI->print(dbgs())); + + const TargetRegisterClass *OpRC = MRI.getRegClass(OrigVReg); + + // Case 1: Exact match - direct replacement + if (OpMask == MaskToRewrite) { + // Check register class compatibility + // If operand uses a subreg, NewRC should match the subreg class + // If operand uses full register, NewRC should match OpRC + const TargetRegisterClass *ExpectedRC = MO.getSubReg() != 0 + ? TRI.getSubRegisterClass(OpRC, MO.getSubReg()) + : OpRC; + bool Compatible = (ExpectedRC == NewRC); + + if (Compatible) { + LLVM_DEBUG(dbgs() << " Exact match -> direct replacement\n"); + MO.setReg(NewSSA); + MO.setSubReg(0); // Clear subregister (NewSSA is a full register of NewRC) + + // Extend NewSSA's live interval to cover this use + SlotIndex UseIdx = LIS.getInstructionIndex(*UseMI).getRegSlot(); + LiveInterval &NewLI = LIS.getInterval(NewSSA); + LIS.extendToIndices(NewLI, {UseIdx}); + + continue; + } + + // Incompatible register classes with same lane mask indicates corrupted MIR + llvm_unreachable("Incompatible register classes with same lane mask - invalid MIR"); + } + + // Case 2: Super/Mixed - use needs more lanes than we're rewriting + if ((OpMask & ~MaskToRewrite).any()) { + LLVM_DEBUG(dbgs() << " Super/Mixed case -> building REG_SEQUENCE\n"); + + SmallVector LanesToExtend; + SlotIndex RSIdx; + Register RSReg = buildRSForSuperUse(UseMI, MO, OrigVReg, NewSSA, MaskToRewrite, + OrigLI, OpRC, RSIdx, LanesToExtend); + extendAt(OrigLI, RSIdx, LanesToExtend); + MO.setReg(RSReg); + MO.setSubReg(0); + + // Extend RSReg's live interval to cover this use + SlotIndex UseIdx; + if (UseMI->isPHI()) { + // For PHI, the value must be live at the end of the predecessor block + unsigned OpIdx = UseMI->getOperandNo(&MO); + MachineBasicBlock *Pred = UseMI->getOperand(OpIdx + 1).getMBB(); + UseIdx = LIS.getMBBEndIdx(Pred); + } else { + UseIdx = LIS.getInstructionIndex(*UseMI).getRegSlot(); + } + LiveInterval &RSLI = LIS.getInterval(RSReg); + LIS.extendToIndices(RSLI, {UseIdx}); + + // Update dead flag on REG_SEQUENCE result + updateDeadFlags(RSReg); + + } else { + // Case 3: Subset - use needs fewer lanes than NewSSA provides + // Need to remap subregister index from OrigVReg's register class to NewSSA's register class + // + // Example: OrigVReg is vreg_128, we redefine sub2_3 (64-bit), use accesses sub3 (32-bit) + // MaskToRewrite = 0xF0 // sub2_3: lanes 4-7 in vreg_128 space + // OpMask = 0xC0 // sub3: lanes 6-7 in vreg_128 space + // NewSSA is vreg_64, has lanes 0-3 (but represents lanes 4-7 of OrigVReg) + // + // Algorithm: Shift OpMask down by the bit position of MaskToRewrite's LSB to map + // from OrigVReg's lane space into NewSSA's lane space, then find the subreg index. + // + // Why this works: + // 1. MaskToRewrite is contiguous (comes from subreg definition) + // 2. OpMask ⊆ MaskToRewrite (we're in subset case by construction) + // 3. Lane masks use bit positions that correspond to actual lane indices + // 4. Subreg boundaries are power-of-2 aligned in register class design + // + // Calculation: + // Shift = countTrailingZeros(MaskToRewrite) = 4 // How far "up" MaskToRewrite is + // NewMask = OpMask >> 4 = 0xC0 >> 4 = 0xC // Map to NewSSA's lane space + // 0xC corresponds to sub1 in vreg_64 ✓ + LLVM_DEBUG(dbgs() << " Subset case -> remapping subregister index\n"); + + // Find the bit offset of MaskToRewrite (position of its lowest set bit) + unsigned ShiftAmt = llvm::countr_zero(MaskToRewrite.getAsInteger()); + assert(ShiftAmt < 64 && "MaskToRewrite should have at least one bit set"); + + // Shift OpMask down into NewSSA's lane space + LaneBitmask NewMask = LaneBitmask(OpMask.getAsInteger() >> ShiftAmt); + + // Find the subregister index for NewMask in NewSSA's register class + unsigned NewSubReg = getSubRegIndexForLaneMask(NewMask, &TRI); + assert(NewSubReg && "Should find subreg index for remapped lanes"); + + LLVM_DEBUG(dbgs() << " Remapping subreg:\n" + << " OrigVReg lanes: OpMask=" << PrintLaneMask(OpMask) + << " MaskToRewrite=" << PrintLaneMask(MaskToRewrite) << "\n" + << " Shift amount: " << ShiftAmt << "\n" + << " NewSSA lanes: NewMask=" << PrintLaneMask(NewMask) + << " -> SubReg=" << TRI.getSubRegIndexName(NewSubReg) << "\n"); + + MO.setReg(NewSSA); + MO.setSubReg(NewSubReg); + + // Extend NewSSA's live interval to cover this use + SlotIndex UseIdx = LIS.getInstructionIndex(*UseMI).getRegSlot(); + LiveInterval &NewLI = LIS.getInterval(NewSSA); + LIS.extendToIndices(NewLI, {UseIdx}); + } + } + + LLVM_DEBUG(dbgs() << " Completed rewriting dominated uses\n"); +} + +//===----------------------------------------------------------------------===// +// Internal helpers +//===----------------------------------------------------------------------===// + +/// Return the VNInfo reaching this PHI operand along its predecessor edge. +VNInfo *MachineLaneSSAUpdater::incomingOnEdge(LiveInterval &LI, MachineInstr *Phi, + MachineOperand &PhiOp) { + unsigned OpIdx = Phi->getOperandNo(&PhiOp); + MachineBasicBlock *Pred = Phi->getOperand(OpIdx + 1).getMBB(); + SlotIndex EndB = LIS.getMBBEndIdx(Pred); + return LI.getVNInfoBefore(EndB); +} + +/// Check if \p DefMI's definition reaches \p UseMI's use operand. +/// During SSA reconstruction, LiveIntervals may not be complete yet, so we use +/// dominance-based checking rather than querying LiveInterval reachability. +bool MachineLaneSSAUpdater::defReachesUse(MachineInstr *DefMI, + MachineInstr *UseMI, + MachineOperand &UseOp) { + // For PHI uses, check if DefMI dominates the predecessor block + if (UseMI->isPHI()) { + unsigned OpIdx = UseMI->getOperandNo(&UseOp); + MachineBasicBlock *Pred = UseMI->getOperand(OpIdx + 1).getMBB(); + return MDT.dominates(DefMI->getParent(), Pred); + } + + // For same-block uses, check instruction order + if (UseMI->getParent() == DefMI->getParent()) { + SlotIndex DefIdx = LIS.getInstructionIndex(*DefMI); + SlotIndex UseIdx = LIS.getInstructionIndex(*UseMI); + return DefIdx < UseIdx; + } + + // For cross-block uses, check block dominance + return MDT.dominates(DefMI->getParent(), UseMI->getParent()); +} + +/// What lanes does this operand read? +LaneBitmask MachineLaneSSAUpdater::operandLaneMask(const MachineOperand &MO) { + const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + + if (unsigned Sub = MO.getSubReg()) + return TRI.getSubRegIndexLaneMask(Sub); + return MRI.getMaxLaneMaskForVReg(MO.getReg()); +} + +/// Helper: Decompose a potentially non-contiguous lane mask into a vector of +/// subregister indices that together cover all lanes in the mask. +/// From getCoveringSubRegsForLaneMask in AMDGPUSSARAUtils.h (PR #156049). +/// +/// Key algorithm: Sort candidates by lane count (prefer larger subregs) to get +/// minimal covering set with largest possible subregisters. +/// +/// Example: For vreg_128 with LaneMask = 0x0F | 0xF0 (sub0 + sub2, skipping sub1) +/// Returns: [sub0_idx, sub2_idx] (not lo16, hi16, sub2, sub3) +static SmallVector getCoveringSubRegsForLaneMask( + LaneBitmask Mask, const TargetRegisterInfo *TRI, + const TargetRegisterClass *RC) { + if (Mask.none()) + return {}; + + // Step 1: Collect all candidate subregisters that overlap with Mask + SmallVector Candidates; + for (unsigned SubIdx = 1; SubIdx < TRI->getNumSubRegIndices(); ++SubIdx) { + // Check if this subreg index is valid for this register class + if (!TRI->getSubRegisterClass(RC, SubIdx)) + continue; + + LaneBitmask SubMask = TRI->getSubRegIndexLaneMask(SubIdx); + // Add if it covers any lanes we need + if ((SubMask & Mask).any()) { + Candidates.push_back(SubIdx); + } + } + + // Step 2: Sort by number of lanes (descending) to prefer larger subregisters + llvm::stable_sort(Candidates, [&](unsigned A, unsigned B) { + return TRI->getSubRegIndexLaneMask(A).getNumLanes() > + TRI->getSubRegIndexLaneMask(B).getNumLanes(); + }); + + // Step 3: Greedily select subregisters, largest first + SmallVector OptimalSubIndices; + for (unsigned SubIdx : Candidates) { + LaneBitmask SubMask = TRI->getSubRegIndexLaneMask(SubIdx); + // Only add if this subreg is fully contained in the remaining mask + if ((Mask & SubMask) == SubMask) { + OptimalSubIndices.push_back(SubIdx); + Mask &= ~SubMask; // Remove covered lanes + + if (Mask.none()) + break; // All lanes covered + } + } + + return OptimalSubIndices; +} + +/// Build a REG_SEQUENCE to materialize a super-reg/mixed-lane use. +/// Inserts at the PHI predecessor terminator (for PHI uses) or right before +/// UseMI otherwise. Returns the new full-width vreg, the RS index via OutIdx, +/// and the subrange lane masks that should be extended to that point. +Register MachineLaneSSAUpdater::buildRSForSuperUse(MachineInstr *UseMI, MachineOperand &MO, + Register OldVR, Register NewVR, + LaneBitmask MaskToRewrite, LiveInterval &LI, + const TargetRegisterClass *OpRC, + SlotIndex &OutIdx, + SmallVectorImpl &LanesToExtend) { + const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo(); + const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + + MachineBasicBlock *InsertBB = UseMI->getParent(); + MachineBasicBlock::iterator IP(UseMI); + SlotIndex QueryIdx; + + if (UseMI->isPHI()) { + unsigned OpIdx = UseMI->getOperandNo(&MO); + MachineBasicBlock *Pred = UseMI->getOperand(OpIdx + 1).getMBB(); + InsertBB = Pred; + IP = Pred->getFirstTerminator(); // ok if == end() + QueryIdx = LIS.getMBBEndIdx(Pred).getPrevSlot(); + } else { + QueryIdx = LIS.getInstructionIndex(*UseMI); + } + + Register Dest = MRI.createVirtualRegister(OpRC); + auto RS = BuildMI(*InsertBB, IP, + (IP != InsertBB->end() ? IP->getDebugLoc() : DebugLoc()), + TII.get(TargetOpcode::REG_SEQUENCE), Dest); + + // Determine what lanes the use needs + LaneBitmask UseMask = operandLaneMask(MO); + + // Decompose into lanes from NewVR (updated) and lanes from OldVR (unchanged) + LaneBitmask LanesFromNew = UseMask & MaskToRewrite; + LaneBitmask LanesFromOld = UseMask & ~MaskToRewrite; + + LLVM_DEBUG(dbgs() << " Building REG_SEQUENCE: UseMask=" << PrintLaneMask(UseMask) + << " LanesFromNew=" << PrintLaneMask(LanesFromNew) + << " LanesFromOld=" << PrintLaneMask(LanesFromOld) << "\n"); + + SmallDenseSet AddedSubIdxs; + + // Add source for lanes from NewVR (updated lanes) + if (LanesFromNew.any()) { + unsigned SubIdx = getSubRegIndexForLaneMask(LanesFromNew, &TRI); + assert(SubIdx && "Failed to find subregister index for LanesFromNew"); + RS.addReg(NewVR, 0, 0).addImm(SubIdx); // NewVR is full register, no subreg + AddedSubIdxs.insert(SubIdx); + LanesToExtend.push_back(LanesFromNew); + } + + // Add source for lanes from OldVR (unchanged lanes) + // Handle both contiguous and non-contiguous lane masks + // Non-contiguous example: Redefining only sub2 of vreg_128 leaves LanesFromOld = sub0+sub1+sub3 + if (LanesFromOld.any()) { + unsigned SubIdx = getSubRegIndexForLaneMask(LanesFromOld, &TRI); + + if (SubIdx) { + // Contiguous case: single subregister covers all lanes + RS.addReg(OldVR, 0, SubIdx).addImm(SubIdx); // OldVR.subIdx + AddedSubIdxs.insert(SubIdx); + LanesToExtend.push_back(LanesFromOld); + } else { + // Non-contiguous case: decompose into multiple subregisters + const TargetRegisterClass *OldRC = MRI.getRegClass(OldVR); + SmallVector CoveringSubRegs = + getCoveringSubRegsForLaneMask(LanesFromOld, &TRI, OldRC); + + assert(!CoveringSubRegs.empty() && + "Failed to decompose non-contiguous lane mask into covering subregs"); + + LLVM_DEBUG(dbgs() << " Non-contiguous LanesFromOld=" << PrintLaneMask(LanesFromOld) + << " decomposed into " << CoveringSubRegs.size() << " subregs\n"); + + // Add each covering subregister as a source to the REG_SEQUENCE + for (unsigned CoverSubIdx : CoveringSubRegs) { + LaneBitmask CoverMask = TRI.getSubRegIndexLaneMask(CoverSubIdx); + RS.addReg(OldVR, 0, CoverSubIdx).addImm(CoverSubIdx); // OldVR.CoverSubIdx + AddedSubIdxs.insert(CoverSubIdx); + LanesToExtend.push_back(CoverMask); + + LLVM_DEBUG(dbgs() << " Added source: OldVR." + << TRI.getSubRegIndexName(CoverSubIdx) + << " covering " << PrintLaneMask(CoverMask) << "\n"); + } + } + } + + assert(!AddedSubIdxs.empty() && "REG_SEQUENCE must have at least one source"); + + LIS.InsertMachineInstrInMaps(*RS); + OutIdx = LIS.getInstructionIndex(*RS); + + // Create live interval for the REG_SEQUENCE result + LIS.createAndComputeVirtRegInterval(Dest); + + // Extend live intervals of all source registers to cover this REG_SEQUENCE + // Use the register slot to ensure the live range covers the use + SlotIndex UseSlot = OutIdx.getRegSlot(); + for (MachineOperand &MO : RS.getInstr()->uses()) { + if (MO.isReg() && MO.getReg().isVirtual()) { + Register SrcReg = MO.getReg(); + LiveInterval &SrcLI = LIS.getInterval(SrcReg); + LIS.extendToIndices(SrcLI, {UseSlot}); + } + } + + LLVM_DEBUG(dbgs() << " Built REG_SEQUENCE: "); + LLVM_DEBUG(RS->print(dbgs())); + + return Dest; +} + +/// Extend LI (and only the specified subranges) at Idx. +void MachineLaneSSAUpdater::extendAt(LiveInterval &LI, SlotIndex Idx, + ArrayRef Lanes) { + SmallVector P{Idx}; + LIS.extendToIndices(LI, P); + for (auto &SR : LI.subranges()) + for (LaneBitmask L : Lanes) + if (SR.LaneMask == L) + LIS.extendToIndices(SR, P); +} + +void MachineLaneSSAUpdater::updateDeadFlags(Register Reg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + LiveInterval &LI = LIS.getInterval(Reg); + MachineInstr *DefMI = MRI.getVRegDef(Reg); + if (!DefMI) + return; + + for (MachineOperand &MO : DefMI->defs()) { + if (MO.getReg() == Reg && MO.isDead()) { + // Check if this register is actually live (has uses) + if (!LI.empty() && !MRI.use_nodbg_empty(Reg)) { + MO.setIsDead(false); + LLVM_DEBUG(dbgs() << " Cleared dead flag on " << Reg << "\n"); + } + } + } +} + +// Remove the old helper that's no longer needed +// LaneBitmask MachineLaneSSAUpdater::getLaneMaskForOperand(...) - REMOVED \ No newline at end of file diff --git a/llvm/unittests/CodeGen/CMakeLists.txt b/llvm/unittests/CodeGen/CMakeLists.txt index 22dbdaa4fa82e..e31c012e639bb 100644 --- a/llvm/unittests/CodeGen/CMakeLists.txt +++ b/llvm/unittests/CodeGen/CMakeLists.txt @@ -36,6 +36,8 @@ add_llvm_unittest(CodeGenTests MachineDomTreeUpdaterTest.cpp MachineInstrBundleIteratorTest.cpp MachineInstrTest.cpp + MachineLaneSSAUpdaterTest.cpp + MachineLaneSSAUpdaterSpillReloadTest.cpp MachineOperandTest.cpp RegAllocScoreTest.cpp PassManagerTest.cpp diff --git a/llvm/unittests/CodeGen/MachineLaneSSAUpdaterSpillReloadTest.cpp b/llvm/unittests/CodeGen/MachineLaneSSAUpdaterSpillReloadTest.cpp new file mode 100644 index 0000000000000..81a1f7703f2a8 --- /dev/null +++ b/llvm/unittests/CodeGen/MachineLaneSSAUpdaterSpillReloadTest.cpp @@ -0,0 +1,332 @@ +//===- MachineLaneSSAUpdaterSpillReloadTest.cpp - Spill/Reload tests -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Unit tests for MachineLaneSSAUpdater focusing on spill/reload scenarios. +// +// NOTE: This file is currently a placeholder for future spiller-specific tests. +// Analysis showed that repairSSAForNewDef() is sufficient for spill/reload +// scenarios - no special spill handling is needed. The spiller workflow is: +// 1. Insert reload instruction before use +// 2. Call repairSSAForNewDef(ReloadMI, SpilledReg) +// 3. Done! Uses are rewritten, LiveIntervals naturally pruned +// +// Future spiller-specific scenarios (if needed) can be added here. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/MachineLaneSSAUpdater.h" +#include "llvm/CodeGen/LiveIntervals.h" +#include "llvm/CodeGen/MachineDominators.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/MIRParser/MIRParser.h" +#include "llvm/CodeGen/SlotIndexes.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/InitializePasses.h" +#include "llvm/MC/LaneBitmask.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/TargetParser/Triple.h" +#include "gtest/gtest.h" + +using namespace llvm; + +// TestPass needs to be defined outside anonymous namespace for INITIALIZE_PASS +struct SpillReloadTestPass : public MachineFunctionPass { + static char ID; + SpillReloadTestPass() : MachineFunctionPass(ID) {} +}; + +char SpillReloadTestPass::ID = 0; + +namespace llvm { + void initializeSpillReloadTestPassPass(PassRegistry &); +} + +INITIALIZE_PASS(SpillReloadTestPass, "spillreloadtestpass", + "spillreloadtestpass", false, false) + +namespace { + +void initLLVM() { + InitializeAllTargets(); + InitializeAllTargetMCs(); + InitializeAllAsmPrinters(); + InitializeAllAsmParsers(); + + PassRegistry *Registry = PassRegistry::getPassRegistry(); + initializeCore(*Registry); + initializeCodeGen(*Registry); +} + +// Helper to create a target machine for AMDGPU +std::unique_ptr createTargetMachine() { + Triple TT("amdgcn--"); + std::string Error; + const Target *T = TargetRegistry::lookupTarget("", TT, Error); + if (!T) + return nullptr; + + TargetOptions Options; + return std::unique_ptr( + T->createTargetMachine(TT, "gfx900", "", Options, std::nullopt, + std::nullopt, CodeGenOptLevel::Aggressive)); +} + +// Helper to parse MIR string with legacy PassManager +std::unique_ptr parseMIR(LLVMContext &Context, + legacy::PassManagerBase &PM, + std::unique_ptr &MIR, + const TargetMachine &TM, StringRef MIRCode) { + SMDiagnostic Diagnostic; + std::unique_ptr MBuffer = MemoryBuffer::getMemBuffer(MIRCode); + MIR = createMIRParser(std::move(MBuffer), Context); + if (!MIR) + return nullptr; + + std::unique_ptr M = MIR->parseIRModule(); + if (!M) + return nullptr; + + M->setDataLayout(TM.createDataLayout()); + + MachineModuleInfoWrapperPass *MMIWP = new MachineModuleInfoWrapperPass(&TM); + if (MIR->parseMachineFunctions(*M, MMIWP->getMMI())) + return nullptr; + PM.add(MMIWP); + + return M; +} + +template +struct SpillReloadTestPassT : public SpillReloadTestPass { + typedef std::function TestFx; + + SpillReloadTestPassT() { + // We should never call this but always use PM.add(new SpillReloadTestPass(...)) + abort(); + } + + SpillReloadTestPassT(TestFx T, bool ShouldPass) + : T(T), ShouldPass(ShouldPass) { + initializeSpillReloadTestPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnMachineFunction(MachineFunction &MF) override { + AnalysisType &A = getAnalysis(); + T(MF, A); + bool VerifyResult = MF.verify(this, /* Banner=*/nullptr, + /*OS=*/&llvm::errs(), + /* AbortOnError=*/false); + EXPECT_EQ(VerifyResult, ShouldPass); + return true; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequired(); + AU.addPreserved(); + MachineFunctionPass::getAnalysisUsage(AU); + } + +private: + TestFx T; + bool ShouldPass; +}; + +template +static void doTest(StringRef MIRFunc, + typename SpillReloadTestPassT::TestFx T, + bool ShouldPass = true) { + initLLVM(); + + LLVMContext Context; + std::unique_ptr TM = createTargetMachine(); + if (!TM) + GTEST_SKIP() << "AMDGPU target not available"; + + legacy::PassManager PM; + std::unique_ptr MIR; + std::unique_ptr M = parseMIR(Context, PM, MIR, *TM, MIRFunc); + ASSERT_TRUE(M); + + PM.add(new SpillReloadTestPassT(T, ShouldPass)); + + PM.run(*M); +} + +static void liveIntervalsTest(StringRef MIRFunc, + SpillReloadTestPassT::TestFx T, + bool ShouldPass = true) { + SmallString<512> S; + StringRef MIRString = (Twine(R"MIR( +--- | + define amdgpu_kernel void @func() { ret void } +... +--- +name: func +tracksRegLiveness: true +registers: + - { id: 0, class: vgpr_32 } +body: | + bb.0: +)MIR") + Twine(MIRFunc) + Twine("...\n")).toNullTerminatedStringRef(S); + + doTest(MIRString, T, ShouldPass); +} + +//===----------------------------------------------------------------------===// +// Spill/Reload Tests +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Test 1: Simple Linear Spill/Reload +//===----------------------------------------------------------------------===// +// +// This test demonstrates that repairSSAForNewDef() works for spill/reload +// scenarios without any special handling. +// +// CFG Structure: +// BB0 (entry) +// | %0 = initial_def +// | +// BB1 (intermediate) +// | some operations +// | +// BB2 (reload & use) +// | %0 = RELOAD (simulated as V_MOV_B32) +// | use %0 +// +// Scenario: +// - %0 is defined in BB0 and used in BB2 +// - Insert a reload instruction in BB2 that redefines %0 (violating SSA) +// - Call repairSSAForNewDef() to fix the SSA violation +// - Verify that uses are rewritten and LiveIntervals are correct +// +// Expected Behavior: +// - Reload renamed to define a new register +// - Uses after reload rewritten to new register +// - OrigReg's LiveInterval naturally pruned to BB0 only +// - No PHI needed (linear CFG) +// +TEST(MachineLaneSSAUpdaterSpillReloadTest, SimpleLinearSpillReload) { + liveIntervalsTest(R"MIR( + %0:vgpr_32 = V_MOV_B32_e32 42, implicit $exec + S_BRANCH %bb.1 + + bb.1: + successors: %bb.2 + %1:vgpr_32 = V_MOV_B32_e32 100, implicit $exec + S_BRANCH %bb.2 + + bb.2: + %2:vgpr_32 = V_ADD_U32_e32 %0, %1, implicit $exec + S_ENDPGM 0 +)MIR", + [](MachineFunction &MF, LiveIntervalsWrapperPass &LISWrapper) { + LiveIntervals &LIS = LISWrapper.getLIS(); + MachineDominatorTree MDT(MF); + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + + // Verify we have 3 blocks as expected + ASSERT_EQ(MF.size(), 3u) << "Should have bb.0, bb.1, bb.2"; + + MachineBasicBlock *BB0 = MF.getBlockNumbered(0); + MachineBasicBlock *BB2 = MF.getBlockNumbered(2); + + // Find %0 definition in BB0 (first instruction should be V_MOV_B32) + MachineInstr *OrigDefMI = &*BB0->begin(); + ASSERT_TRUE(OrigDefMI && OrigDefMI->getNumOperands() > 0); + Register OrigReg = OrigDefMI->getOperand(0).getReg(); + ASSERT_TRUE(OrigReg.isValid()) << "Should have valid original register %0"; + + // STEP 1: Insert reload instruction in BB2 before the use + // This creates a second definition of %0, violating SSA + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + auto InsertPt = BB2->getFirstNonPHI(); + + // Get opcode and register from the existing V_MOV_B32 in BB0 + unsigned MovOpcode = OrigDefMI->getOpcode(); + Register ExecReg = OrigDefMI->getOperand(2).getReg(); + + // Insert reload: %0 = V_MOV_B32 999 (simulating load from stack) + // This violates SSA because %0 is already defined in BB0 + MachineInstr *ReloadMI = BuildMI(*BB2, InsertPt, DebugLoc(), + TII->get(MovOpcode), OrigReg) + .addImm(999) // Simulated reload value + .addReg(ExecReg, RegState::Implicit); + + // Set MachineFunction properties to allow SSA + MF.getProperties().set(MachineFunctionProperties::Property::IsSSA); + MF.getProperties().reset(MachineFunctionProperties::Property::NoPHIs); + + // STEP 2: Call repairSSAForNewDef to fix the SSA violation + // This will: + // - Rename the reload to define a new register + // - Rewrite uses dominated by the reload + // - Naturally prune OrigReg's LiveInterval via recomputation + MachineLaneSSAUpdater Updater(MF, LIS, MDT, *TRI); + Register ReloadReg = Updater.repairSSAForNewDef(*ReloadMI, OrigReg); + + // VERIFY RESULTS: + + // 1. ReloadReg should be valid and different from OrigReg + EXPECT_TRUE(ReloadReg.isValid()) << "Updater should return valid register"; + EXPECT_NE(ReloadReg, OrigReg) << "Reload register should be different from original"; + + // 2. ReloadMI should define the new ReloadReg (not OrigReg) + EXPECT_EQ(ReloadMI->getOperand(0).getReg(), ReloadReg) + << "ReloadMI should define new reload register"; + + // 3. Verify the ReloadReg has a valid LiveInterval + EXPECT_TRUE(LIS.hasInterval(ReloadReg)) + << "Reload register should have live interval"; + + // 4. No PHI should be inserted (linear CFG, reload dominates subsequent uses) + bool FoundPHI = false; + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + if (MI.isPHI()) { + FoundPHI = true; + break; + } + } + } + EXPECT_FALSE(FoundPHI) + << "Linear CFG should not require PHI nodes"; + + // 5. Verify OrigReg's LiveInterval was naturally pruned + // It should only cover BB0 now (definition to end of block) + EXPECT_TRUE(LIS.hasInterval(OrigReg)) + << "Original register should still have live interval"; + const LiveInterval &OrigLI = LIS.getInterval(OrigReg); + + // The performSSARepair recomputation naturally prunes OrigReg + // because all uses in BB2 were rewritten to ReloadReg + SlotIndex OrigEnd = OrigLI.endIndex(); + + // OrigReg should not extend into BB2 where ReloadReg took over + SlotIndex BB2Start = LIS.getMBBStartIdx(BB2); + EXPECT_LE(OrigEnd, BB2Start) + << "Original register should not extend into BB2 after reload"; + }); +} + +} // anonymous namespace + diff --git a/llvm/unittests/CodeGen/MachineLaneSSAUpdaterTest.cpp b/llvm/unittests/CodeGen/MachineLaneSSAUpdaterTest.cpp new file mode 100644 index 0000000000000..172cbf33dce2f --- /dev/null +++ b/llvm/unittests/CodeGen/MachineLaneSSAUpdaterTest.cpp @@ -0,0 +1,2367 @@ +//===- MachineLaneSSAUpdaterTest.cpp - Unit tests for MachineLaneSSAUpdater -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/MachineLaneSSAUpdater.h" +#include "llvm/CodeGen/LiveIntervals.h" +#include "llvm/CodeGen/MachineDominators.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/MIRParser/MIRParser.h" +#include "llvm/CodeGen/SlotIndexes.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/InitializePasses.h" +#include "llvm/MC/LaneBitmask.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/TargetParser/Triple.h" +#include "gtest/gtest.h" + +#define DEBUG_TYPE "machine-lane-ssa-updater-test" + +using namespace llvm; + +// TestPass needs to be defined outside anonymous namespace for INITIALIZE_PASS +struct TestPass : public MachineFunctionPass { + static char ID; + TestPass() : MachineFunctionPass(ID) {} +}; + +char TestPass::ID = 0; + +namespace llvm { + void initializeTestPassPass(PassRegistry &); +} + +INITIALIZE_PASS(TestPass, "testpass", "testpass", false, false) + +namespace { + +void initLLVM() { + InitializeAllTargets(); + InitializeAllTargetMCs(); + InitializeAllAsmPrinters(); + InitializeAllAsmParsers(); + + PassRegistry *Registry = PassRegistry::getPassRegistry(); + initializeCore(*Registry); + initializeCodeGen(*Registry); +} + +// Helper to create a target machine for AMDGPU +std::unique_ptr createTargetMachine() { + Triple TT("amdgcn--"); + std::string Error; + const Target *T = TargetRegistry::lookupTarget("", TT, Error); + if (!T) + return nullptr; + + TargetOptions Options; + return std::unique_ptr( + T->createTargetMachine(TT, "gfx900", "", Options, std::nullopt, + std::nullopt, CodeGenOptLevel::Aggressive)); +} + +// Helper to parse MIR string with legacy PassManager +std::unique_ptr parseMIR(LLVMContext &Context, + legacy::PassManagerBase &PM, + std::unique_ptr &MIR, + const TargetMachine &TM, StringRef MIRCode) { + SMDiagnostic Diagnostic; + std::unique_ptr MBuffer = MemoryBuffer::getMemBuffer(MIRCode); + MIR = createMIRParser(std::move(MBuffer), Context); + if (!MIR) + return nullptr; + + std::unique_ptr M = MIR->parseIRModule(); + if (!M) + return nullptr; + + M->setDataLayout(TM.createDataLayout()); + + MachineModuleInfoWrapperPass *MMIWP = new MachineModuleInfoWrapperPass(&TM); + if (MIR->parseMachineFunctions(*M, MMIWP->getMMI())) + return nullptr; + PM.add(MMIWP); + + return M; +} + +template +struct TestPassT : public TestPass { + typedef std::function TestFx; + + TestPassT() { + // We should never call this but always use PM.add(new TestPass(...)) + abort(); + } + + TestPassT(TestFx T, bool ShouldPass) + : T(T), ShouldPass(ShouldPass) { + initializeTestPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnMachineFunction(MachineFunction &MF) override { + AnalysisType &A = getAnalysis(); + T(MF, A); + bool VerifyResult = MF.verify(this, /* Banner=*/nullptr, + /*OS=*/&llvm::errs(), + /* AbortOnError=*/false); + EXPECT_EQ(VerifyResult, ShouldPass); + return true; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequired(); + AU.addPreserved(); + MachineFunctionPass::getAnalysisUsage(AU); + } + +private: + TestFx T; + bool ShouldPass; +}; + +template +static void doTest(StringRef MIRFunc, + typename TestPassT::TestFx T, + bool ShouldPass = true) { + initLLVM(); + + LLVMContext Context; + std::unique_ptr TM = createTargetMachine(); + if (!TM) + GTEST_SKIP() << "AMDGPU target not available"; + + legacy::PassManager PM; + std::unique_ptr MIR; + std::unique_ptr M = parseMIR(Context, PM, MIR, *TM, MIRFunc); + ASSERT_TRUE(M); + + PM.add(new TestPassT(T, ShouldPass)); + + PM.run(*M); +} + +static void liveIntervalsTest(StringRef MIRFunc, + TestPassT::TestFx T, + bool ShouldPass = true) { + SmallString<512> S; + StringRef MIRString = (Twine(R"MIR( +--- | + define amdgpu_kernel void @func() { ret void } +... +--- +name: func +tracksRegLiveness: true +registers: + - { id: 0, class: vgpr_32 } +body: | + bb.0: +)MIR") + Twine(MIRFunc) + Twine("...\n")).toNullTerminatedStringRef(S); + + doTest(MIRString, T, ShouldPass); +} + +//===----------------------------------------------------------------------===// +// Test 1: Insert new definition and verify SSA repair with PHI insertion +//===----------------------------------------------------------------------===// + +// Test basic PHI insertion and use rewriting in a diamond CFG +// +// CFG Structure: +// BB0 (entry) +// | +// BB1 (%1 = orig def) +// / \ +// BB2 BB3 (INSERT: %1 = new_def) +// \ / +// BB4 (use %1) → PHI expected +// +TEST(MachineLaneSSAUpdaterTest, NewDefInsertsPhiAndRewritesUses) { + liveIntervalsTest(R"MIR( + %0:vgpr_32 = V_MOV_B32_e32 0, implicit $exec + S_BRANCH %bb.1 + + bb.1: + successors: %bb.2, %bb.3 + %1:vgpr_32 = V_ADD_U32_e32 %0, %0, implicit $exec + $sgpr0 = S_MOV_B32 0 + $sgpr1 = S_MOV_B32 1 + S_CMP_LG_U32 $sgpr0, $sgpr1, implicit-def $scc + S_CBRANCH_SCC1 %bb.3, implicit $scc + + bb.2: + successors: %bb.4 + %2:vgpr_32 = V_ADD_U32_e32 %1, %1, implicit $exec + S_BRANCH %bb.4 + + bb.3: + successors: %bb.4 + S_NOP 0 + + bb.4: + %5:vgpr_32 = V_ADD_U32_e32 %1, %1, implicit $exec + S_ENDPGM 0 +)MIR", + [](MachineFunction &MF, LiveIntervalsWrapperPass &LISWrapper) { + LiveIntervals &LIS = LISWrapper.getLIS(); + MachineDominatorTree MDT(MF); + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + + // Verify we have 5 blocks as expected + ASSERT_EQ(MF.size(), 5u) << "Should have bb.0, bb.1, bb.2, bb.3, bb.4"; + + MachineBasicBlock *BB1 = MF.getBlockNumbered(1); + MachineBasicBlock *BB3 = MF.getBlockNumbered(3); + MachineBasicBlock *BB4 = MF.getBlockNumbered(4); + + // Get %1 which is defined in bb.1 (first non-PHI instruction) + MachineInstr *OrigDefMI = &*BB1->getFirstNonPHI(); + ASSERT_TRUE(OrigDefMI) << "Could not find instruction in bb.1"; + ASSERT_TRUE(OrigDefMI->getNumOperands() > 0) << "Instruction has no operands"; + + Register OrigReg = OrigDefMI->getOperand(0).getReg(); + ASSERT_TRUE(OrigReg.isValid()) << "Could not get destination register %1 from bb.1"; + + // Count uses before SSA repair + unsigned UseCountBefore = 0; + for (const MachineInstr &MI : MRI.use_instructions(OrigReg)) { + (void)MI; + ++UseCountBefore; + } + ASSERT_GT(UseCountBefore, 0u) << "Original register should have uses"; + + // Find V_MOV_B32_e32 instruction in bb.0 to get its opcode + MachineBasicBlock *BB0 = MF.getBlockNumbered(0); + MachineInstr *MovInst = &*BB0->begin(); + unsigned MovOpcode = MovInst->getOpcode(); + Register ExecReg = MovInst->getOperand(2).getReg(); // Get EXEC register + + // Create a new definition in bb.3 that defines OrigReg (violating SSA) + // This creates a scenario where bb.4 needs a PHI to merge values from bb.2 and bb.3 + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + auto InsertPt = BB3->getFirstNonPHI(); + MachineInstr *NewDefMI = BuildMI(*BB3, InsertPt, DebugLoc(), + TII->get(MovOpcode), OrigReg) + .addImm(42) + .addReg(ExecReg, RegState::Implicit); + + // Set MachineFunction properties to allow PHIs and indicate SSA form + MF.getProperties().set(MachineFunctionProperties::Property::IsSSA); + MF.getProperties().reset(MachineFunctionProperties::Property::NoPHIs); + + // NOW TEST MachineLaneSSAUpdater: call repairSSAForNewDef + // Before: %1 defined in bb.1, used in bb.2 and bb.4 + // NewDefMI in bb.3 also defines %1 (violating SSA!) + // After repair: NewDefMI will define a new vreg, bb.4 gets PHI + MachineLaneSSAUpdater Updater(MF, LIS, MDT, *TRI); + Register NewReg = Updater.repairSSAForNewDef(*NewDefMI, OrigReg); + + // VERIFY RESULTS: + + // 1. NewReg should be valid and different from OrigReg + EXPECT_TRUE(NewReg.isValid()) << "Updater should create a new register"; + EXPECT_NE(NewReg, OrigReg) << "New register should be different from original"; + + // 2. NewDefMI should now define NewReg (not OrigReg) + EXPECT_EQ(NewDefMI->getOperand(0).getReg(), NewReg) << "NewDefMI should now define the new register"; + + + // 3. Check if PHI nodes were inserted in bb.4 + bool FoundPHI = false; + for (MachineInstr &MI : *BB4) { + if (MI.isPHI()) { + FoundPHI = true; + break; + } + } + EXPECT_TRUE(FoundPHI) << "SSA repair should have inserted PHI node in bb.4"; + + // 4. Verify LiveIntervals are still valid + EXPECT_TRUE(LIS.hasInterval(NewReg)) << "New register should have live interval"; + EXPECT_TRUE(LIS.hasInterval(OrigReg)) << "Original register should still have live interval"; + + // Verify the MachineFunction is still valid after SSA repair + EXPECT_TRUE(MF.verify(nullptr, /* Banner=*/nullptr, /*OS=*/nullptr, /* AbortOnError=*/false)) + << "MachineFunction verification failed after SSA repair"; + }); +} + +//===----------------------------------------------------------------------===// +// Test 2: Multiple PHI insertions in nested control flow +// +// CFG structure: +// bb.0 +// | +// bb.1 (%1 = original def) +// / \ +// bb.2 bb.3 +// | / \ +// | bb.4 bb.5 (new def inserted here) +// | \ / +// | bb.6 (needs first PHI: %X = PHI %1,bb.4 NewDef,bb.5) +// \ / +// bb.7 (needs second PHI: %Y = PHI %1,bb.2 %X,bb.6) +// | +// bb.8 (use) +// +// Key insight: IDF(bb.5) = {bb.6, bb.7} +// - bb.6 needs PHI because it's reachable from bb.4 (has %1) and bb.5 (has new def) +// - bb.7 needs PHI because it's reachable from bb.2 (has %1) and bb.6 (has PHI result %X) +// +// This truly requires TWO PHI nodes for proper SSA form! +//===----------------------------------------------------------------------===// + +TEST(MachineLaneSSAUpdaterTest, MultiplePhiInsertion) { + liveIntervalsTest(R"MIR( + %0:vgpr_32 = V_MOV_B32_e32 0, implicit $exec + S_BRANCH %bb.1 + + bb.1: + successors: %bb.2, %bb.3 + %1:vgpr_32 = V_ADD_U32_e32 %0, %0, implicit $exec + $sgpr0 = S_MOV_B32 0 + $sgpr1 = S_MOV_B32 1 + S_CMP_LG_U32 $sgpr0, $sgpr1, implicit-def $scc + S_CBRANCH_SCC1 %bb.3, implicit $scc + + bb.2: + successors: %bb.7 + %2:vgpr_32 = V_ADD_U32_e32 %1, %1, implicit $exec + S_BRANCH %bb.7 + + bb.3: + successors: %bb.4, %bb.5 + $sgpr2 = S_MOV_B32 0 + $sgpr3 = S_MOV_B32 1 + S_CMP_LG_U32 $sgpr2, $sgpr3, implicit-def $scc + S_CBRANCH_SCC1 %bb.5, implicit $scc + + bb.4: + successors: %bb.6 + %3:vgpr_32 = V_ADD_U32_e32 %1, %1, implicit $exec + S_BRANCH %bb.6 + + bb.5: + successors: %bb.6 + S_NOP 0 + + bb.6: + successors: %bb.7 + %4:vgpr_32 = V_SUB_U32_e32 %1, %1, implicit $exec + + bb.7: + successors: %bb.8 + %5:vgpr_32 = V_AND_B32_e32 %1, %1, implicit $exec + S_BRANCH %bb.8 + + bb.8: + %6:vgpr_32 = V_OR_B32_e32 %1, %1, implicit $exec + S_ENDPGM 0 +)MIR", + [](MachineFunction &MF, LiveIntervalsWrapperPass &LISWrapper) { + LiveIntervals &LIS = LISWrapper.getLIS(); + MachineDominatorTree MDT(MF); + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + + // Verify we have the expected number of blocks + ASSERT_EQ(MF.size(), 9u) << "Should have bb.0 through bb.8"; + + MachineBasicBlock *BB1 = MF.getBlockNumbered(1); + MachineBasicBlock *BB5 = MF.getBlockNumbered(5); + MachineBasicBlock *BB6 = MF.getBlockNumbered(6); + MachineBasicBlock *BB7 = MF.getBlockNumbered(7); + + // Get %1 which is defined in bb.1 + MachineInstr *OrigDefMI = &*BB1->getFirstNonPHI(); + Register OrigReg = OrigDefMI->getOperand(0).getReg(); + ASSERT_TRUE(OrigReg.isValid()) << "Could not get original register"; + + // Count uses of %1 before SSA repair + unsigned UseCountBefore = 0; + for (const MachineInstr &MI : MRI.use_instructions(OrigReg)) { + (void)MI; + ++UseCountBefore; + } + ASSERT_GT(UseCountBefore, 0u) << "Original register should have uses"; + LLVM_DEBUG(dbgs() << "Original register has " << UseCountBefore << " uses before SSA repair\n"); + + // Get V_MOV opcode from bb.0 + MachineBasicBlock *BB0 = MF.getBlockNumbered(0); + MachineInstr *MovInst = &*BB0->begin(); + unsigned MovOpcode = MovInst->getOpcode(); + Register ExecReg = MovInst->getOperand(2).getReg(); + + // Insert new definition in bb.5 that defines OrigReg (violating SSA) + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + auto InsertPt = BB5->getFirstNonPHI(); + MachineInstr *NewDefMI = BuildMI(*BB5, InsertPt, DebugLoc(), + TII->get(MovOpcode), OrigReg) + .addImm(100) + .addReg(ExecReg, RegState::Implicit); + + // Set MachineFunction properties + MF.getProperties().set(MachineFunctionProperties::Property::IsSSA); + MF.getProperties().reset(MachineFunctionProperties::Property::NoPHIs); + + // Call MachineLaneSSAUpdater + MachineLaneSSAUpdater Updater(MF, LIS, MDT, *TRI); + Register NewReg = Updater.repairSSAForNewDef(*NewDefMI, OrigReg); + + EXPECT_TRUE(NewReg.isValid()) << "Updater should create a new register"; + EXPECT_NE(NewReg, OrigReg) << "New register should be different from original"; + EXPECT_EQ(NewDefMI->getOperand(0).getReg(), NewReg) << "NewDefMI should now define the new register"; + + // Count PHI nodes inserted and track their locations + unsigned PHICount = 0; + std::map PHIsPerBlock; + for (MachineBasicBlock &MBB : MF) { + unsigned BlockPHIs = 0; + for (MachineInstr &MI : MBB) { + if (MI.isPHI()) { + ++PHICount; + ++BlockPHIs; + LLVM_DEBUG(dbgs() << "Found PHI in BB#" << MBB.getNumber() << ": "); + LLVM_DEBUG(MI.print(dbgs())); + } + } + if (BlockPHIs > 0) { + PHIsPerBlock[MBB.getNumber()] = BlockPHIs; + } + } + + LLVM_DEBUG(dbgs() << "Total PHI nodes inserted: " << PHICount << "\n"); + + // Check for first PHI in bb.6 (joins bb.4 and bb.5) + bool FoundPHIInBB6 = false; + for (MachineInstr &MI : *BB6) { + if (MI.isPHI()) { + FoundPHIInBB6 = true; + LLVM_DEBUG(dbgs() << "First PHI in bb.6: "); + LLVM_DEBUG(MI.print(dbgs())); + // Verify it has 2 incoming values (4 operands: 2 x (reg, mbb)) + unsigned NumIncoming = (MI.getNumOperands() - 1) / 2; + EXPECT_EQ(NumIncoming, 2u) << "First PHI in bb.6 should have 2 incoming values (from bb.4 and bb.5)"; + break; + } + } + EXPECT_TRUE(FoundPHIInBB6) << "Should have first PHI in bb.6 (joins bb.4 with %1 and bb.5 with new def)"; + + // Check for second PHI in bb.7 (joins bb.2 and bb.6) + bool FoundPHIInBB7 = false; + for (MachineInstr &MI : *BB7) { + if (MI.isPHI()) { + FoundPHIInBB7 = true; + LLVM_DEBUG(dbgs() << "Second PHI in bb.7: "); + LLVM_DEBUG(MI.print(dbgs())); + // Verify it has 2 incoming values (4 operands: 2 x (reg, mbb)) + unsigned NumIncoming = (MI.getNumOperands() - 1) / 2; + EXPECT_EQ(NumIncoming, 2u) << "Second PHI in bb.7 should have 2 incoming values (from bb.2 with %1 and bb.6 with first PHI result)"; + break; + } + } + EXPECT_TRUE(FoundPHIInBB7) << "Should have second PHI in bb.7 (joins bb.2 with %1 and bb.6 with first PHI)"; + + // Should have exactly 2 PHIs + EXPECT_EQ(PHICount, 2u) << "Should have inserted exactly TWO PHI nodes (one at bb.6, one at bb.7)"; + + // Verify LiveIntervals are valid + EXPECT_TRUE(LIS.hasInterval(NewReg)) << "New register should have live interval"; + EXPECT_TRUE(LIS.hasInterval(OrigReg)) << "Original register should have live interval"; + + // Verify the MachineFunction is still valid + EXPECT_TRUE(MF.verify(nullptr, nullptr, nullptr, false)) + << "MachineFunction verification failed"; + }); +} + +//===----------------------------------------------------------------------===// +// Test 3: Subregister lane tracking with partial register updates +// +// This tests the "LaneAware" part of MachineLaneSSAUpdater. +// +// Scenario: +// - Start with a 64-bit register %3 (has sub0 and sub1 lanes) +// - Insert a new definition that only updates sub0 (lower 32 bits) +// - The SSA updater should: +// 1. Track that only sub0 lane is modified (not sub1) +// 2. Create PHI that merges only the sub0 lane +// 3. Preserve the original sub1 lane +// 4. Generate REG_SEQUENCE to compose full register from PHI+unchanged lanes +// +// CFG Structure: +// BB0 (entry) +// | +// BB1 (%3:vreg_64 = REG_SEQUENCE of %1:sub0, %2:sub1) +// / \ +// BB2 BB3 (INSERT: %3.sub0 = new_def) +// | | +// use (no use) +// sub0 +// \ / +// BB4 (use sub0 + sub1) → PHI for sub0 lane only +// | +// BB5 (use full %3) → REG_SEQUENCE to compose full reg from PHI result + unchanged sub1 +// +// Expected behavior: +// - PHI in BB4 merges only sub0 lane (changed) +// - sub1 lane flows unchanged through the diamond +// - REG_SEQUENCE in BB5 composes full 64-bit from (PHI_sub0, original_sub1) +//===----------------------------------------------------------------------===// + +TEST(MachineLaneSSAUpdaterTest, SubregisterLaneTracking) { + liveIntervalsTest(R"MIR( + %0:vgpr_32 = V_MOV_B32_e32 0, implicit $exec + S_BRANCH %bb.1 + + bb.1: + successors: %bb.2, %bb.3 + ; Create vregs in order: %1, %2, %3 + %1:vgpr_32 = V_MOV_B32_e32 10, implicit $exec + %2:vgpr_32 = V_MOV_B32_e32 20, implicit $exec + %3:vreg_64 = REG_SEQUENCE %1, %subreg.sub0, %2, %subreg.sub1 + $sgpr0 = S_MOV_B32 0 + $sgpr1 = S_MOV_B32 1 + S_CMP_LG_U32 $sgpr0, $sgpr1, implicit-def $scc + S_CBRANCH_SCC1 %bb.3, implicit $scc + + bb.2: + successors: %bb.4 + ; Use sub0 lane only + %4:vgpr_32 = V_ADD_U32_e32 %3.sub0, %3.sub0, implicit $exec + S_BRANCH %bb.4 + + bb.3: + successors: %bb.4 + S_NOP 0 + + bb.4: + successors: %bb.5 + ; Use both sub0 and sub1 lanes separately + %5:vgpr_32 = V_ADD_U32_e32 %3.sub0, %3.sub1, implicit $exec + S_BRANCH %bb.5 + + bb.5: + ; Use full 64-bit register (tests REG_SEQUENCE path after PHI) + %6:vreg_64 = V_LSHLREV_B64_e64 0, %3, implicit $exec + S_ENDPGM 0 +)MIR", + [](MachineFunction &MF, LiveIntervalsWrapperPass &LISWrapper) { + LiveIntervals &LIS = LISWrapper.getLIS(); + MachineDominatorTree MDT(MF); + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + + // Verify we have the expected number of blocks + ASSERT_EQ(MF.size(), 6u) << "Should have bb.0 through bb.5"; + + MachineBasicBlock *BB3 = MF.getBlockNumbered(3); + + // Get the 64-bit register %3 (vreg_64) from the MIR + Register Reg64 = Register::index2VirtReg(3); + ASSERT_TRUE(Reg64.isValid()) << "Register %3 should be valid"; + + const TargetRegisterClass *RC64 = MRI.getRegClass(Reg64); + ASSERT_EQ(TRI->getRegSizeInBits(*RC64), 64u) << "Register %3 should be 64-bit"; + LLVM_DEBUG(dbgs() << "Using 64-bit register: %" << Reg64.virtRegIndex() << " (raw: " << Reg64 << ")\n"); + + // Verify it has subranges for lane tracking + ASSERT_TRUE(LIS.hasInterval(Reg64)) << "Register should have live interval"; + LiveInterval &LI = LIS.getInterval(Reg64); + if (LI.hasSubRanges()) { + LLVM_DEBUG(dbgs() << "Register has subranges (lane tracking active)\n"); + for (const LiveInterval::SubRange &SR : LI.subranges()) { + LLVM_DEBUG(dbgs() << " Lane mask: " << PrintLaneMask(SR.LaneMask) << "\n"); + } + } else { + LLVM_DEBUG(dbgs() << "Warning: Register does not have subranges\n"); + } + + // Find the subreg index for a 32-bit subreg of the 64-bit register + unsigned Sub0Idx = 0; + for (unsigned Idx = 1, E = TRI->getNumSubRegIndices(); Idx <= E; ++Idx) { + const TargetRegisterClass *SubRC = TRI->getSubRegisterClass(RC64, Idx); + if (SubRC && TRI->getRegSizeInBits(*SubRC) == 32) { + Sub0Idx = Idx; + break; + } + } + ASSERT_NE(Sub0Idx, 0u) << "Could not find 32-bit subregister index"; + LaneBitmask Sub0Mask = TRI->getSubRegIndexLaneMask(Sub0Idx); + LLVM_DEBUG(dbgs() << "Sub0 index=" << Sub0Idx << " (" << TRI->getSubRegIndexName(Sub0Idx) + << "), mask=" << PrintLaneMask(Sub0Mask) << "\n"); + + // Insert new definition in bb.3 that defines Reg64.sub0 (partial update, violating SSA) + // Use V_MOV with immediate - no liveness dependencies + // It's the caller's responsibility to ensure source operands are valid + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + auto InsertPt = BB3->getFirstNonPHI(); + + // Get V_MOV opcode and EXEC register from bb.0 + MachineBasicBlock *BB0 = MF.getBlockNumbered(0); + MachineInstr *MovInst = &*BB0->begin(); + unsigned MovOpcode = MovInst->getOpcode(); + Register ExecReg = MovInst->getOperand(2).getReg(); + + // Create a 32-bit temporary register + Register TempReg = MRI.createVirtualRegister(TRI->getSubRegisterClass(RC64, Sub0Idx)); + + // Insert both instructions first (V_MOV and COPY) + MachineInstr *TempMI = BuildMI(*BB3, InsertPt, DebugLoc(), TII->get(MovOpcode), TempReg) + .addImm(99) + .addReg(ExecReg, RegState::Implicit); + + MachineInstr *NewDefMI = BuildMI(*BB3, InsertPt, DebugLoc(), + TII->get(TargetOpcode::COPY)) + .addReg(Reg64, RegState::Define, Sub0Idx) // %3.sub0 = (violates SSA) + .addReg(TempReg); // COPY from temp + + // Caller's responsibility: index instructions and create live intervals + // Do this AFTER inserting both instructions so uses are visible + LIS.InsertMachineInstrInMaps(*TempMI); + LIS.InsertMachineInstrInMaps(*NewDefMI); + LIS.createAndComputeVirtRegInterval(TempReg); + + // Set MachineFunction properties + MF.getProperties().set(MachineFunctionProperties::Property::IsSSA); + MF.getProperties().reset(MachineFunctionProperties::Property::NoPHIs); + + // Call MachineLaneSSAUpdater to repair the SSA violation + // This should create a new vreg for the subreg def and insert lane-aware PHIs + MachineLaneSSAUpdater Updater(MF, LIS, MDT, *TRI); + Register NewReg = Updater.repairSSAForNewDef(*NewDefMI, Reg64); + + LLVM_DEBUG(dbgs() << "SSA repair created new register: %" << NewReg.virtRegIndex() << " (raw: " << NewReg << ")\n"); + + // VERIFY RESULTS: + + // 1. NewReg should be a 32-bit register (for sub0), not 64-bit + EXPECT_TRUE(NewReg.isValid()) << "Updater should create a new register"; + EXPECT_NE(NewReg, Reg64) << "New register should be different from original"; + + const TargetRegisterClass *NewRC = MRI.getRegClass(NewReg); + EXPECT_EQ(TRI->getRegSizeInBits(*NewRC), 32u) << "New register should be 32-bit (subreg class)"; + + // 2. NewDefMI should now define NewReg (not Reg64.sub0) + EXPECT_EQ(NewDefMI->getOperand(0).getReg(), NewReg) << "NewDefMI should now define new 32-bit register"; + EXPECT_EQ(NewDefMI->getOperand(0).getSubReg(), 0u) << "NewDefMI should no longer have subreg index"; + + // 3. Verify PHIs were inserted where needed + MachineBasicBlock *BB4 = MF.getBlockNumbered(4); + bool FoundPHI = false; + for (MachineInstr &MI : *BB4) { + if (MI.isPHI()) { + FoundPHI = true; + LLVM_DEBUG(dbgs() << "Found PHI in bb.4: "); + LLVM_DEBUG(MI.print(dbgs())); + break; + } + } + EXPECT_TRUE(FoundPHI) << "Should have inserted PHI for sub0 lane in bb.4"; + + // 4. Verify LiveIntervals are valid + EXPECT_TRUE(LIS.hasInterval(NewReg)) << "New register should have live interval"; + + // Verify the MachineFunction is still valid + EXPECT_TRUE(MF.verify(nullptr, nullptr, nullptr, false)) + << "MachineFunction verification failed"; + }); +} + +//===----------------------------------------------------------------------===// +// Test 4: Subreg def → Full register PHI (REG_SEQUENCE before PHI) +// +// This tests the critical case where: +// - Input MIR has a PHI that expects full 64-bit register from both paths +// - We insert a subreg definition (X.sub0) on one path +// - The updater must build a REG_SEQUENCE before the PHI to combine: +// NewSubreg (sub0) + OriginalReg.sub1 → FullReg for PHI +// +// CFG: +// bb.0 (entry) +// | +// bb.1 (X=1, full 64-bit def) +// / \ +// bb.2 bb.3 +// (Y=2) / \ +// | bb.4 bb.5 (NEW DEF: X.sub0 = 3) ← inserted by test +// | \ / +// | bb.6 (first join: bb.4 + bb.5, may need REG_SEQUENCE) +// | / +// \ / +// bb.7 (second join: PHI Z = PHI(Y, bb.2, X, bb.6)) ← already in input MIR +// | +// bb.8 (use Z) +// +// Expected: REG_SEQUENCE in bb.6 before branching to bb.7 +//===----------------------------------------------------------------------===// + +TEST(MachineLaneSSAUpdaterTest, SubregDefToFullRegPHI) { + liveIntervalsTest(R"MIR( + %0:vgpr_32 = V_MOV_B32_e32 0, implicit $exec + S_BRANCH %bb.1 + + bb.1: + successors: %bb.2, %bb.3 + ; X = 1 (full 64-bit register) + %1:vgpr_32 = V_MOV_B32_e32 10, implicit $exec + %2:vgpr_32 = V_MOV_B32_e32 11, implicit $exec + %3:vreg_64 = REG_SEQUENCE %1, %subreg.sub0, %2, %subreg.sub1 + $sgpr0 = S_MOV_B32 0 + $sgpr1 = S_MOV_B32 1 + S_CMP_LG_U32 $sgpr0, $sgpr1, implicit-def $scc + S_CBRANCH_SCC1 %bb.3, implicit $scc + + bb.2: + successors: %bb.7 + ; Y = 2 (full 64-bit register, different from X) + %4:vgpr_32 = V_MOV_B32_e32 20, implicit $exec + %5:vgpr_32 = V_MOV_B32_e32 21, implicit $exec + %6:vreg_64 = REG_SEQUENCE %4, %subreg.sub0, %5, %subreg.sub1 + S_BRANCH %bb.7 + + bb.3: + successors: %bb.4, %bb.5 + $sgpr2 = S_MOV_B32 0 + $sgpr3 = S_MOV_B32 1 + S_CMP_LG_U32 $sgpr2, $sgpr3, implicit-def $scc + S_CBRANCH_SCC1 %bb.5, implicit $scc + + bb.4: + successors: %bb.6 + S_NOP 0 + S_BRANCH %bb.6 + + bb.5: + successors: %bb.6 + ; New def will be inserted here: X.sub0 = 3 + S_NOP 0 + + bb.6: + successors: %bb.7 + S_BRANCH %bb.7 + + bb.7: + ; PHI already in input MIR, expects full 64-bit from both paths + %7:vreg_64 = PHI %6:vreg_64, %bb.2, %3:vreg_64, %bb.6 + S_BRANCH %bb.8 + + bb.8: + ; Use Z + %8:vreg_64 = V_LSHLREV_B64_e64 0, %7, implicit $exec + S_ENDPGM 0 +)MIR", + [](MachineFunction &MF, LiveIntervalsWrapperPass &LISWrapper) { + LiveIntervals &LIS = LISWrapper.getLIS(); + MachineDominatorTree MDT(MF); + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + + ASSERT_EQ(MF.size(), 9u) << "Should have bb.0 through bb.8"; + + MachineBasicBlock *BB5 = MF.getBlockNumbered(5); // New def inserted here + MachineBasicBlock *BB6 = MF.getBlockNumbered(6); // First join (bb.4 + bb.5) + MachineBasicBlock *BB7 = MF.getBlockNumbered(7); // PHI block (bb.2 + bb.6) + + // Get register X (%3, the 64-bit register from bb.1) + Register RegX = Register::index2VirtReg(3); + ASSERT_TRUE(RegX.isValid()) << "Register %3 (X) should be valid"; + + const TargetRegisterClass *RC64 = MRI.getRegClass(RegX); + ASSERT_EQ(TRI->getRegSizeInBits(*RC64), 64u) << "Register X should be 64-bit"; + + // Find sub0 index (32-bit subregister) + unsigned Sub0Idx = 0; + for (unsigned Idx = 1, E = TRI->getNumSubRegIndices(); Idx <= E; ++Idx) { + const TargetRegisterClass *SubRC = TRI->getSubRegisterClass(RC64, Idx); + if (SubRC && TRI->getRegSizeInBits(*SubRC) == 32) { + Sub0Idx = Idx; + break; + } + } + ASSERT_NE(Sub0Idx, 0u) << "Could not find 32-bit subregister index"; + + // Insert new definition in bb.5: X.sub0 = 3 + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + auto InsertPt = BB5->getFirstNonPHI(); + + // Get V_MOV opcode and EXEC register + MachineBasicBlock *BB0 = MF.getBlockNumbered(0); + MachineInstr *MovInst = &*BB0->begin(); + unsigned MovOpcode = MovInst->getOpcode(); + Register ExecReg = MovInst->getOperand(2).getReg(); + + // Create temporary register + Register TempReg = MRI.createVirtualRegister(TRI->getSubRegisterClass(RC64, Sub0Idx)); + + MachineInstr *TempMI = BuildMI(*BB5, InsertPt, DebugLoc(), TII->get(MovOpcode), TempReg) + .addImm(30) + .addReg(ExecReg, RegState::Implicit); + + MachineInstr *NewDefMI = BuildMI(*BB5, InsertPt, DebugLoc(), + TII->get(TargetOpcode::COPY)) + .addReg(RegX, RegState::Define, Sub0Idx) // X.sub0 = + .addReg(TempReg); + + // Index instructions and create live interval for temp + LIS.InsertMachineInstrInMaps(*TempMI); + LIS.InsertMachineInstrInMaps(*NewDefMI); + LIS.createAndComputeVirtRegInterval(TempReg); + + // Set MachineFunction properties + MF.getProperties().set(MachineFunctionProperties::Property::IsSSA); + MF.getProperties().reset(MachineFunctionProperties::Property::NoPHIs); + + // Call SSA updater + MachineLaneSSAUpdater Updater(MF, LIS, MDT, *TRI); + Register NewReg = Updater.repairSSAForNewDef(*NewDefMI, RegX); + + LLVM_DEBUG(dbgs() << "SSA repair created new register: %" << NewReg.virtRegIndex() << " (raw: " << NewReg << ")\n"); + + // VERIFY RESULTS: + + // 1. New register should be 32-bit (subreg class) + EXPECT_TRUE(NewReg.isValid()); + EXPECT_NE(NewReg, RegX); + const TargetRegisterClass *NewRC = MRI.getRegClass(NewReg); + EXPECT_EQ(TRI->getRegSizeInBits(*NewRC), 32u) << "New register should be 32-bit"; + + // 2. NewDefMI should now define NewReg without subreg index + EXPECT_EQ(NewDefMI->getOperand(0).getReg(), NewReg); + EXPECT_EQ(NewDefMI->getOperand(0).getSubReg(), 0u); + + // 3. Check the existing PHI in bb.7 + bool FoundPHI = false; + Register PHIReg; + for (MachineInstr &MI : *BB7) { + if (MI.isPHI()) { + FoundPHI = true; + PHIReg = MI.getOperand(0).getReg(); + LLVM_DEBUG(dbgs() << "PHI in bb.7 after SSA repair: "); + LLVM_DEBUG(MI.print(dbgs())); + break; + } + } + ASSERT_TRUE(FoundPHI) << "Should have PHI in bb.7 (from input MIR)"; + + // 4. CRITICAL: Check for REG_SEQUENCE in bb.6 (first join, before branch to PHI) + // The updater must build REG_SEQUENCE to provide full register to the PHI + bool FoundREGSEQ = false; + for (MachineInstr &MI : *BB6) { + if (MI.getOpcode() == TargetOpcode::REG_SEQUENCE) { + FoundREGSEQ = true; + LLVM_DEBUG(dbgs() << "Found REG_SEQUENCE in bb.6: "); + LLVM_DEBUG(MI.print(dbgs())); + + // Should combine new sub0 with original sub1 + EXPECT_GE(MI.getNumOperands(), 5u) << "REG_SEQUENCE should have result + 2 source pairs"; + break; + } + } + EXPECT_TRUE(FoundREGSEQ) << "Should have built REG_SEQUENCE in bb.6 to provide full register to PHI in bb.7"; + + // 5. Verify LiveIntervals + EXPECT_TRUE(LIS.hasInterval(NewReg)); + EXPECT_TRUE(LIS.hasInterval(PHIReg)); + + // Verify the MachineFunction is still valid + EXPECT_TRUE(MF.verify(nullptr, nullptr, nullptr, false)) + << "MachineFunction verification failed"; + }); +} + +//===----------------------------------------------------------------------===// +// Test 5: Loop with new def in loop body (PHI in loop header) +// +// This tests SSA repair when a new definition is inserted inside a loop, +// requiring a PHI node in the loop header to merge: +// - Entry path: original value from before the loop +// - Back edge: new value from loop body +// +// CFG: +// bb.0 (entry, X = 1) +// | +// v +// bb.1 (loop header) ← PHI needed: %PHI = PHI(X, bb.0, NewReg, bb.2) +// / \ +// / \ +// bb.2 bb.3 (loop exit, use X) +// (loop +// body, +// new def) +// | +// └──→ bb.1 (back edge) +// +// Key test: Dominance-based PHI construction should correctly use NewReg +// for the back edge operand since NewDefBB (bb.2) dominates the loop latch (bb.2). +//===----------------------------------------------------------------------===// + +// Test loop with new definition in loop body requiring PHI in loop header +// +// CFG Structure: +// BB0 (entry, %1 = orig def) +// | +// +-> BB1 (loop header) +// | / \ +// | / \ +// BB2 BB3 (exit, use %1) +// | +// (INSERT: %1 = new_def) +// | +// +-(backedge) -> PHI needed in BB1 to merge initial value and loop value +// +TEST(MachineLaneSSAUpdaterTest, LoopWithDefInBody) { + liveIntervalsTest(R"MIR( + %0:vgpr_32 = V_MOV_B32_e32 0, implicit $exec + ; Original definition of %1 (before loop) + %1:vgpr_32 = V_ADD_U32_e32 %0, %0, implicit $exec + S_BRANCH %bb.1 + + bb.1: + successors: %bb.2, %bb.3 + ; Loop header - PHI should be inserted here + $sgpr0 = S_MOV_B32 0 + $sgpr1 = S_MOV_B32 1 + S_CMP_LG_U32 $sgpr0, $sgpr1, implicit-def $scc + S_CBRANCH_SCC1 %bb.3, implicit $scc + + bb.2: + successors: %bb.1 + ; Loop body - new def will be inserted here + %2:vgpr_32 = V_ADD_U32_e32 %1, %1, implicit $exec + S_BRANCH %bb.1 + + bb.3: + ; Loop exit - use %1 + %3:vgpr_32 = V_ADD_U32_e32 %1, %1, implicit $exec + S_ENDPGM 0 +)MIR", + [](MachineFunction &MF, LiveIntervalsWrapperPass &LISWrapper) { + LiveIntervals &LIS = LISWrapper.getLIS(); + MachineDominatorTree MDT(MF); + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + + ASSERT_EQ(MF.size(), 4u) << "Should have bb.0 through bb.3"; + + MachineBasicBlock *BB0 = MF.getBlockNumbered(0); // Entry with original def + MachineBasicBlock *BB1 = MF.getBlockNumbered(1); // Loop header + MachineBasicBlock *BB2 = MF.getBlockNumbered(2); // Loop body + + // Get %1 (defined in bb.0, used in loop) + // Skip the first V_MOV instruction, get the V_ADD + auto It = BB0->begin(); + ++It; // Skip %0 = V_MOV + MachineInstr *OrigDefMI = &*It; + Register OrigReg = OrigDefMI->getOperand(0).getReg(); + ASSERT_TRUE(OrigReg.isValid()) << "Could not get original register"; + + LLVM_DEBUG(dbgs() << "Original register: %" << OrigReg.virtRegIndex() << "\n"); + + // Insert new definition in loop body (bb.2) + // This violates SSA because %1 is defined both in bb.0 and bb.2 + MachineInstr *MovInst = &*BB0->begin(); + unsigned MovOpcode = MovInst->getOpcode(); + Register ExecReg = MovInst->getOperand(2).getReg(); + + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + auto InsertPt = BB2->getFirstNonPHI(); + MachineInstr *NewDefMI = BuildMI(*BB2, InsertPt, DebugLoc(), + TII->get(MovOpcode), OrigReg) + .addImm(99) + .addReg(ExecReg, RegState::Implicit); + + // Set MachineFunction properties + MF.getProperties().set(MachineFunctionProperties::Property::IsSSA); + MF.getProperties().reset(MachineFunctionProperties::Property::NoPHIs); + + // Call SSA updater + MachineLaneSSAUpdater Updater(MF, LIS, MDT, *TRI); + Register NewReg = Updater.repairSSAForNewDef(*NewDefMI, OrigReg); + + LLVM_DEBUG(dbgs() << "SSA repair created new register: %" << NewReg.virtRegIndex() << "\n"); + + // VERIFY RESULTS: + + // 1. NewReg should be valid and different from OrigReg + EXPECT_TRUE(NewReg.isValid()); + EXPECT_NE(NewReg, OrigReg); + + // 2. NewDefMI should now define NewReg + EXPECT_EQ(NewDefMI->getOperand(0).getReg(), NewReg); + + // 3. PHI should be inserted in loop header (bb.1) + bool FoundPHIInHeader = false; + for (MachineInstr &MI : *BB1) { + if (MI.isPHI()) { + FoundPHIInHeader = true; + LLVM_DEBUG(dbgs() << "Found PHI in loop header (bb.1): "); + LLVM_DEBUG(MI.print(dbgs())); + + // Verify PHI has 2 incoming values + unsigned NumIncoming = (MI.getNumOperands() - 1) / 2; + EXPECT_EQ(NumIncoming, 2u) << "Loop header PHI should have 2 incoming values"; + + // Check the operands + // One should be from bb.0 (entry, using OrigReg) + // One should be from bb.2 (back edge, using NewReg) + bool HasEntryPath = false; + bool HasBackEdge = false; + + for (unsigned i = 1; i < MI.getNumOperands(); i += 2) { + Register IncomingReg = MI.getOperand(i).getReg(); + MachineBasicBlock *IncomingMBB = MI.getOperand(i + 1).getMBB(); + + if (IncomingMBB == BB0) { + HasEntryPath = true; + EXPECT_EQ(IncomingReg, OrigReg) << "Entry path should use OrigReg"; + LLVM_DEBUG(dbgs() << " Entry path (bb.0): %" << IncomingReg.virtRegIndex() << "\n"); + } else if (IncomingMBB == BB2) { + HasBackEdge = true; + EXPECT_EQ(IncomingReg, NewReg) << "Back edge should use NewReg"; + LLVM_DEBUG(dbgs() << " Back edge (bb.2): %" << IncomingReg.virtRegIndex() << "\n"); + } + } + + EXPECT_TRUE(HasEntryPath) << "PHI should have entry path from bb.0"; + EXPECT_TRUE(HasBackEdge) << "PHI should have back edge from bb.2"; + + break; + } + } + EXPECT_TRUE(FoundPHIInHeader) << "Should have inserted PHI in loop header (bb.1)"; + + // 4. Verify LiveIntervals are valid + EXPECT_TRUE(LIS.hasInterval(NewReg)); + EXPECT_TRUE(LIS.hasInterval(OrigReg)); + + // Verify the MachineFunction is still valid + EXPECT_TRUE(MF.verify(nullptr, nullptr, nullptr, false)) + << "MachineFunction verification failed"; + }); +} + +//===----------------------------------------------------------------------===// +// Test 6: Complex loop with diamond CFG and use-before-def +// +// This is the most comprehensive test combining multiple SSA repair scenarios: +// 1. Loop with existing PHI (induction variable) +// 2. Use before redefinition (in loop header) +// 3. New definition in one branch of if-then-else diamond +// 4. PHI1 at diamond join +// 5. PHI2 at loop header (merges entry value and PHI1 result from back edge) +// 6. Use after diamond (in latch) should use PHI1 result +// +// CFG Structure: +// BB0 (entry: X=%1, i=0) +// | +// +-> BB1 (loop header) +// | | PHI_i = PHI(0, BB0; i+1, BB5) [already in input MIR] +// | | PHI2 = PHI(X, BB0; PHI1, BB5) [created by SSA updater] +// | | USE X (before redef!) [rewritten to use PHI2] +// | | if (i < 10) +// | / \ +// | BB2 BB3 (INSERT: X = 99) +// | | | +// | | (then: X unchanged) +// | | (else: NEW DEF) +// | \ / +// | BB4 (diamond join) +// | | PHI1 = PHI(X, BB2; NewReg, BB3) [created by SSA updater] +// | | +// | BB5 (loop latch) +// | | USE X [rewritten to use PHI1] +// | | i = i + 1 +// | | \ +// | | \ +// +---+ BB6 (exit, USE X) +// +// Key challenge: Use in BB1 occurs BEFORE the def in BB3 (in program order), +// requiring PHI2 in the loop header for proper SSA form. +// +// Expected SSA repair: +// - PHI1 created in BB4 (diamond join): merges unchanged X from BB2, new def from BB3 +// - PHI2 created in BB1 (loop header): merges entry X from BB0, PHI1 result from BB5 +// - Use in BB1 rewritten to PHI2 +// - Use in BB5 rewritten to PHI1 +//===----------------------------------------------------------------------===// +TEST(MachineLaneSSAUpdaterTest, ComplexLoopWithDiamondAndUseBeforeDef) { + liveIntervalsTest(R"MIR( + %0:vgpr_32 = V_MOV_B32_e32 0, implicit $exec + ; X = 1 (the register we'll redefine in loop) + %1:vgpr_32 = V_MOV_B32_e32 1, implicit $exec + ; i = 0 (induction variable) + %2:vgpr_32 = V_MOV_B32_e32 0, implicit $exec + S_BRANCH %bb.1 + + bb.1: + successors: %bb.2, %bb.3 + ; Loop header with existing PHI for induction variable + %3:vgpr_32 = PHI %2:vgpr_32, %bb.0, %10:vgpr_32, %bb.5 + ; USE X before redefinition - should be rewritten to PHI2 + %4:vgpr_32 = V_ADD_U32_e32 %1, %1, implicit $exec + ; Check if i < 10 + %5:vgpr_32 = V_MOV_B32_e32 10, implicit $exec + $sgpr0 = S_MOV_B32 0 + $sgpr1 = S_MOV_B32 1 + S_CMP_LG_U32 $sgpr0, $sgpr1, implicit-def $scc + S_CBRANCH_SCC1 %bb.3, implicit $scc + + bb.2: + successors: %bb.4 + ; Then branch - X unchanged + S_NOP 0 + S_BRANCH %bb.4 + + bb.3: + successors: %bb.4 + ; Else branch - NEW DEF will be inserted here: X = 99 + S_NOP 0 + + bb.4: + successors: %bb.5 + ; Diamond join - PHI1 should be created here + S_NOP 0 + + bb.5: + successors: %bb.1, %bb.6 + ; Loop latch - USE X (should be rewritten to PHI1) + %8:vgpr_32 = V_SUB_U32_e32 %1, %1, implicit $exec + ; i = i + 1 + %9:vgpr_32 = V_MOV_B32_e32 1, implicit $exec + %10:vgpr_32 = V_ADD_U32_e32 %3, %9, implicit $exec + ; Check loop condition + $sgpr2 = S_MOV_B32 0 + $sgpr3 = S_MOV_B32 1 + S_CMP_LG_U32 $sgpr2, $sgpr3, implicit-def $scc + S_CBRANCH_SCC1 %bb.6, implicit $scc + S_BRANCH %bb.1 + + bb.6: + ; Loop exit - USE X + %11:vgpr_32 = V_OR_B32_e32 %1, %1, implicit $exec + S_ENDPGM 0 +)MIR", + [](MachineFunction &MF, LiveIntervalsWrapperPass &LISWrapper) { + LiveIntervals &LIS = LISWrapper.getLIS(); + MachineDominatorTree MDT(MF); + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + + ASSERT_EQ(MF.size(), 7u) << "Should have bb.0 through bb.6"; + + MachineBasicBlock *BB0 = MF.getBlockNumbered(0); // Entry + MachineBasicBlock *BB1 = MF.getBlockNumbered(1); // Loop header + MachineBasicBlock *BB3 = MF.getBlockNumbered(3); // Else (new def here) + MachineBasicBlock *BB4 = MF.getBlockNumbered(4); // Diamond join + MachineBasicBlock *BB5 = MF.getBlockNumbered(5); // Latch + + // Get %1 (X, defined in bb.0) + auto It = BB0->begin(); + ++It; // Skip %0 = V_MOV_B32_e32 0 + MachineInstr *OrigDefMI = &*It; // %1 = V_MOV_B32_e32 1 + Register OrigReg = OrigDefMI->getOperand(0).getReg(); + ASSERT_TRUE(OrigReg.isValid()) << "Could not get original register X"; + + LLVM_DEBUG(dbgs() << "Original register X: %" << OrigReg.virtRegIndex() << "\n"); + + // Find the use-before-def in bb.1 (loop header) + MachineInstr *UseBeforeDefMI = nullptr; + for (MachineInstr &MI : *BB1) { + if (!MI.isPHI() && MI.getOpcode() != TargetOpcode::IMPLICIT_DEF) { + // First non-PHI instruction should be V_ADD using %1 + if (MI.getNumOperands() >= 3 && MI.getOperand(1).isReg() && + MI.getOperand(1).getReg() == OrigReg) { + UseBeforeDefMI = &MI; + break; + } + } + } + ASSERT_TRUE(UseBeforeDefMI) << "Could not find use-before-def in loop header"; + LLVM_DEBUG(dbgs() << "Found use-before-def in bb.1: %" + << UseBeforeDefMI->getOperand(0).getReg().virtRegIndex() << "\n"); + + // Insert new definition in bb.3 (else branch): X = 99 + MachineInstr *MovInst = &*BB0->begin(); + unsigned MovOpcode = MovInst->getOpcode(); + Register ExecReg = MovInst->getOperand(2).getReg(); + + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + auto InsertPt = BB3->getFirstNonPHI(); + MachineInstr *NewDefMI = BuildMI(*BB3, InsertPt, DebugLoc(), + TII->get(MovOpcode), OrigReg) + .addImm(99) + .addReg(ExecReg, RegState::Implicit); + + // Set MachineFunction properties + MF.getProperties().set(MachineFunctionProperties::Property::IsSSA); + MF.getProperties().reset(MachineFunctionProperties::Property::NoPHIs); + + // Call SSA updater + MachineLaneSSAUpdater Updater(MF, LIS, MDT, *TRI); + Register NewReg = Updater.repairSSAForNewDef(*NewDefMI, OrigReg); + + LLVM_DEBUG(dbgs() << "SSA repair created new register: %" << NewReg.virtRegIndex() << "\n"); + + // VERIFY RESULTS: + + // 1. NewReg should be valid and different from OrigReg + EXPECT_TRUE(NewReg.isValid()); + EXPECT_NE(NewReg, OrigReg); + EXPECT_EQ(NewDefMI->getOperand(0).getReg(), NewReg); + + // 2. PHI1 should exist in diamond join (bb.4) + bool FoundPHI1 = false; + Register PHI1Reg; + for (MachineInstr &MI : *BB4) { + if (MI.isPHI()) { + FoundPHI1 = true; + PHI1Reg = MI.getOperand(0).getReg(); + LLVM_DEBUG(dbgs() << "Found PHI1 in diamond join (bb.4): "); + LLVM_DEBUG(MI.print(dbgs())); + + // Should have 2 incoming: OrigReg from bb.2, NewReg from bb.3 + unsigned NumIncoming = (MI.getNumOperands() - 1) / 2; + EXPECT_EQ(NumIncoming, 2u) << "Diamond join PHI should have 2 incoming"; + break; + } + } + EXPECT_TRUE(FoundPHI1) << "Should have PHI1 in diamond join (bb.4)"; + + // 3. PHI2 should exist in loop header (bb.1) + // First, count all PHIs + unsigned TotalPHICount = 0; + for (MachineInstr &MI : *BB1) { + if (MI.isPHI()) + TotalPHICount++; + } + LLVM_DEBUG(dbgs() << "Total PHIs in loop header: " << TotalPHICount << "\n"); + EXPECT_EQ(TotalPHICount, 2u) << "Loop header should have 2 PHIs (induction var + SSA repair)"; + + // Now find the SSA repair PHI (not the induction variable PHI %3) + bool FoundPHI2 = false; + Register PHI2Reg; + Register InductionVarPHI = Register::index2VirtReg(3); // %3 from input MIR + for (MachineInstr &MI : *BB1) { + if (MI.isPHI()) { + Register PHIResult = MI.getOperand(0).getReg(); + + // Skip the induction variable PHI (%3 from input MIR) when looking for SSA repair PHI + if (PHIResult == InductionVarPHI) + continue; + + FoundPHI2 = true; + PHI2Reg = PHIResult; + LLVM_DEBUG(dbgs() << "Found PHI2 (SSA repair) in loop header (bb.1): "); + LLVM_DEBUG(MI.print(dbgs())); + + // Should have 2 incoming: OrigReg from bb.0, PHI1Reg from bb.5 + unsigned NumIncoming = (MI.getNumOperands() - 1) / 2; + EXPECT_EQ(NumIncoming, 2u) << "Loop header PHI2 should have 2 incoming"; + + // Verify operands + bool HasEntryPath = false; + bool HasBackEdge = false; + for (unsigned i = 1; i < MI.getNumOperands(); i += 2) { + Register IncomingReg = MI.getOperand(i).getReg(); + MachineBasicBlock *IncomingMBB = MI.getOperand(i + 1).getMBB(); + + if (IncomingMBB == BB0) { + HasEntryPath = true; + EXPECT_EQ(IncomingReg, OrigReg) << "Entry path should use OrigReg"; + } else if (IncomingMBB == BB5) { + HasBackEdge = true; + EXPECT_EQ(IncomingReg, PHI1Reg) << "Back edge should use PHI1 result"; + } + } + + EXPECT_TRUE(HasEntryPath) << "PHI2 should have entry path from bb.0"; + EXPECT_TRUE(HasBackEdge) << "PHI2 should have back edge from bb.5"; + break; + } + } + EXPECT_TRUE(FoundPHI2) << "Should have PHI2 (SSA repair) in loop header (bb.1)"; + + // 4. Use-before-def in bb.1 should be rewritten to PHI2 + EXPECT_EQ(UseBeforeDefMI->getOperand(1).getReg(), PHI2Reg) + << "Use-before-def should be rewritten to PHI2 result"; + LLVM_DEBUG(dbgs() << "Use-before-def correctly rewritten to PHI2: %" + << PHI2Reg.virtRegIndex() << "\n"); + + // 5. Use in latch (bb.5) should be rewritten to PHI1 + // Find instruction using PHI1 (originally used %1) + bool FoundLatchUse = false; + for (MachineInstr &MI : *BB5) { + // Skip PHIs and branches + if (MI.isPHI() || MI.isBranch()) + continue; + + // Look for any instruction that uses PHI1Reg + for (unsigned i = 0; i < MI.getNumOperands(); ++i) { + MachineOperand &MO = MI.getOperand(i); + if (MO.isReg() && MO.isUse() && MO.getReg() == PHI1Reg) { + LLVM_DEBUG(dbgs() << "Latch use correctly rewritten to PHI1: %" + << PHI1Reg.virtRegIndex() << " in: "); + LLVM_DEBUG(MI.print(dbgs())); + FoundLatchUse = true; + break; + } + } + if (FoundLatchUse) + break; + } + EXPECT_TRUE(FoundLatchUse) << "Should find use of PHI1 in latch (bb.5)"; + + // 6. Verify LiveIntervals + EXPECT_TRUE(LIS.hasInterval(NewReg)); + EXPECT_TRUE(LIS.hasInterval(PHI1Reg)); + EXPECT_TRUE(LIS.hasInterval(PHI2Reg)); + + // Verify the MachineFunction is still valid + EXPECT_TRUE(MF.verify(nullptr, nullptr, nullptr, false)) + << "MachineFunction verification failed"; + }); +} + +// Test 7: Multiple subreg redefinitions in loop (X.sub0 in one branch, X.sub1 in latch) +// This tests the most complex scenario: two separate lane redefinitions with REG_SEQUENCE +// composition at the backedge. +// Test multiple subregister redefinitions in different paths within a loop +// +// CFG Structure: +// BB0 (entry, %1:vreg_64 = IMPLICIT_DEF) +// | +// +-> BB1 (loop header, PHI for %0) +// | | (use %0.sub0) +// | / \ +// | BB2 BB5 +// | | | +// | use INSERT: %0.sub0 = new_def1 +// |sub1 use %0.sub0 +// | \ / +// | BB3 (latch) +// | | (INSERT: %3.sub1 = new_def2, where %3 is increment result) +// | | (%3 = %0 << 1) +// +---+ +// | +// BB4 (exit) +// +// Key: Two separate lane redefinitions requiring separate SSA repairs: +// 1. %0.sub0 in BB5 → PHI for sub0 in BB3 +// 2. %3.sub1 in BB3 (after increment) → PHI for sub1 in BB1 +// +TEST(MachineLaneSSAUpdaterTest, MultipleSubregRedefsInLoop) { + SmallString<2048> S; + StringRef MIRString = (Twine(R"MIR( +--- | + define amdgpu_kernel void @func() { ret void } +... +--- +name: func +tracksRegLiveness: true +registers: + - { id: 0, class: vreg_64 } + - { id: 1, class: vreg_64 } + - { id: 2, class: vgpr_32 } + - { id: 3, class: vreg_64 } +body: | + bb.0: + successors: %bb.1 + %1:vreg_64 = IMPLICIT_DEF + + bb.1: + successors: %bb.2, %bb.5 + %0:vreg_64 = PHI %1:vreg_64, %bb.0, %3:vreg_64, %bb.3 + %2:vgpr_32 = V_MOV_B32_e32 10, implicit $exec + dead %4:vgpr_32 = V_ADD_U32_e32 %0.sub0:vreg_64, %2:vgpr_32, implicit $exec + $sgpr0 = S_MOV_B32 0 + $sgpr1 = S_MOV_B32 1 + S_CMP_LG_U32 $sgpr0, $sgpr1, implicit-def $scc + S_CBRANCH_SCC1 %bb.2, implicit $scc + S_BRANCH %bb.5 + + bb.2: + successors: %bb.3 + dead %5:vgpr_32 = V_MOV_B32_e32 %0.sub1:vreg_64, implicit $exec + S_BRANCH %bb.3 + + bb.5: + successors: %bb.3 + dead %6:vgpr_32 = V_MOV_B32_e32 %0.sub0:vreg_64, implicit $exec + S_BRANCH %bb.3 + + bb.3: + successors: %bb.1, %bb.4 + %3:vreg_64 = V_LSHLREV_B64_e64 1, %0:vreg_64, implicit $exec + $sgpr2 = S_MOV_B32 0 + $sgpr3 = S_MOV_B32 10 + S_CMP_LT_U32 $sgpr2, $sgpr3, implicit-def $scc + S_CBRANCH_SCC1 %bb.1, implicit $scc + S_BRANCH %bb.4 + + bb.4: + S_ENDPGM 0 +... +)MIR")).toNullTerminatedStringRef(S); + + doTest(MIRString, + [](MachineFunction &MF, LiveIntervalsWrapperPass &LISWrapper) { + LiveIntervals &LIS = LISWrapper.getLIS(); + MachineDominatorTree MDT(MF); + LLVM_DEBUG(dbgs() << "\n=== MultipleSubregRedefsInLoop Test ===\n"); + + // Get basic blocks + auto BBI = MF.begin(); + ++BBI; // Skip BB0 (Entry) + MachineBasicBlock *BB1 = &*BBI++; // Loop header + ++BBI; // Skip BB2 (True branch) + MachineBasicBlock *BB5 = &*BBI++; // False branch (uses X.LO, INSERT def X.LO) + MachineBasicBlock *BB3 = &*BBI++; // Latch (increment, INSERT def X.HI) + // Skip BB4 (Exit) + + MachineRegisterInfo &MRI = MF.getRegInfo(); + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + (void)MRI; // May be unused, suppress warning + + // Find the 64-bit register and its subregister indices + Register OrigReg = Register::index2VirtReg(0); // %0 from MIR + ASSERT_TRUE(OrigReg.isValid()) << "Register %0 should be valid"; + unsigned Sub0Idx = 0, Sub1Idx = 0; + + // Find sub0 (low 32 bits) and sub1 (high 32 bits) + for (unsigned Idx = 1; Idx < TRI->getNumSubRegIndices(); ++Idx) { + LaneBitmask Mask = TRI->getSubRegIndexLaneMask(Idx); + unsigned SubRegSize = TRI->getSubRegIdxSize(Idx); + + if (SubRegSize == 32) { + if (Mask.getAsInteger() == 0x3) { // Low lanes + Sub0Idx = Idx; + } else if (Mask.getAsInteger() == 0xC) { // High lanes + Sub1Idx = Idx; + } + } + } + + ASSERT_NE(Sub0Idx, 0u) << "Should find sub0 index"; + ASSERT_NE(Sub1Idx, 0u) << "Should find sub1 index"; + + LLVM_DEBUG(dbgs() << "Using 64-bit register: %" << OrigReg.virtRegIndex() + << " with sub0=" << Sub0Idx << ", sub1=" << Sub1Idx << "\n"); + + // Get V_MOV opcode and EXEC register from existing instruction + MachineInstr *MovInst = nullptr; + Register ExecReg; + for (MachineInstr &MI : *BB1) { + if (!MI.isPHI() && MI.getNumOperands() >= 3 && MI.getOperand(2).isReg()) { + MovInst = &MI; + ExecReg = MI.getOperand(2).getReg(); + break; + } + } + ASSERT_NE(MovInst, nullptr) << "Should find V_MOV in BB1"; + unsigned MovOpcode = MovInst->getOpcode(); + + // === FIRST INSERTION: X.sub0 in BB5 (else branch) === + LLVM_DEBUG(dbgs() << "\n=== First insertion: X.sub0 in BB5 ===\n"); + + // Find insertion point in BB5 (after the use of X.sub0) + MachineInstr *InsertPoint1 = nullptr; + for (MachineInstr &MI : *BB5) { + if (MI.isBranch()) { + InsertPoint1 = &MI; + break; + } + } + ASSERT_NE(InsertPoint1, nullptr) << "Should find branch in BB5"; + + // Create first new def: X.sub0 = 99 + MachineInstrBuilder MIB1 = BuildMI(*BB5, InsertPoint1, DebugLoc(), + TII->get(MovOpcode)) + .addReg(OrigReg, RegState::Define, Sub0Idx) + .addImm(99) + .addReg(ExecReg, RegState::Implicit); + + MachineInstr &NewDefMI1 = *MIB1; + LLVM_DEBUG(dbgs() << "Created first def in BB5: "); + LLVM_DEBUG(NewDefMI1.print(dbgs())); + + // Create SSA updater and repair after first insertion + MachineLaneSSAUpdater Updater(MF, LIS, MDT, *TRI); + Register NewReg1 = Updater.repairSSAForNewDef(NewDefMI1, OrigReg); + + LLVM_DEBUG(dbgs() << "SSA repair #1 created new register: %" << NewReg1.virtRegIndex() << "\n"); + + // === SECOND INSERTION: X.sub1 in BB3 (after increment) === + LLVM_DEBUG(dbgs() << "\n=== Second insertion: X.sub1 in BB3 (after increment) ===\n"); + + // Find the increment instruction in BB3 (look for vreg_64 def) + MachineInstr *IncrementMI = nullptr; + Register IncrementReg; + for (MachineInstr &MI : *BB3) { + if (!MI.isPHI() && MI.getNumOperands() > 0 && MI.getOperand(0).isReg() && + MI.getOperand(0).isDef()) { + Register DefReg = MI.getOperand(0).getReg(); + if (DefReg.isVirtual() && DefReg == Register::index2VirtReg(3)) { + IncrementMI = &MI; + IncrementReg = DefReg; // This is %3 + LLVM_DEBUG(dbgs() << "Found increment: "); + LLVM_DEBUG(MI.print(dbgs())); + break; + } + } + } + ASSERT_NE(IncrementMI, nullptr) << "Should find increment (def of %3) in BB3"; + ASSERT_TRUE(IncrementReg.isValid()) << "Increment register should be valid"; + + // Create second new def: %3.sub1 = 200 (redefine increment result's sub1) + MachineBasicBlock::iterator InsertPoint2 = std::next(IncrementMI->getIterator()); + MachineInstrBuilder MIB2 = BuildMI(*BB3, InsertPoint2, DebugLoc(), + TII->get(MovOpcode)) + .addReg(IncrementReg, RegState::Define, Sub1Idx) // Redefine %3.sub1, not %0.sub1! + .addImm(200) + .addReg(ExecReg, RegState::Implicit); + + MachineInstr &NewDefMI2 = *MIB2; + LLVM_DEBUG(dbgs() << "Created second def in BB3 (redefining %3.sub1): "); + LLVM_DEBUG(NewDefMI2.print(dbgs())); + + // Repair SSA after second insertion (for %3, the increment result) + Register NewReg2 = Updater.repairSSAForNewDef(NewDefMI2, IncrementReg); + + LLVM_DEBUG(dbgs() << "SSA repair #2 created new register: %" << NewReg2.virtRegIndex() << "\n"); + + // === Verification === + LLVM_DEBUG(dbgs() << "\n=== Verification ===\n"); + + // Print final MIR + LLVM_DEBUG(dbgs() << "Final BB3 (latch):\n"); + LLVM_DEBUG(for (MachineInstr &MI : *BB3) { + MI.print(dbgs()); + }); + + // 1. Should have PHI for 32-bit X.sub0 at BB3 (diamond join) + bool FoundSub0PHI = false; + for (MachineInstr &MI : *BB3) { + if (MI.isPHI()) { + Register PHIResult = MI.getOperand(0).getReg(); + if (PHIResult != Register::index2VirtReg(3)) { // Not the increment result PHI + FoundSub0PHI = true; + LLVM_DEBUG(dbgs() << "Found sub0 PHI in BB3: "); + LLVM_DEBUG(MI.print(dbgs())); + } + } + } + EXPECT_TRUE(FoundSub0PHI) << "Should have PHI for sub0 lane in BB3"; + + // 2. Should have REG_SEQUENCE in BB3 before backedge to compose full 64-bit + bool FoundREGSEQ = false; + for (MachineInstr &MI : *BB3) { + if (MI.getOpcode() == TargetOpcode::REG_SEQUENCE) { + FoundREGSEQ = true; + LLVM_DEBUG(dbgs() << "Found REG_SEQUENCE in BB3: "); + LLVM_DEBUG(MI.print(dbgs())); + + // Verify it composes both lanes + unsigned NumSources = (MI.getNumOperands() - 1) / 2; + EXPECT_GE(NumSources, 2u) << "REG_SEQUENCE should have at least 2 sources (sub0 and sub1)"; + } + } + + EXPECT_TRUE(FoundREGSEQ) << "Should have REG_SEQUENCE at backedge in BB3"; + + // 3. Verify LiveIntervals + EXPECT_TRUE(LIS.hasInterval(NewReg1)); + EXPECT_TRUE(LIS.hasInterval(NewReg2)); + + // Verify the MachineFunction is still valid + EXPECT_TRUE(MF.verify(nullptr, nullptr, nullptr, false)) + << "MachineFunction verification failed"; + }); +} + +// Test 8: Nested loops with SSA repair across multiple loop levels +// This tests SSA repair with a new definition in an inner loop body that propagates +// to both the inner loop header and outer loop header PHIs. +// Test nested loops with SSA repair across multiple loop levels +// +// CFG Structure: +// BB0 (entry, %0 = 100) +// | +// +-> BB1 (outer loop header) +// | | PHI for %1 (outer induction var) +// | | +// | +->BB2 (inner loop header) +// | | | PHI for %2 (inner induction var) +// | | |\ +// | | | \ +// | | BB3 BB4 (outer loop body) +// | | | +// | | INSERT: %0 = new_def +// | | (%3 = %2 + %0) +// | | | +// | +--+ (inner backedge) -> PHI in BB2 for %0 expected +// | | +// | (%4 = %1 + %0, use %0) +// +----+ (outer backedge) +// | +// BB5 (exit) +// +// Key: New def in inner loop body propagates to: +// 1. Inner loop header PHI (BB2) +// 2. Outer loop body uses (BB4) +// 3. Outer loop header PHI (BB1) +// +TEST(MachineLaneSSAUpdaterTest, NestedLoopsWithSSARepair) { + SmallString<2048> S; + StringRef MIRString = (Twine(R"MIR( +--- | + define amdgpu_kernel void @func() { ret void } +... +--- +name: func +tracksRegLiveness: true +registers: + - { id: 0, class: vgpr_32 } + - { id: 1, class: vgpr_32 } + - { id: 2, class: vgpr_32 } + - { id: 3, class: vgpr_32 } +body: | + bb.0: + successors: %bb.1 + %0:vgpr_32 = V_MOV_B32_e32 100, implicit $exec + S_BRANCH %bb.1 + + bb.1: + successors: %bb.2 + ; Outer loop header: %1 = PHI(initial, result_from_outer_body) + %1:vgpr_32 = PHI %0:vgpr_32, %bb.0, %4:vgpr_32, %bb.4 + dead %5:vgpr_32 = V_ADD_U32_e32 %1:vgpr_32, %1:vgpr_32, implicit $exec + S_BRANCH %bb.2 + + bb.2: + successors: %bb.3, %bb.4 + ; Inner loop header: %2 = PHI(from_outer, from_inner_body) + %2:vgpr_32 = PHI %1:vgpr_32, %bb.1, %3:vgpr_32, %bb.3 + dead %6:vgpr_32 = V_MOV_B32_e32 %2:vgpr_32, implicit $exec + $sgpr0 = S_MOV_B32 0 + $sgpr1 = S_MOV_B32 5 + S_CMP_LT_U32 $sgpr0, $sgpr1, implicit-def $scc + S_CBRANCH_SCC1 %bb.3, implicit $scc + S_BRANCH %bb.4 + + bb.3: + successors: %bb.2 + ; Inner loop body - accumulate value, then we'll insert new def for %0 + %3:vgpr_32 = V_ADD_U32_e32 %2:vgpr_32, %0:vgpr_32, implicit $exec + S_BRANCH %bb.2 + + bb.4: + successors: %bb.1, %bb.5 + ; Outer loop body after inner loop exit + ; Increment outer induction variable %1 and use %0 (which we'll redefine) + %4:vgpr_32 = V_ADD_U32_e32 %1:vgpr_32, %0:vgpr_32, implicit $exec + dead %7:vgpr_32 = V_MOV_B32_e32 %0:vgpr_32, implicit $exec + $sgpr2 = S_MOV_B32 0 + $sgpr3 = S_MOV_B32 10 + S_CMP_LT_U32 $sgpr2, $sgpr3, implicit-def $scc + S_CBRANCH_SCC1 %bb.1, implicit $scc + S_BRANCH %bb.5 + + bb.5: + ; Exit + S_ENDPGM 0 +... +)MIR")).toNullTerminatedStringRef(S); + + doTest(MIRString, + [](MachineFunction &MF, LiveIntervalsWrapperPass &LISWrapper) { + LiveIntervals &LIS = LISWrapper.getLIS(); + MachineDominatorTree MDT(MF); + LLVM_DEBUG(dbgs() << "\n=== NestedLoopsWithSSARepair Test ===\n"); + + // Get basic blocks + auto BBI = MF.begin(); + MachineBasicBlock *BB0 = &*BBI++; // Entry + MachineBasicBlock *BB1 = &*BBI++; // Outer loop header + MachineBasicBlock *BB2 = &*BBI++; // Inner loop header + MachineBasicBlock *BB3 = &*BBI++; // Inner loop body (INSERT HERE) + MachineBasicBlock *BB4 = &*BBI++; // Outer loop body (after inner) + // BB5 = Exit (not needed) + + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + + // Get the register that will be redefined (%0 is the initial value) + Register OrigReg = Register::index2VirtReg(0); + ASSERT_TRUE(OrigReg.isValid()) << "Register %0 should be valid"; + + LLVM_DEBUG(dbgs() << "Original register: %" << OrigReg.virtRegIndex() << "\n"); + + // Get V_MOV opcode and EXEC register + MachineInstr *MovInst = &*BB0->begin(); + unsigned MovOpcode = MovInst->getOpcode(); + Register ExecReg = MovInst->getOperand(2).getReg(); + + // Print initial state + LLVM_DEBUG(dbgs() << "\nInitial BB2 (inner loop header):\n"); + for (MachineInstr &MI : *BB2) { + LLVM_DEBUG(MI.print(dbgs())); + } + + LLVM_DEBUG(dbgs() << "\nInitial BB1 (outer loop header):\n"); + for (MachineInstr &MI : *BB1) { + LLVM_DEBUG(MI.print(dbgs())); + } + + // Insert new definition in BB3 (inner loop body) + // Find insertion point before the branch + MachineInstr *InsertPt = nullptr; + for (MachineInstr &MI : *BB3) { + if (MI.isBranch()) { + InsertPt = &MI; + break; + } + } + ASSERT_NE(InsertPt, nullptr) << "Should find branch in BB3"; + + // Insert: X = 999 (violates SSA) + MachineInstr *NewDefMI = BuildMI(*BB3, InsertPt, DebugLoc(), + TII->get(MovOpcode), OrigReg) + .addImm(999) + .addReg(ExecReg, RegState::Implicit); + + LLVM_DEBUG(dbgs() << "\nInserted new def in BB3 (inner loop body): "); + LLVM_DEBUG(NewDefMI->print(dbgs())); + + // Create SSA updater and repair + MachineLaneSSAUpdater Updater(MF, LIS, MDT, *TRI); + Register NewReg = Updater.repairSSAForNewDef(*NewDefMI, OrigReg); + + LLVM_DEBUG(dbgs() << "SSA repair created new register: %" << NewReg.virtRegIndex() << "\n"); + + // === Verification === + LLVM_DEBUG(dbgs() << "\n=== Verification ===\n"); + + LLVM_DEBUG(dbgs() << "\nFinal BB2 (inner loop header):\n"); + for (MachineInstr &MI : *BB2) { + LLVM_DEBUG(MI.print(dbgs())); + } + + LLVM_DEBUG(dbgs() << "\nFinal BB1 (outer loop header):\n"); + for (MachineInstr &MI : *BB1) { + LLVM_DEBUG(MI.print(dbgs())); + } + + LLVM_DEBUG(dbgs() << "\nFinal BB4 (outer loop body after inner):\n"); + for (MachineInstr &MI : *BB4) { + LLVM_DEBUG(MI.print(dbgs())); + } + + // 1. Inner loop header (BB2) should have NEW PHI created by SSA repair + bool FoundSSARepairPHI = false; + Register SSARepairPHIReg; + for (MachineInstr &MI : *BB2) { + if (MI.isPHI()) { + // Look for a PHI that has NewReg as one of its incoming values + for (unsigned i = 1; i < MI.getNumOperands(); i += 2) { + Register IncomingReg = MI.getOperand(i).getReg(); + MachineBasicBlock *IncomingMBB = MI.getOperand(i + 1).getMBB(); + + if (IncomingMBB == BB3 && IncomingReg == NewReg) { + FoundSSARepairPHI = true; + SSARepairPHIReg = MI.getOperand(0).getReg(); + LLVM_DEBUG(dbgs() << "Found SSA repair PHI in inner loop header: "); + LLVM_DEBUG(MI.print(dbgs())); + + // Should have incoming from BB1 and BB3 + unsigned NumIncoming = (MI.getNumOperands() - 1) / 2; + EXPECT_EQ(NumIncoming, 2u) << "SSA repair PHI should have 2 incoming"; + break; + } + } + if (FoundSSARepairPHI) + break; + } + } + EXPECT_TRUE(FoundSSARepairPHI) << "Should find SSA repair PHI in BB2 (inner loop header)"; + + // 2. Outer loop header (BB1) may have PHI updated if needed + bool FoundOuterPHI = false; + for (MachineInstr &MI : *BB1) { + if (MI.isPHI() && MI.getOperand(0).getReg() == Register::index2VirtReg(1)) { + FoundOuterPHI = true; + LLVM_DEBUG(dbgs() << "Found outer loop PHI: "); + LLVM_DEBUG(MI.print(dbgs())); + } + } + EXPECT_TRUE(FoundOuterPHI) << "Should find outer loop PHI in BB1"; + + // 3. Use in BB4 should be updated + bool FoundUseInBB4 = false; + for (MachineInstr &MI : *BB4) { + if (!MI.isPHI() && MI.getNumOperands() > 1) { + for (unsigned i = 0; i < MI.getNumOperands(); ++i) { + if (MI.getOperand(i).isReg() && MI.getOperand(i).isUse()) { + Register UseReg = MI.getOperand(i).getReg(); + if (UseReg.isVirtual()) { + FoundUseInBB4 = true; + LLVM_DEBUG(dbgs() << "Found use in BB4: %" << UseReg.virtRegIndex() << " in "); + LLVM_DEBUG(MI.print(dbgs())); + } + } + } + } + } + EXPECT_TRUE(FoundUseInBB4) << "Should find uses in outer loop body (BB4)"; + + // 4. Verify LiveIntervals + EXPECT_TRUE(LIS.hasInterval(NewReg)); + + // Verify the MachineFunction is still valid + EXPECT_TRUE(MF.verify(nullptr, nullptr, nullptr, false)) + << "MachineFunction verification failed"; + }); +} + +//===----------------------------------------------------------------------===// +// Test 9: 128-bit register with 64-bit subreg redef and multiple lane uses +// +// This comprehensive test covers: +// 1. Large register (128-bit) with multiple subregisters (sub0, sub1, sub2, sub3) +// 2. Partial redefinition (64-bit sub2_3 covering two lanes: sub2+sub3) +// 3. Uses of changed lanes (sub2, sub3) in different paths +// 4. Uses of unchanged lanes (sub0, sub1) in different paths +// 5. Diamond CFG with redef in one branch +// 6. Second diamond to test propagation of PHI result +// +// CFG Structure: +// BB0 (entry) +// | +// BB1 (%0:vreg_128 = initial 128-bit value) +// | +// BB2 (diamond1 split) +// / \ +// BB3 BB4 (INSERT: %0.sub2_3 = new_def) +// | | +// use use +// sub0 sub3 (changed) +// \ / +// BB5 (join) -> PHI for sub2_3 lanes (sub2+sub3 changed, sub0+sub1 unchanged) +// | +// use sub1 (unchanged, flows from BB1) +// | +// BB6 (diamond2 split) +// / \ +// BB7 BB8 +// | | +// use (no use) +// sub2 +// \ / +// BB9 (join, no PHI - BB5's PHI dominates) +// | +// BB10 (use sub0, exit) +// +// Expected behavior: +// - PHI in BB5 merges sub2_3 lanes ONLY (sub2+sub3 changed) +// - sub0+sub1 lanes flow unchanged from BB1 through entire CFG +// - Uses in BB5, BB7, BB10 use PHI result or unchanged lanes +// - No PHI in BB9 (BB5 dominates, PHI result flows through) +// +// This test validates: +// ✓ Partial redefinition (64-bit of 128-bit) +// ✓ Multiple different subreg uses (sub0, sub1, sub2, sub3) +// ✓ Changed vs unchanged lane tracking +// ✓ PHI result propagation to dominated blocks +//===----------------------------------------------------------------------===// +TEST(MachineLaneSSAUpdaterTest, MultipleSubregUsesAcrossDiamonds) { + SmallString<4096> S; + StringRef MIRString = (Twine(R"MIR( +--- | + define amdgpu_kernel void @func() { ret void } +... +--- +name: func +tracksRegLiveness: true +registers: + - { id: 0, class: vreg_128 } + - { id: 1, class: vgpr_32 } + - { id: 2, class: vgpr_32 } + - { id: 3, class: vgpr_32 } + - { id: 4, class: vgpr_32 } +body: | + bb.0: + successors: %bb.1 + S_BRANCH %bb.1 + + bb.1: + successors: %bb.2 + ; Initialize 128-bit register %0 with IMPLICIT_DEF + %0:vreg_128 = IMPLICIT_DEF + S_BRANCH %bb.2 + + bb.2: + successors: %bb.3, %bb.4 + ; Diamond 1 split + $sgpr0 = S_MOV_B32 0 + $sgpr1 = S_MOV_B32 1 + S_CMP_LG_U32 $sgpr0, $sgpr1, implicit-def $scc + S_CBRANCH_SCC1 %bb.4, implicit $scc + + bb.3: + successors: %bb.5 + ; Use sub0 (unchanged lane, low 32 bits) + %1:vgpr_32 = V_MOV_B32_e32 %0.sub0:vreg_128, implicit $exec + S_BRANCH %bb.5 + + bb.4: + successors: %bb.5 + ; This is where we'll INSERT: %0.sub2_3 = new_def (64-bit, covers sub2+sub3) + ; After insertion, use sub3 (high 32 bits of sub2_3) + %2:vgpr_32 = V_MOV_B32_e32 %0.sub3:vreg_128, implicit $exec + S_BRANCH %bb.5 + + bb.5: + successors: %bb.6 + ; Diamond 1 join - PHI expected for sub2_3 lanes + ; Use sub1 (unchanged lane, bits 32-63) + %3:vgpr_32 = V_MOV_B32_e32 %0.sub1:vreg_128, implicit $exec + S_BRANCH %bb.6 + + bb.6: + successors: %bb.7, %bb.8 + ; Diamond 2 split + $sgpr2 = S_MOV_B32 0 + $sgpr3 = S_MOV_B32 1 + S_CMP_LG_U32 $sgpr2, $sgpr3, implicit-def $scc + S_CBRANCH_SCC1 %bb.8, implicit $scc + + bb.7: + successors: %bb.9 + ; Use sub2 (changed lane, bits 64-95) + dead %4:vgpr_32 = V_MOV_B32_e32 %0.sub2:vreg_128, implicit $exec + S_BRANCH %bb.9 + + bb.8: + successors: %bb.9 + ; No use - sparse use pattern + S_NOP 0 + + bb.9: + successors: %bb.10 + ; Diamond 2 join - no PHI needed (BB5 dominates) + S_NOP 0 + + bb.10: + ; Exit - use sub0 again (unchanged lane) + dead %5:vgpr_32 = V_MOV_B32_e32 %0.sub0:vreg_128, implicit $exec + S_ENDPGM 0 +... +)MIR")).toNullTerminatedStringRef(S); + + doTest(MIRString, + [](MachineFunction &MF, LiveIntervalsWrapperPass &LISWrapper) { + LiveIntervals &LIS = LISWrapper.getLIS(); + MachineDominatorTree MDT(MF); + LLVM_DEBUG(dbgs() << "\n=== MultipleSubregUsesAcrossDiamonds Test ===\n"); + + // Get basic blocks + auto BBI = MF.begin(); + ++BBI; // Skip BB0 (entry) + ++BBI; // Skip BB1 (Initial def) + ++BBI; // Skip BB2 (Diamond1 split) + MachineBasicBlock *BB3 = &*BBI++; // Diamond1 true (no redef) + MachineBasicBlock *BB4 = &*BBI++; // Diamond1 false (INSERT HERE) + MachineBasicBlock *BB5 = &*BBI++; // Diamond1 join + + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + (void)MRI; // May be unused, suppress warning + + // Find the 128-bit register %0 + Register OrigReg = Register::index2VirtReg(0); + ASSERT_TRUE(OrigReg.isValid()) << "Register %0 should be valid"; + + LLVM_DEBUG(dbgs() << "Using 128-bit register: %" << OrigReg.virtRegIndex() << "\n"); + + // Find sub2_3 subregister index (64-bit covering bits 64-127) + unsigned Sub2_3Idx = 0; + for (unsigned Idx = 1; Idx < TRI->getNumSubRegIndices(); ++Idx) { + unsigned SubRegSize = TRI->getSubRegIdxSize(Idx); + LaneBitmask Mask = TRI->getSubRegIndexLaneMask(Idx); + + // Looking for 64-bit subreg covering upper half (lanes for sub2+sub3) + // sub2_3 should have mask 0xF0 (lanes for bits 64-127) + if (SubRegSize == 64 && (Mask.getAsInteger() & 0xF0) == 0xF0) { + Sub2_3Idx = Idx; + LLVM_DEBUG(dbgs() << "Found sub2_3 index: " << Idx + << " (size=" << SubRegSize + << ", mask=0x" << llvm::format("%X", Mask.getAsInteger()) << ")\n"); + break; + } + } + + ASSERT_NE(Sub2_3Idx, 0u) << "Should find sub2_3 subregister index"; + + // Insert new definition in BB4: %0.sub2_3 = IMPLICIT_DEF + // Find insertion point (before the use of sub3) + MachineInstr *UseOfSub3 = nullptr; + + for (MachineInstr &MI : *BB4) { + if (MI.getNumOperands() >= 2 && MI.getOperand(0).isReg() && + MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == OrigReg) { + UseOfSub3 = &MI; + break; + } + } + ASSERT_NE(UseOfSub3, nullptr) << "Should find use of sub3 in BB4"; + + // Create new def: %0.sub2_3 = IMPLICIT_DEF + // We use IMPLICIT_DEF because it works for any register size and the SSA updater + // doesn't care about the specific instruction semantics - we're just testing SSA repair + MachineInstrBuilder MIB = BuildMI(*BB4, UseOfSub3, DebugLoc(), + TII->get(TargetOpcode::IMPLICIT_DEF)) + .addDef(OrigReg, RegState::Define, Sub2_3Idx); + + MachineInstr *NewDefMI = MIB.getInstr(); + LLVM_DEBUG(dbgs() << "Inserted new def in BB4: "); + LLVM_DEBUG(NewDefMI->print(dbgs())); + + // Index the new instruction + LIS.InsertMachineInstrInMaps(*NewDefMI); + + // Set MachineFunction properties to allow PHI insertion + MF.getProperties().set(MachineFunctionProperties::Property::IsSSA); + MF.getProperties().reset(MachineFunctionProperties::Property::NoPHIs); + + // Create SSA updater and repair + MachineLaneSSAUpdater Updater(MF, LIS, MDT, *TRI); + Register NewReg = Updater.repairSSAForNewDef(*NewDefMI, OrigReg); + + LLVM_DEBUG(dbgs() << "SSA repair created new register: %" << NewReg.virtRegIndex() << "\n"); + + // Print final state of key blocks + LLVM_DEBUG(dbgs() << "\nFinal BB5 (diamond1 join):\n"); + for (MachineInstr &MI : *BB5) { + LLVM_DEBUG(MI.print(dbgs())); + } + + // Verify SSA repair results + + // 1. Should have PHI in BB5 for sub2+sub3 lanes + bool FoundPHI = false; + for (MachineInstr &MI : *BB5) { + if (MI.isPHI()) { + Register PHIResult = MI.getOperand(0).getReg(); + if (PHIResult.isVirtual()) { + LLVM_DEBUG(dbgs() << "Found PHI in BB5: "); + LLVM_DEBUG(MI.print(dbgs())); + + // Check that it has 2 incoming values + unsigned NumIncoming = (MI.getNumOperands() - 1) / 2; + EXPECT_EQ(NumIncoming, 2u) << "PHI should have 2 incoming values"; + + // Check that one incoming is the new register from BB4 + // and the other incoming from BB3 uses %0.sub2_3 + bool HasNewRegFromBB4 = false; + bool HasCorrectSubregFromBB3 = false; + for (unsigned i = 1; i < MI.getNumOperands(); i += 2) { + Register IncomingReg = MI.getOperand(i).getReg(); + unsigned IncomingSubReg = MI.getOperand(i).getSubReg(); + MachineBasicBlock *IncomingMBB = MI.getOperand(i + 1).getMBB(); + + if (IncomingMBB == BB4) { + HasNewRegFromBB4 = (IncomingReg == NewReg); + LLVM_DEBUG(dbgs() << " Incoming from BB4: %" << IncomingReg.virtRegIndex() << "\n"); + } else if (IncomingMBB == BB3) { + // Should be %0.sub2_3 (the lanes we redefined) + LLVM_DEBUG(dbgs() << " Incoming from BB3: %" << IncomingReg.virtRegIndex()); + if (IncomingSubReg) { + LLVM_DEBUG(dbgs() << "." << TRI->getSubRegIndexName(IncomingSubReg)); + } + LLVM_DEBUG(dbgs() << "\n"); + + // Verify it's using sub2_3 + if (IncomingReg == OrigReg && IncomingSubReg == Sub2_3Idx) { + HasCorrectSubregFromBB3 = true; + } + } + } + EXPECT_TRUE(HasNewRegFromBB4) << "PHI should use NewReg from BB4"; + EXPECT_TRUE(HasCorrectSubregFromBB3) << "PHI should use %0.sub2_3 from BB3"; + FoundPHI = true; + } + } + } + EXPECT_TRUE(FoundPHI) << "Should find PHI in BB5 for sub2_3 lanes"; + + // 2. Verify LiveIntervals + EXPECT_TRUE(LIS.hasInterval(NewReg)); + EXPECT_TRUE(LIS.hasInterval(OrigReg)); + + // 3. Verify LiveInterval for OrigReg has subranges for changed lanes + LiveInterval &OrigLI = LIS.getInterval(OrigReg); + EXPECT_TRUE(OrigLI.hasSubRanges()) << "OrigReg should have subranges after partial redef"; + + // Verify the MachineFunction is still valid + EXPECT_TRUE(MF.verify(nullptr, nullptr, nullptr, false)) + << "MachineFunction verification failed"; + }); +} + +// Test 10: Non-contiguous lane mask - redefine sub1 of 128-bit, use full register +// This specifically tests the multi-source REG_SEQUENCE code path for non-contiguous lanes +// +// CFG Structure: +// BB0 (entry) +// | +// v +// BB1 (%0:vreg_128 = IMPLICIT_DEF) +// | +// v +// BB2 (diamond split) +// / \ +// / \ +// v v +// BB3 BB4 (%0.sub1 = IMPLICIT_DEF - redefine middle lane!) +// \ / +// \ / +// v +// BB5 (diamond join - USE %0 as full register) +// | +// v +// BB6 (exit) +// +// Key Property: Redefining sub1 leaves LanesFromOld = sub0 + sub2 + sub3 (non-contiguous!) +// This requires getCoveringSubRegsForLaneMask to decompose into multiple subregs +// Expected REG_SEQUENCE: %RS = REG_SEQUENCE %6, sub1, %0.sub0, sub0, %0.sub2_3, sub2_3 +// +TEST(MachineLaneSSAUpdaterTest, NonContiguousLaneMaskREGSEQUENCE) { + SmallString<4096> S; + StringRef MIRString = (Twine(R"MIR( +--- | + define amdgpu_kernel void @func() { ret void } +... +--- +name: func +tracksRegLiveness: true +registers: + - { id: 0, class: vreg_128 } + - { id: 1, class: vreg_128 } +body: | + bb.0: + successors: %bb.1 + S_BRANCH %bb.1 + + bb.1: + successors: %bb.2, %bb.3 + %0:vreg_128 = IMPLICIT_DEF + $sgpr0 = S_MOV_B32 0 + $sgpr1 = S_MOV_B32 1 + S_CMP_LG_U32 $sgpr0, $sgpr1, implicit-def $scc + S_CBRANCH_SCC1 %bb.3, implicit $scc + + bb.2: + successors: %bb.4 + ; Left path - no redefinition + S_NOP 0 + S_BRANCH %bb.4 + + bb.3: + successors: %bb.4 + ; Right path - THIS IS WHERE WE'LL INSERT: %0.sub1 = IMPLICIT_DEF + S_NOP 0 + S_BRANCH %bb.4 + + bb.4: + ; Diamond join - use FULL register (this will need REG_SEQUENCE!) + ; Using full %0 (not a subreg) forces composition of non-contiguous lanes + dead %1:vreg_128 = COPY %0:vreg_128 + S_ENDPGM 0 +... +)MIR")).toNullTerminatedStringRef(S); + + doTest(MIRString, + [](MachineFunction &MF, LiveIntervalsWrapperPass &LISWrapper) { + LiveIntervals &LIS = LISWrapper.getLIS(); + MachineDominatorTree MDT(MF); + LLVM_DEBUG(dbgs() << "\n=== NonContiguousLaneMaskREGSEQUENCE Test ===\n"); + + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + (void)MRI; // May be unused, suppress warning + + // Find blocks + // bb.0 = entry + // bb.1 = IMPLICIT_DEF + diamond split + // bb.2 = left path (no redef) + // bb.3 = right path (INSERT sub1 def here) + // bb.4 = diamond join (use full register) + MachineBasicBlock *BB3 = MF.getBlockNumbered(3); // Right path - where we insert + MachineBasicBlock *BB4 = MF.getBlockNumbered(4); // Join - where we need REG_SEQUENCE + + // Find %0 (the vreg_128) + Register OrigReg = Register::index2VirtReg(0); + ASSERT_TRUE(OrigReg.isValid()) << "Register %0 should be valid"; + LLVM_DEBUG(dbgs() << "Using 128-bit register: %" << OrigReg.virtRegIndex() << "\n"); + + // Find sub1 subregister index + unsigned Sub1Idx = 0; + for (unsigned Idx = 1; Idx < TRI->getNumSubRegIndices(); ++Idx) { + StringRef Name = TRI->getSubRegIndexName(Idx); + if (Name == "sub1") { + Sub1Idx = Idx; + break; + } + } + + ASSERT_NE(Sub1Idx, 0u) << "Should find sub1 subregister index"; + + // Insert new definition in BB3 (right path): %0.sub1 = IMPLICIT_DEF + MachineInstrBuilder MIB = BuildMI(*BB3, BB3->getFirstNonPHI(), DebugLoc(), + TII->get(TargetOpcode::IMPLICIT_DEF)) + .addDef(OrigReg, RegState::Define, Sub1Idx); + + MachineInstr *NewDefMI = MIB.getInstr(); + LLVM_DEBUG(dbgs() << "Inserted new def in BB3: "); + LLVM_DEBUG(NewDefMI->print(dbgs())); + + // Index the new instruction + LIS.InsertMachineInstrInMaps(*NewDefMI); + + // Set MachineFunction properties to allow PHI insertion + MF.getProperties().set(MachineFunctionProperties::Property::IsSSA); + MF.getProperties().reset(MachineFunctionProperties::Property::NoPHIs); + + // Create SSA updater and repair + MachineLaneSSAUpdater Updater(MF, LIS, MDT, *TRI); + Register NewReg = Updater.repairSSAForNewDef(*NewDefMI, OrigReg); + + LLVM_DEBUG(dbgs() << "SSA repair created new register: %" << NewReg.virtRegIndex() << "\n"); + + // Print final state + LLVM_DEBUG(dbgs() << "\nFinal BB4 (diamond join):\n"); + for (MachineInstr &MI : *BB4) { + LLVM_DEBUG(MI.print(dbgs())); + } + + // Verify SSA repair results + + // 1. Should have PHI in BB4 for sub1 lane + bool FoundPHI = false; + Register PHIReg; + for (MachineInstr &MI : *BB4) { + if (MI.isPHI()) { + PHIReg = MI.getOperand(0).getReg(); + if (PHIReg.isVirtual()) { + LLVM_DEBUG(dbgs() << "Found PHI in BB4: "); + LLVM_DEBUG(MI.print(dbgs())); + FoundPHI = true; + + // Check that it has 2 incoming values + unsigned NumIncoming = (MI.getNumOperands() - 1) / 2; + EXPECT_EQ(NumIncoming, 2u) << "PHI should have 2 incoming values"; + + // One incoming should be the new register (vgpr_32 from BB3) + bool HasNewRegFromBB3 = false; + for (unsigned i = 1; i < MI.getNumOperands(); i += 2) { + if (MI.getOperand(i).isReg() && MI.getOperand(i).getReg() == NewReg) { + EXPECT_EQ(MI.getOperand(i + 1).getMBB(), BB3) << "NewReg should come from BB3"; + HasNewRegFromBB3 = true; + } + } + EXPECT_TRUE(HasNewRegFromBB3) << "PHI should have NewReg from BB3"; + + break; + } + } + } + + EXPECT_TRUE(FoundPHI) << "Should create PHI in BB4 for sub1 lane"; + + // 2. Most importantly: Should have REG_SEQUENCE with MULTIPLE sources for non-contiguous lanes + // After PHI for sub1, we need to compose full register: + // LanesFromOld = sub0 + sub2 + sub3 (non-contiguous!) + // This requires multiple REG_SEQUENCE operands + bool FoundREGSEQUENCE = false; + unsigned NumREGSEQSources = 0; + + for (MachineInstr &MI : *BB4) { + if (MI.getOpcode() == TargetOpcode::REG_SEQUENCE) { + LLVM_DEBUG(dbgs() << "Found REG_SEQUENCE: "); + LLVM_DEBUG(MI.print(dbgs())); + FoundREGSEQUENCE = true; + + // Count sources (each source is: register + subregidx, so pairs) + NumREGSEQSources = (MI.getNumOperands() - 1) / 2; + LLVM_DEBUG(dbgs() << " REG_SEQUENCE has " << NumREGSEQSources << " sources\n"); + + // We expect at least 2 sources for non-contiguous case: + // 1. PHI result covering sub1 + // 2. One or more sources from OrigReg covering sub0, sub2, sub3 + EXPECT_GE(NumREGSEQSources, 2u) + << "REG_SEQUENCE should have multiple sources for non-contiguous lanes"; + + // Verify at least one source is the PHI result + bool HasPHISource = false; + for (unsigned i = 1; i < MI.getNumOperands(); i += 2) { + if (MI.getOperand(i).isReg() && MI.getOperand(i).getReg() == PHIReg) { + HasPHISource = true; + break; + } + } + EXPECT_TRUE(HasPHISource) << "REG_SEQUENCE should use PHI result"; + + break; + } + } + + EXPECT_TRUE(FoundREGSEQUENCE) + << "Should create REG_SEQUENCE to compose full register from non-contiguous lanes"; + + // 3. The COPY use should now reference the REG_SEQUENCE result (not %0) + bool FoundRewrittenUse = false; + for (MachineInstr &MI : *BB4) { + if (MI.getOpcode() == TargetOpcode::COPY) { + MachineOperand &SrcOp = MI.getOperand(1); + if (SrcOp.isReg() && SrcOp.getReg().isVirtual() && SrcOp.getReg() != OrigReg) { + LLVM_DEBUG(dbgs() << "Found rewritten COPY: "); + LLVM_DEBUG(MI.print(dbgs())); + FoundRewrittenUse = true; + break; + } + } + } + + EXPECT_TRUE(FoundRewrittenUse) << "COPY should be rewritten to use REG_SEQUENCE result"; + + // Print summary + LLVM_DEBUG(dbgs() << "\n=== Test Summary ===\n"); + LLVM_DEBUG(dbgs() << "✓ Redefined sub1 (middle lane) of vreg_128\n"); + LLVM_DEBUG(dbgs() << "✓ Created PHI for sub1 lane\n"); + LLVM_DEBUG(dbgs() << "✓ Created REG_SEQUENCE with " << NumREGSEQSources + << " sources to handle non-contiguous lanes (sub0 + sub2 + sub3)\n"); + LLVM_DEBUG(dbgs() << "✓ This test exercises getCoveringSubRegsForLaneMask!\n"); + }); +} + +} // anonymous namespace