Skip to content

Commit d813716

Browse files
committed
[AArch64][MachineCombiner] Combine sequences of gather patterns
1 parent 76d2f0f commit d813716

11 files changed

+1168
-346
lines changed

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "Utils/AArch64BaseInfo.h"
2121
#include "llvm/ADT/ArrayRef.h"
2222
#include "llvm/ADT/STLExtras.h"
23+
#include "llvm/ADT/SmallSet.h"
2324
#include "llvm/ADT/SmallVector.h"
2425
#include "llvm/CodeGen/CFIInstBuilder.h"
2526
#include "llvm/CodeGen/LivePhysRegs.h"
@@ -7412,11 +7413,347 @@ static bool getMiscPatterns(MachineInstr &Root,
74127413
return false;
74137414
}
74147415

7416+
/// Check if there are any stores or calls between two instructions in the same
7417+
/// basic block.
7418+
static bool hasInterveningStoreOrCall(const MachineInstr *First,
7419+
const MachineInstr *Last) {
7420+
if (!First || !Last || First == Last)
7421+
return false;
7422+
7423+
// Both instructions must be in the same basic block.
7424+
if (First->getParent() != Last->getParent())
7425+
return false;
7426+
7427+
// Sanity check that First comes before Last.
7428+
const MachineBasicBlock *MBB = First->getParent();
7429+
auto InstrIt = First->getIterator();
7430+
auto LastIt = Last->getIterator();
7431+
7432+
for (; InstrIt != MBB->end(); ++InstrIt) {
7433+
if (InstrIt == LastIt)
7434+
break;
7435+
7436+
// Check for stores or calls that could interfere
7437+
if (InstrIt->mayStore() || InstrIt->isCall())
7438+
return true;
7439+
}
7440+
7441+
// If we reached the end of the basic block, our instructions must have not
7442+
// been ordered correctly and the analysis is invalid.
7443+
assert(InstrIt != MBB->end() &&
7444+
"Got bad machine instructions, First should come before Last!");
7445+
return false;
7446+
}
7447+
7448+
/// Check if the given instruction forms a gather load pattern that can be
7449+
/// optimized for better Memory-Level Parallelism (MLP). This function
7450+
/// identifies chains of NEON lane load instructions that load data from
7451+
/// different memory addresses into individual lanes of a 128-bit vector
7452+
/// register, then attempts to split the pattern into parallel loads to break
7453+
/// the serial dependency between instructions.
7454+
///
7455+
/// Pattern Matched:
7456+
/// Initial scalar load -> SUBREG_TO_REG (lane 0) -> LD1i* (lane 1) ->
7457+
/// LD1i* (lane 2) -> ... -> LD1i* (lane N-1, Root)
7458+
///
7459+
/// Transformed Into:
7460+
/// Two parallel vector loads using fewer lanes each, followed by ZIP1v2i64
7461+
/// to combine the results, enabling better memory-level parallelism.
7462+
///
7463+
/// Supported Element Types:
7464+
/// - 32-bit elements (LD1i32, 4 lanes total)
7465+
/// - 16-bit elements (LD1i16, 8 lanes total)
7466+
/// - 8-bit elements (LD1i8, 16 lanes total)
7467+
static bool getGatherPattern(MachineInstr &Root,
7468+
SmallVectorImpl<unsigned> &Patterns,
7469+
unsigned LoadLaneOpCode, unsigned NumLanes) {
7470+
const MachineFunction *MF = Root.getMF();
7471+
7472+
// Early exit if optimizing for size.
7473+
if (MF->getFunction().hasMinSize())
7474+
return false;
7475+
7476+
const MachineRegisterInfo &MRI = MF->getRegInfo();
7477+
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
7478+
7479+
// The root of the pattern must load into the last lane of the vector.
7480+
if (Root.getOperand(2).getImm() != NumLanes - 1)
7481+
return false;
7482+
7483+
// Check that we have load into all lanes except lane 0.
7484+
// For each load we also want to check that:
7485+
// 1. It has a single non-debug use (since we will be replacing the virtual
7486+
// register)
7487+
// 2. That the addressing mode only uses a single pointer operand
7488+
auto *CurrInstr = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
7489+
auto Range = llvm::seq<unsigned>(1, NumLanes - 1);
7490+
SmallSet<unsigned, 16> RemainingLanes(Range.begin(), Range.end());
7491+
SmallSet<const MachineInstr *, 16> LoadInstrs = {};
7492+
while (!RemainingLanes.empty() && CurrInstr &&
7493+
CurrInstr->getOpcode() == LoadLaneOpCode &&
7494+
MRI.hasOneNonDBGUse(CurrInstr->getOperand(0).getReg()) &&
7495+
CurrInstr->getNumOperands() == 4) {
7496+
RemainingLanes.erase(CurrInstr->getOperand(2).getImm());
7497+
LoadInstrs.insert(CurrInstr);
7498+
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
7499+
}
7500+
7501+
// Check that we have found a match for lanes N-1.. 1.
7502+
if (!RemainingLanes.empty())
7503+
return false;
7504+
7505+
// Match the SUBREG_TO_REG sequence.
7506+
if (CurrInstr->getOpcode() != TargetOpcode::SUBREG_TO_REG)
7507+
return false;
7508+
7509+
// Verify that the subreg to reg loads an integer into the first lane.
7510+
auto Lane0LoadReg = CurrInstr->getOperand(2).getReg();
7511+
unsigned SingleLaneSizeInBits = 128 / NumLanes;
7512+
if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != SingleLaneSizeInBits)
7513+
return false;
7514+
7515+
// Verify that it also has a single non debug use.
7516+
if (!MRI.hasOneNonDBGUse(Lane0LoadReg))
7517+
return false;
7518+
7519+
LoadInstrs.insert(MRI.getUniqueVRegDef(Lane0LoadReg));
7520+
7521+
// Check for intervening stores or calls between the first and last load.
7522+
// Sort load instructions by program order.
7523+
SmallVector<const MachineInstr *, 16> SortedLoads(LoadInstrs.begin(),
7524+
LoadInstrs.end());
7525+
llvm::sort(SortedLoads, [](const MachineInstr *A, const MachineInstr *B) {
7526+
if (A->getParent() != B->getParent()) {
7527+
// If in different blocks, this shouldn't happen for gather patterns.
7528+
return false;
7529+
}
7530+
// Compare positions within the same basic block.
7531+
for (const MachineInstr &MI : *A->getParent()) {
7532+
if (&MI == A)
7533+
return true;
7534+
if (&MI == B)
7535+
return false;
7536+
}
7537+
return false;
7538+
});
7539+
7540+
const MachineInstr *FirstLoad = SortedLoads.front();
7541+
const MachineInstr *LastLoad = SortedLoads.back();
7542+
7543+
if (hasInterveningStoreOrCall(FirstLoad, LastLoad))
7544+
return false;
7545+
7546+
switch (NumLanes) {
7547+
case 4:
7548+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i32);
7549+
break;
7550+
case 8:
7551+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i16);
7552+
break;
7553+
case 16:
7554+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i8);
7555+
break;
7556+
default:
7557+
llvm_unreachable("Got bad number of lanes for gather pattern.");
7558+
}
7559+
7560+
return true;
7561+
}
7562+
7563+
/// Search for patterns of LD instructions we can optimize.
7564+
static bool getLoadPatterns(MachineInstr &Root,
7565+
SmallVectorImpl<unsigned> &Patterns) {
7566+
7567+
// The pattern searches for loads into single lanes.
7568+
switch (Root.getOpcode()) {
7569+
case AArch64::LD1i32:
7570+
return getGatherPattern(Root, Patterns, Root.getOpcode(), 4);
7571+
case AArch64::LD1i16:
7572+
return getGatherPattern(Root, Patterns, Root.getOpcode(), 8);
7573+
case AArch64::LD1i8:
7574+
return getGatherPattern(Root, Patterns, Root.getOpcode(), 16);
7575+
default:
7576+
return false;
7577+
}
7578+
}
7579+
7580+
/// Generate optimized instruction sequence for gather load patterns to improve
7581+
/// Memory-Level Parallelism (MLP). This function transforms a chain of
7582+
/// sequential NEON lane loads into parallel vector loads that can execute
7583+
/// concurrently.
7584+
static void
7585+
generateGatherPattern(MachineInstr &Root,
7586+
SmallVectorImpl<MachineInstr *> &InsInstrs,
7587+
SmallVectorImpl<MachineInstr *> &DelInstrs,
7588+
DenseMap<Register, unsigned> &InstrIdxForVirtReg,
7589+
unsigned Pattern, unsigned NumLanes) {
7590+
MachineFunction &MF = *Root.getParent()->getParent();
7591+
MachineRegisterInfo &MRI = MF.getRegInfo();
7592+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
7593+
7594+
// Gather the initial load instructions to build the pattern.
7595+
SmallVector<MachineInstr *, 16> LoadToLaneInstrs;
7596+
MachineInstr *CurrInstr = &Root;
7597+
for (unsigned i = 0; i < NumLanes - 1; ++i) {
7598+
LoadToLaneInstrs.push_back(CurrInstr);
7599+
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
7600+
}
7601+
7602+
// Sort the load instructions according to the lane.
7603+
llvm::sort(LoadToLaneInstrs,
7604+
[](const MachineInstr *A, const MachineInstr *B) {
7605+
return A->getOperand(2).getImm() > B->getOperand(2).getImm();
7606+
});
7607+
7608+
MachineInstr *SubregToReg = CurrInstr;
7609+
LoadToLaneInstrs.push_back(
7610+
MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg()));
7611+
auto LoadToLaneInstrsAscending = llvm::reverse(LoadToLaneInstrs);
7612+
7613+
const TargetRegisterClass *FPR128RegClass =
7614+
MRI.getRegClass(Root.getOperand(0).getReg());
7615+
7616+
// Helper lambda to create a LD1 instruction.
7617+
auto CreateLD1Instruction = [&](MachineInstr *OriginalInstr,
7618+
Register SrcRegister, unsigned Lane,
7619+
Register OffsetRegister,
7620+
bool OffsetRegisterKillState) {
7621+
auto NewRegister = MRI.createVirtualRegister(FPR128RegClass);
7622+
MachineInstrBuilder LoadIndexIntoRegister =
7623+
BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()),
7624+
NewRegister)
7625+
.addReg(SrcRegister)
7626+
.addImm(Lane)
7627+
.addReg(OffsetRegister, getKillRegState(OffsetRegisterKillState));
7628+
InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size()));
7629+
InsInstrs.push_back(LoadIndexIntoRegister);
7630+
return NewRegister;
7631+
};
7632+
7633+
// Helper to create load instruction based on the NumLanes in the NEON
7634+
// register we are rewriting.
7635+
auto CreateLDRInstruction = [&](unsigned NumLanes, Register DestReg,
7636+
Register OffsetReg,
7637+
bool KillState) -> MachineInstrBuilder {
7638+
unsigned Opcode;
7639+
switch (NumLanes) {
7640+
case 4:
7641+
Opcode = AArch64::LDRSui;
7642+
break;
7643+
case 8:
7644+
Opcode = AArch64::LDRHui;
7645+
break;
7646+
case 16:
7647+
Opcode = AArch64::LDRBui;
7648+
break;
7649+
default:
7650+
llvm_unreachable(
7651+
"Got unsupported number of lanes in machine-combiner gather pattern");
7652+
}
7653+
// Immediate offset load
7654+
return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg)
7655+
.addReg(OffsetReg)
7656+
.addImm(0);
7657+
};
7658+
7659+
// Load the remaining lanes into register 0.
7660+
auto LanesToLoadToReg0 =
7661+
llvm::make_range(LoadToLaneInstrsAscending.begin() + 1,
7662+
LoadToLaneInstrsAscending.begin() + NumLanes / 2);
7663+
Register PrevReg = SubregToReg->getOperand(0).getReg();
7664+
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) {
7665+
const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3);
7666+
PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1,
7667+
OffsetRegOperand.getReg(),
7668+
OffsetRegOperand.isKill());
7669+
DelInstrs.push_back(LoadInstr);
7670+
}
7671+
Register LastLoadReg0 = PrevReg;
7672+
7673+
// First load into register 1. Perform an integer load to zero out the upper
7674+
// lanes in a single instruction.
7675+
MachineInstr *Lane0Load = *LoadToLaneInstrsAscending.begin();
7676+
MachineInstr *OriginalSplitLoad =
7677+
*std::next(LoadToLaneInstrsAscending.begin(), NumLanes / 2);
7678+
Register DestRegForMiddleIndex = MRI.createVirtualRegister(
7679+
MRI.getRegClass(Lane0Load->getOperand(0).getReg()));
7680+
7681+
const MachineOperand &OriginalSplitToLoadOffsetOperand =
7682+
OriginalSplitLoad->getOperand(3);
7683+
MachineInstrBuilder MiddleIndexLoadInstr =
7684+
CreateLDRInstruction(NumLanes, DestRegForMiddleIndex,
7685+
OriginalSplitToLoadOffsetOperand.getReg(),
7686+
OriginalSplitToLoadOffsetOperand.isKill());
7687+
7688+
InstrIdxForVirtReg.insert(
7689+
std::make_pair(DestRegForMiddleIndex, InsInstrs.size()));
7690+
InsInstrs.push_back(MiddleIndexLoadInstr);
7691+
DelInstrs.push_back(OriginalSplitLoad);
7692+
7693+
// Subreg To Reg instruction for register 1.
7694+
Register DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass);
7695+
unsigned SubregType;
7696+
switch (NumLanes) {
7697+
case 4:
7698+
SubregType = AArch64::ssub;
7699+
break;
7700+
case 8:
7701+
SubregType = AArch64::hsub;
7702+
break;
7703+
case 16:
7704+
SubregType = AArch64::bsub;
7705+
break;
7706+
default:
7707+
llvm_unreachable(
7708+
"Got invalid NumLanes for machine-combiner gather pattern");
7709+
}
7710+
7711+
auto SubRegToRegInstr =
7712+
BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()),
7713+
DestRegForSubregToReg)
7714+
.addImm(0)
7715+
.addReg(DestRegForMiddleIndex, getKillRegState(true))
7716+
.addImm(SubregType);
7717+
InstrIdxForVirtReg.insert(
7718+
std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
7719+
InsInstrs.push_back(SubRegToRegInstr);
7720+
7721+
// Load remaining lanes into register 1.
7722+
auto LanesToLoadToReg1 =
7723+
llvm::make_range(LoadToLaneInstrsAscending.begin() + NumLanes / 2 + 1,
7724+
LoadToLaneInstrsAscending.end());
7725+
PrevReg = SubRegToRegInstr->getOperand(0).getReg();
7726+
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) {
7727+
const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3);
7728+
PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1,
7729+
OffsetRegOperand.getReg(),
7730+
OffsetRegOperand.isKill());
7731+
7732+
// Do not add the last reg to DelInstrs - it will be removed later.
7733+
if (Index == NumLanes / 2 - 2) {
7734+
break;
7735+
}
7736+
DelInstrs.push_back(LoadInstr);
7737+
}
7738+
Register LastLoadReg1 = PrevReg;
7739+
7740+
// Create the final zip instruction to combine the results.
7741+
MachineInstrBuilder ZipInstr =
7742+
BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64),
7743+
Root.getOperand(0).getReg())
7744+
.addReg(LastLoadReg0)
7745+
.addReg(LastLoadReg1);
7746+
InsInstrs.push_back(ZipInstr);
7747+
}
7748+
74157749
CombinerObjective
74167750
AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const {
74177751
switch (Pattern) {
74187752
case AArch64MachineCombinerPattern::SUBADD_OP1:
74197753
case AArch64MachineCombinerPattern::SUBADD_OP2:
7754+
case AArch64MachineCombinerPattern::GATHER_LANE_i32:
7755+
case AArch64MachineCombinerPattern::GATHER_LANE_i16:
7756+
case AArch64MachineCombinerPattern::GATHER_LANE_i8:
74207757
return CombinerObjective::MustReduceDepth;
74217758
default:
74227759
return TargetInstrInfo::getCombinerObjective(Pattern);
@@ -7446,6 +7783,10 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
74467783
if (getMiscPatterns(Root, Patterns))
74477784
return true;
74487785

7786+
// Load patterns
7787+
if (getLoadPatterns(Root, Patterns))
7788+
return true;
7789+
74497790
return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
74507791
DoRegPressureReduce);
74517792
}
@@ -8701,6 +9042,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
87019042
MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs);
87029043
break;
87039044
}
9045+
case AArch64MachineCombinerPattern::GATHER_LANE_i32: {
9046+
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
9047+
Pattern, 4);
9048+
break;
9049+
}
9050+
case AArch64MachineCombinerPattern::GATHER_LANE_i16: {
9051+
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
9052+
Pattern, 8);
9053+
break;
9054+
}
9055+
case AArch64MachineCombinerPattern::GATHER_LANE_i8: {
9056+
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
9057+
Pattern, 16);
9058+
break;
9059+
}
87049060

87059061
} // end switch (Pattern)
87069062
// Record MUL and ADD/SUB for deletion

llvm/lib/Target/AArch64/AArch64InstrInfo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ enum AArch64MachineCombinerPattern : unsigned {
172172
FMULv8i16_indexed_OP2,
173173

174174
FNMADD,
175+
176+
GATHER_LANE_i32,
177+
GATHER_LANE_i16,
178+
GATHER_LANE_i8
175179
};
176180
class AArch64InstrInfo final : public AArch64GenInstrInfo {
177181
const AArch64RegisterInfo RI;

0 commit comments

Comments
 (0)