Skip to content

Commit c6fe567

Browse files
authored
[AArch64][MachineCombiner] Combine sequences of gather patterns (#152979)
Reland of #142941 Squashed with fixes for #150004, #149585 This pattern matches gather-like patterns where values are loaded per lane into neon registers, and replaces it with loads into 2 separate registers, which will be combined with a zip instruction. This decreases the critical path length and improves Memory Level Parallelism. rdar://151851094
1 parent 673750f commit c6fe567

12 files changed

+1179
-346
lines changed

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
#include "Utils/AArch64BaseInfo.h"
2121
#include "llvm/ADT/ArrayRef.h"
2222
#include "llvm/ADT/STLExtras.h"
23+
#include "llvm/ADT/SmallSet.h"
2324
#include "llvm/ADT/SmallVector.h"
25+
#include "llvm/Analysis/AliasAnalysis.h"
2426
#include "llvm/CodeGen/CFIInstBuilder.h"
2527
#include "llvm/CodeGen/LivePhysRegs.h"
2628
#include "llvm/CodeGen/MachineBasicBlock.h"
@@ -83,6 +85,11 @@ static cl::opt<unsigned>
8385
BDisplacementBits("aarch64-b-offset-bits", cl::Hidden, cl::init(26),
8486
cl::desc("Restrict range of B instructions (DEBUG)"));
8587

88+
static cl::opt<unsigned> GatherOptSearchLimit(
89+
"aarch64-search-limit", cl::Hidden, cl::init(2048),
90+
cl::desc("Restrict range of instructions to search for the "
91+
"machine-combiner gather pattern optimization"));
92+
8693
AArch64InstrInfo::AArch64InstrInfo(const AArch64Subtarget &STI)
8794
: AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP,
8895
AArch64::CATCHRET),
@@ -7412,11 +7419,319 @@ static bool getMiscPatterns(MachineInstr &Root,
74127419
return false;
74137420
}
74147421

