Skip to content
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/LiveIntervals.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class VirtRegMap;
LiveIntervals();
~LiveIntervals() override;

const TargetInstrInfo &getTargetInstrInfo() const { return *TII; }

/// Calculate the spill weight to assign to a single instruction.
static float getSpillWeight(bool isDef, bool isUse,
const MachineBlockFrequencyInfo *MBFI,
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ namespace llvm {
///
FunctionPass *createGreedyRegisterAllocator();
FunctionPass *createGreedyRegisterAllocator(RegClassFilterFunc F);
FunctionPass *createGreedyRegisterAllocator(RegClassFilterFunc F,
LiveIntervalFilterFunc LIF);

/// PBQPRegisterAllocation Pass - This pass implements the Partitioned Boolean
/// Quadratic Prograaming (PBQP) based register allocator.
Expand Down
15 changes: 15 additions & 0 deletions llvm/include/llvm/CodeGen/RegAllocCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ namespace llvm {
class TargetRegisterClass;
class TargetRegisterInfo;

class MachineRegisterInfo;
class TargetInstrInfo;
class LiveInterval;

typedef std::function<bool(const TargetRegisterInfo &TRI,
const TargetRegisterClass &RC)> RegClassFilterFunc;

Expand All @@ -26,6 +30,17 @@ static inline bool allocateAllRegClasses(const TargetRegisterInfo &,
return true;
}

typedef std::function<bool(MachineRegisterInfo &MRI, const TargetInstrInfo &TII,
const LiveInterval *LI)>
LiveIntervalFilterFunc;
/// Default live interval filter function for register allocation. All live
/// intervals should be allocated.
static inline bool allocateAllLiveIntervals(MachineRegisterInfo &,
const TargetInstrInfo &,
const LiveInterval *) {
return true;
}

} // namespace llvm

#endif // LLVM_CODEGEN_REGALLOCCOMMON_H
9 changes: 7 additions & 2 deletions llvm/lib/CodeGen/RegAllocBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,13 @@ void RegAllocBase::enqueue(const LiveInterval *LI) {

const TargetRegisterClass &RC = *MRI->getRegClass(Reg);
if (ShouldAllocateClass(*TRI, RC)) {
LLVM_DEBUG(dbgs() << "Enqueuing " << printReg(Reg, TRI) << '\n');
enqueueImpl(LI);
if (ShouldAllocateLiveInterval(*MRI, LIS->getTargetInstrInfo(), LI)) {
LLVM_DEBUG(dbgs() << "Enqueuing " << printReg(Reg, TRI) << '\n');
enqueueImpl(LI);
} else {
LLVM_DEBUG(dbgs() << "Not enqueueing " << printReg(Reg, TRI)
<< " in skipped live interval\n");
}
} else {
LLVM_DEBUG(dbgs() << "Not enqueueing " << printReg(Reg, TRI)
<< " in skipped register class\n");
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/CodeGen/RegAllocBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,17 @@ class RegAllocBase {
LiveRegMatrix *Matrix = nullptr;
RegisterClassInfo RegClassInfo;
const RegClassFilterFunc ShouldAllocateClass;
const LiveIntervalFilterFunc ShouldAllocateLiveInterval;

/// Inst which is a def of an original reg and whose defs are already all
/// dead after remat is saved in DeadRemats. The deletion of such inst is
/// postponed till all the allocations are done, so its remat expr is
/// always available for the remat of all the siblings of the original reg.
SmallPtrSet<MachineInstr *, 32> DeadRemats;

RegAllocBase(const RegClassFilterFunc F = allocateAllRegClasses) :
ShouldAllocateClass(F) {}
RegAllocBase(const RegClassFilterFunc F = allocateAllRegClasses,
const LiveIntervalFilterFunc LIF = allocateAllLiveIntervals)
: ShouldAllocateClass(F), ShouldAllocateLiveInterval(LIF) {}

virtual ~RegAllocBase() = default;

Expand Down
10 changes: 7 additions & 3 deletions llvm/lib/CodeGen/RegAllocGreedy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,15 @@ FunctionPass *llvm::createGreedyRegisterAllocator(RegClassFilterFunc Ftor) {
return new RAGreedy(Ftor);
}

RAGreedy::RAGreedy(RegClassFilterFunc F):
MachineFunctionPass(ID),
RegAllocBase(F) {
FunctionPass *
llvm::createGreedyRegisterAllocator(RegClassFilterFunc Ftor,
LiveIntervalFilterFunc LIFtor) {
return new RAGreedy(Ftor, LIFtor);
}

RAGreedy::RAGreedy(RegClassFilterFunc F, LiveIntervalFilterFunc LIF)
: MachineFunctionPass(ID), RegAllocBase(F, LIF) {}

void RAGreedy::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
AU.addRequired<MachineBlockFrequencyInfo>();
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/RegAllocGreedy.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ class LLVM_LIBRARY_VISIBILITY RAGreedy : public MachineFunctionPass,
bool ReverseLocalAssignment = false;

public:
RAGreedy(const RegClassFilterFunc F = allocateAllRegClasses);
RAGreedy(const RegClassFilterFunc F = allocateAllRegClasses,
const LiveIntervalFilterFunc LIF = allocateAllLiveIntervals);

/// Return the pass name.
StringRef getPassName() const override { return "Greedy Register Allocator"; }
Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Target/AIE/AIE2InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,6 @@ AIE2InstrInfo::getSpillPseudoExpandInfo(const MachineInstr &MI) const {
{AIE2::LDA_dms_spill, AIE2::sub_dim_size},
{AIE2::LDA_dms_spill, AIE2::sub_dim_stride},
{AIE2::LDA_dms_spill, AIE2::sub_dim_count},
{AIE2::LDA_dms_spill, AIE2::sub_hi_dim_then_sub_mod},
{AIE2::LDA_dms_spill, AIE2::sub_hi_dim_then_sub_dim_size},
{AIE2::LDA_dms_spill, AIE2::sub_hi_dim_then_sub_dim_stride},
{AIE2::LDA_dms_spill, AIE2::sub_hi_dim_then_sub_dim_count}};
Expand All @@ -844,7 +843,6 @@ AIE2InstrInfo::getSpillPseudoExpandInfo(const MachineInstr &MI) const {
{AIE2::ST_dms_spill, AIE2::sub_dim_size},
{AIE2::ST_dms_spill, AIE2::sub_dim_stride},
{AIE2::ST_dms_spill, AIE2::sub_dim_count},
{AIE2::ST_dms_spill, AIE2::sub_hi_dim_then_sub_mod},
{AIE2::ST_dms_spill, AIE2::sub_hi_dim_then_sub_dim_size},
{AIE2::ST_dms_spill, AIE2::sub_hi_dim_then_sub_dim_stride},
{AIE2::ST_dms_spill, AIE2::sub_hi_dim_then_sub_dim_count}};
Expand Down Expand Up @@ -1205,7 +1203,6 @@ AIE2InstrInfo::getTiedRegInfo(unsigned Opcode) const {
SubRegSplit(AIE2::sub_dim_size),
SubRegSplit(AIE2::sub_dim_stride),
SubRegSplit(AIE2::sub_dim_count),
SubRegSplit(AIE2::sub_hi_dim_then_sub_mod, /*IsUndef=*/true),
SubRegSplit(AIE2::sub_hi_dim_then_sub_dim_size),
SubRegSplit(AIE2::sub_hi_dim_then_sub_dim_stride),
SubRegSplit(AIE2::sub_hi_dim_then_sub_dim_count)};
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AIE/AIE2InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ foreach instr = [VST_2D_SRS_D8_S32, VST_2D_SRS_D16_S64, VST_2D_SRS_D16_S32,
// Define _split variants for instructions using 3D registers
class Split3DInstr<Instruction RealInst, int opidx> : SplitPseudo<RealInst,
opidx, (ins eM:$mod1, eDN:$dim_size1, eDJ:$dim_stride1, eDC:$dim_count1,
eM:$mod2, eDN:$dim_size2, eDJ:$dim_stride2, eDC:$dim_count2)> {}
eDN:$dim_size2, eDJ:$dim_stride2, eDC:$dim_count2)> {}
foreach instr = [VLDA_3D_dmw_lda_w, VLDA_3D_dmw_lda_am, VLDA_3D_CONV_FP32_BF16,
VLDB_3D, VLDB_3D_128, LDA_3D_dmv_lda_q, VLDB_3D_UNPACK_S8_S4,
VLDB_3D_UNPACK_S16_S8, VLDB_3D_UNPACK_D8_D4, VLDB_3D_UNPACK_D16_D8,
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Target/AIE/AIE2RegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,6 @@ const std::set<int> &AIE2RegisterInfo::getSubRegSplit(int RegClassId) const {
AIE2::sub_dim_size,
AIE2::sub_dim_stride,
AIE2::sub_dim_count,
AIE2::sub_hi_dim_then_sub_mod,
AIE2::sub_hi_dim_then_sub_dim_size,
AIE2::sub_hi_dim_then_sub_dim_stride,
AIE2::sub_hi_dim_then_sub_dim_count};
Expand Down
30 changes: 22 additions & 8 deletions llvm/lib/Target/AIE/AIEBaseInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,15 +649,29 @@ void AIEBaseInstrInfo::copyThroughSubRegs(MachineBasicBlock &MBB,
MCRegister SrcReg,
bool KillSrc) const {
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
const TargetRegisterInfo &TRI = *MRI.getTargetRegisterInfo();

SmallSet<MCRegister, 8> SrcSubRegs;
collectSubRegs(SrcReg, SrcSubRegs, TRI);
auto &TRI =
*static_cast<const AIEBaseRegisterInfo *>(MRI.getTargetRegisterInfo());

const auto *RC = Register::isPhysicalRegister(SrcReg.id())
? TRI.getMinimalPhysRegClass(SrcReg)
: MRI.getRegClass(SrcReg);
auto &SubRegSplit = TRI.getSubRegSplit(RC->getID());

if (SubRegSplit.size() > 1) {
for (const auto &SubRegIdx : SubRegSplit) {
MCRegister SrcSubReg = TRI.getSubReg(SrcReg, SubRegIdx);
MCRegister DstSubReg = TRI.getSubReg(DstReg, SubRegIdx);
copyPhysReg(MBB, MBBI, DL, DstSubReg, SrcSubReg, KillSrc);
}
} else {
SmallSet<MCRegister, 8> SrcSubRegs;
collectSubRegs(SrcReg, SrcSubRegs, TRI);

for (MCRegister SrcSubReg : SrcSubRegs) {
unsigned SubRegIdx = TRI.getSubRegIndex(SrcReg, SrcSubReg);
MCRegister DstSubReg = TRI.getSubReg(DstReg, SubRegIdx);
copyPhysReg(MBB, MBBI, DL, DstSubReg, SrcSubReg, KillSrc);
for (MCRegister SrcSubReg : SrcSubRegs) {
unsigned SubRegIdx = TRI.getSubRegIndex(SrcReg, SrcSubReg);
MCRegister DstSubReg = TRI.getSubReg(DstReg, SubRegIdx);
copyPhysReg(MBB, MBBI, DL, DstSubReg, SrcSubReg, KillSrc);
}
}
}

Expand Down
Loading
Loading