7422+
/// Check if the given instruction forms a gather load pattern that can be
7423+
/// optimized for better Memory-Level Parallelism (MLP). This function
7424+
/// identifies chains of NEON lane load instructions that load data from
7425+
/// different memory addresses into individual lanes of a 128-bit vector
7426+
/// register, then attempts to split the pattern into parallel loads to break
7427+
/// the serial dependency between instructions.
7428+
///
7429+
/// Pattern Matched:
7430+
/// Initial scalar load -> SUBREG_TO_REG (lane 0) -> LD1i* (lane 1) ->
7431+
/// LD1i* (lane 2) -> ... -> LD1i* (lane N-1, Root)
7432+
///
7433+
/// Transformed Into:
7434+
/// Two parallel vector loads using fewer lanes each, followed by ZIP1v2i64
7435+
/// to combine the results, enabling better memory-level parallelism.
7436+
///
7437+
/// Supported Element Types:
7438+
/// - 32-bit elements (LD1i32, 4 lanes total)
7439+
/// - 16-bit elements (LD1i16, 8 lanes total)
7440+
/// - 8-bit elements (LD1i8, 16 lanes total)
7441+
static bool getGatherLanePattern(MachineInstr &Root,
7442+
SmallVectorImpl<unsigned> &Patterns,
7443+
unsigned LoadLaneOpCode, unsigned NumLanes) {
7444+
const MachineFunction *MF = Root.getMF();
7445+
7446+
// Early exit if optimizing for size.
7447+
if (MF->getFunction().hasMinSize())
7448+
return false;
7449+
7450+
const MachineRegisterInfo &MRI = MF->getRegInfo();
7451+
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
7452+
7453+
// The root of the pattern must load into the last lane of the vector.
7454+
if (Root.getOperand(2).getImm() != NumLanes - 1)
7455+
return false;
7456+
7457+
// Check that we have load into all lanes except lane 0.
7458+
// For each load we also want to check that:
7459+
// 1. It has a single non-debug use (since we will be replacing the virtual
7460+
// register)
7461+
// 2. That the addressing mode only uses a single pointer operand
7462+
auto *CurrInstr = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
7463+
auto Range = llvm::seq<unsigned>(1, NumLanes - 1);
7464+
SmallSet<unsigned, 16> RemainingLanes(Range.begin(), Range.end());
7465+
SmallVector<const MachineInstr *, 16> LoadInstrs;
7466+
while (!RemainingLanes.empty() && CurrInstr &&
7467+
CurrInstr->getOpcode() == LoadLaneOpCode &&
7468+
MRI.hasOneNonDBGUse(CurrInstr->getOperand(0).getReg()) &&
7469+
CurrInstr->getNumOperands() == 4) {
7470+
RemainingLanes.erase(CurrInstr->getOperand(2).getImm());
7471+
LoadInstrs.push_back(CurrInstr);
7472+
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
7473+
}
7474+
7475+
// Check that we have found a match for lanes N-1.. 1.
7476+
if (!RemainingLanes.empty())
7477+
return false;
7478+
7479+
// Match the SUBREG_TO_REG sequence.
7480+
if (CurrInstr->getOpcode() != TargetOpcode::SUBREG_TO_REG)
7481+
return false;
7482+
7483+
// Verify that the subreg to reg loads an integer into the first lane.
7484+
auto Lane0LoadReg = CurrInstr->getOperand(2).getReg();
7485+
unsigned SingleLaneSizeInBits = 128 / NumLanes;
7486+
if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != SingleLaneSizeInBits)
7487+
return false;
7488+
7489+
// Verify that it also has a single non debug use.
7490+
if (!MRI.hasOneNonDBGUse(Lane0LoadReg))
7491+
return false;
7492+
7493+
LoadInstrs.push_back(MRI.getUniqueVRegDef(Lane0LoadReg));
7494+
7495+
// If there is any chance of aliasing, do not apply the pattern.
7496+
// Walk backward through the MBB starting from Root.
7497+
// Exit early if we've encountered all load instructions or hit the search
7498+
// limit.
7499+
auto MBBItr = Root.getIterator();
7500+
unsigned RemainingSteps = GatherOptSearchLimit;
7501+
SmallSet<const MachineInstr *, 16> RemainingLoadInstrs;
7502+
RemainingLoadInstrs.insert(LoadInstrs.begin(), LoadInstrs.end());
7503+
const MachineBasicBlock *MBB = Root.getParent();
7504+
7505+
for (; MBBItr != MBB->begin() && RemainingSteps > 0 &&
7506+
!RemainingLoadInstrs.empty();
7507+
--MBBItr, --RemainingSteps) {
7508+
const MachineInstr &CurrInstr = *MBBItr;
7509+
7510+
// Remove this instruction from remaining loads if it's one we're tracking.
7511+
RemainingLoadInstrs.erase(&CurrInstr);
7512+
7513+
// Check for potential aliasing with any of the load instructions to
7514+
// optimize.
7515+
if (CurrInstr.isLoadFoldBarrier())
7516+
return false;
7517+
}
7518+
7519+
// If we hit the search limit without finding all load instructions,
7520+
// don't match the pattern.
7521+
if (RemainingSteps == 0 && !RemainingLoadInstrs.empty())
7522+
return false;
7523+
7524+
switch (NumLanes) {
7525+
case 4:
7526+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i32);
7527+
break;
7528+
case 8:
7529+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i16);
7530+
break;
7531+
case 16:
7532+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i8);
7533+
break;
7534+
default:
7535+
llvm_unreachable("Got bad number of lanes for gather pattern.");
7536+
}
7537+
7538+
return true;
7539+
}
7540+
7541+
/// Search for patterns of LD instructions we can optimize.
7542+
static bool getLoadPatterns(MachineInstr &Root,
7543+
SmallVectorImpl<unsigned> &Patterns) {
7544+
7545+
// The pattern searches for loads into single lanes.
7546+
switch (Root.getOpcode()) {
7547+
case AArch64::LD1i32:
7548+
return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 4);
7549+
case AArch64::LD1i16:
7550+
return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 8);
7551+
case AArch64::LD1i8:
7552+
return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 16);
7553+
default:
7554+
return false;
7555+
}
7556+
}
7557+
7558+
/// Generate optimized instruction sequence for gather load patterns to improve
7559+
/// Memory-Level Parallelism (MLP). This function transforms a chain of
7560+
/// sequential NEON lane loads into parallel vector loads that can execute
7561+
/// concurrently.
7562+
static void
7563+
generateGatherLanePattern(MachineInstr &Root,
7564+
SmallVectorImpl<MachineInstr *> &InsInstrs,
7565+
SmallVectorImpl<MachineInstr *> &DelInstrs,
7566+
DenseMap<Register, unsigned> &InstrIdxForVirtReg,
7567+
unsigned Pattern, unsigned NumLanes) {
7568+
MachineFunction &MF = *Root.getParent()->getParent();
7569+
MachineRegisterInfo &MRI = MF.getRegInfo();
7570+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
7571+
7572+
// Gather the initial load instructions to build the pattern.
7573+
SmallVector<MachineInstr *, 16> LoadToLaneInstrs;
7574+
MachineInstr *CurrInstr = &Root;
7575+
for (unsigned i = 0; i < NumLanes - 1; ++i) {
7576+
LoadToLaneInstrs.push_back(CurrInstr);
7577+
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
7578+
}
7579+
7580+
// Sort the load instructions according to the lane.
7581+
llvm::sort(LoadToLaneInstrs,
7582+
[](const MachineInstr *A, const MachineInstr *B) {
7583+
return A->getOperand(2).getImm() > B->getOperand(2).getImm();
7584+
});
7585+
7586+
MachineInstr *SubregToReg = CurrInstr;
7587+
LoadToLaneInstrs.push_back(
7588+
MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg()));
7589+
auto LoadToLaneInstrsAscending = llvm::reverse(LoadToLaneInstrs);
7590+
7591+
const TargetRegisterClass *FPR128RegClass =
7592+
MRI.getRegClass(Root.getOperand(0).getReg());
7593+
7594+
// Helper lambda to create a LD1 instruction.
7595+
auto CreateLD1Instruction = [&](MachineInstr *OriginalInstr,
7596+
Register SrcRegister, unsigned Lane,
7597+
Register OffsetRegister,
7598+
bool OffsetRegisterKillState) {
7599+
auto NewRegister = MRI.createVirtualRegister(FPR128RegClass);
7600+
MachineInstrBuilder LoadIndexIntoRegister =
7601+
BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()),
7602+
NewRegister)
7603+
.addReg(SrcRegister)
7604+
.addImm(Lane)
7605+
.addReg(OffsetRegister, getKillRegState(OffsetRegisterKillState));
7606+
InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size()));
7607+
InsInstrs.push_back(LoadIndexIntoRegister);
7608+
return NewRegister;
7609+
};
7610+
7611+
// Helper to create load instruction based on the NumLanes in the NEON
7612+
// register we are rewriting.
7613+
auto CreateLDRInstruction = [&](unsigned NumLanes, Register DestReg,
7614+
Register OffsetReg,
7615+
bool KillState) -> MachineInstrBuilder {
7616+
unsigned Opcode;
7617+
switch (NumLanes) {
7618+
case 4:
7619+
Opcode = AArch64::LDRSui;
7620+
break;
7621+
case 8:
7622+
Opcode = AArch64::LDRHui;
7623+
break;
7624+
case 16:
7625+
Opcode = AArch64::LDRBui;
7626+
break;
7627+
default:
7628+
llvm_unreachable(
7629+
"Got unsupported number of lanes in machine-combiner gather pattern");
7630+
}
7631+
// Immediate offset load
7632+
return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg)
7633+
.addReg(OffsetReg)
7634+
.addImm(0);
7635+
};
7636+
7637+
// Load the remaining lanes into register 0.
7638+
auto LanesToLoadToReg0 =
7639+
llvm::make_range(LoadToLaneInstrsAscending.begin() + 1,
7640+
LoadToLaneInstrsAscending.begin() + NumLanes / 2);
7641+
Register PrevReg = SubregToReg->getOperand(0).getReg();
7642+
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) {
7643+
const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3);
7644+
PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1,
7645+
OffsetRegOperand.getReg(),
7646+
OffsetRegOperand.isKill());
7647+
DelInstrs.push_back(LoadInstr);
7648+
}
7649+
Register LastLoadReg0 = PrevReg;
7650+
7651+
// First load into register 1. Perform an integer load to zero out the upper
7652+
// lanes in a single instruction.
7653+
MachineInstr *Lane0Load = *LoadToLaneInstrsAscending.begin();
7654+
MachineInstr *OriginalSplitLoad =
7655+
*std::next(LoadToLaneInstrsAscending.begin(), NumLanes / 2);
7656+
Register DestRegForMiddleIndex = MRI.createVirtualRegister(
7657+
MRI.getRegClass(Lane0Load->getOperand(0).getReg()));
7658+
7659+
const MachineOperand &OriginalSplitToLoadOffsetOperand =
7660+
OriginalSplitLoad->getOperand(3);
7661+
MachineInstrBuilder MiddleIndexLoadInstr =
7662+
CreateLDRInstruction(NumLanes, DestRegForMiddleIndex,
7663+
OriginalSplitToLoadOffsetOperand.getReg(),
7664+
OriginalSplitToLoadOffsetOperand.isKill());
7665+
7666+
InstrIdxForVirtReg.insert(
7667+
std::make_pair(DestRegForMiddleIndex, InsInstrs.size()));
7668+
InsInstrs.push_back(MiddleIndexLoadInstr);
7669+
DelInstrs.push_back(OriginalSplitLoad);
7670+
7671+
// Subreg To Reg instruction for register 1.
7672+
Register DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass);
7673+
unsigned SubregType;
7674+
switch (NumLanes) {
7675+
case 4:
7676+
SubregType = AArch64::ssub;
7677+
break;
7678+
case 8:
7679+
SubregType = AArch64::hsub;
7680+
break;
7681+
case 16:
7682+
SubregType = AArch64::bsub;
7683+
break;
7684+
default:
7685+
llvm_unreachable(
7686+
"Got invalid NumLanes for machine-combiner gather pattern");
7687+
}
7688+
7689+
auto SubRegToRegInstr =
7690+
BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()),
7691+
DestRegForSubregToReg)
7692+
.addImm(0)
7693+
.addReg(DestRegForMiddleIndex, getKillRegState(true))
7694+
.addImm(SubregType);
7695+
InstrIdxForVirtReg.insert(
7696+
std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
7697+
InsInstrs.push_back(SubRegToRegInstr);
7698+
7699+
// Load remaining lanes into register 1.
7700+
auto LanesToLoadToReg1 =
7701+
llvm::make_range(LoadToLaneInstrsAscending.begin() + NumLanes / 2 + 1,
7702+
LoadToLaneInstrsAscending.end());
7703+
PrevReg = SubRegToRegInstr->getOperand(0).getReg();
7704+
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) {
7705+
const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3);
7706+
PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1,
7707+
OffsetRegOperand.getReg(),
7708+
OffsetRegOperand.isKill());
7709+
7710+
// Do not add the last reg to DelInstrs - it will be removed later.
7711+
if (Index == NumLanes / 2 - 2) {
7712+
break;
7713+
}
7714+
DelInstrs.push_back(LoadInstr);
7715+
}
7716+
Register LastLoadReg1 = PrevReg;
7717+
7718+
// Create the final zip instruction to combine the results.
7719+
MachineInstrBuilder ZipInstr =
7720+
BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64),
7721+
Root.getOperand(0).getReg())
7722+
.addReg(LastLoadReg0)
7723+
.addReg(LastLoadReg1);
7724+
InsInstrs.push_back(ZipInstr);
7725+
}
7726+
74157727
CombinerObjective
74167728
AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const {
74177729
switch (Pattern) {
74187730
case AArch64MachineCombinerPattern::SUBADD_OP1:
74197731
case AArch64MachineCombinerPattern::SUBADD_OP2:
7732+
case AArch64MachineCombinerPattern::GATHER_LANE_i32:
7733+
case AArch64MachineCombinerPattern::GATHER_LANE_i16:
7734+
case AArch64MachineCombinerPattern::GATHER_LANE_i8:
74207735
return CombinerObjective::MustReduceDepth;
74217736
default:
74227737
return TargetInstrInfo::getCombinerObjective(Pattern);
@@ -7446,6 +7761,10 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
74467761
if (getMiscPatterns(Root, Patterns))
74477762
return true;
74487763

7764+
// Load patterns
7765+
if (getLoadPatterns(Root, Patterns))
7766+
return true;
7767+
74497768
return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
74507769
DoRegPressureReduce);
74517770
}
@@ -8701,6 +9020,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
87019020
MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs);
87029021
break;
87039022
}
9023+
case AArch64MachineCombinerPattern::GATHER_LANE_i32: {
9024+
generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
9025+
Pattern, 4);
9026+
break;
9027+
}
9028+
case AArch64MachineCombinerPattern::GATHER_LANE_i16: {
9029+
generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
9030+
Pattern, 8);
9031+
break;
9032+
}
9033+
case AArch64MachineCombinerPattern::GATHER_LANE_i8: {
9034+
generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
9035+
Pattern, 16);
9036+
break;
9037+
}
87049038

87059039
} // end switch (Pattern)
87069040
// Record MUL and ADD/SUB for deletion

llvm/lib/Target/AArch64/AArch64InstrInfo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ enum AArch64MachineCombinerPattern : unsigned {
172172
FMULv8i16_indexed_OP2,
173173

174174
FNMADD,
175+
176+
GATHER_LANE_i32,
177+
GATHER_LANE_i16,
178+
GATHER_LANE_i8
175179
};
176180
class AArch64InstrInfo final : public AArch64GenInstrInfo {
177181
const AArch64RegisterInfo RI;

0 commit comments

Comments
 (0)