Skip to content

Commit 13fb03c

Browse files
committed
[AArch64][MachineCombiner] Combine sequences of gather patterns
rdar://151851094
1 parent d1f0702 commit 13fb03c

12 files changed

+1195
-346
lines changed

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
#include "Utils/AArch64BaseInfo.h"
2121
#include "llvm/ADT/ArrayRef.h"
2222
#include "llvm/ADT/STLExtras.h"
23+
#include "llvm/ADT/SmallSet.h"
2324
#include "llvm/ADT/SmallVector.h"
25+
#include "llvm/Analysis/AliasAnalysis.h"
2426
#include "llvm/CodeGen/CFIInstBuilder.h"
2527
#include "llvm/CodeGen/LivePhysRegs.h"
2628
#include "llvm/CodeGen/MachineBasicBlock.h"
@@ -83,6 +85,11 @@ static cl::opt<unsigned>
8385
BDisplacementBits("aarch64-b-offset-bits", cl::Hidden, cl::init(26),
8486
cl::desc("Restrict range of B instructions (DEBUG)"));
8587

88+
static cl::opt<unsigned> GatherOptSearchLimit(
89+
"aarch64-search-limit", cl::Hidden, cl::init(2048),
90+
cl::desc("Restrict range of instructions to search for the "
91+
"machine-combiner gather pattern optimization"));
92+
8693
AArch64InstrInfo::AArch64InstrInfo(const AArch64Subtarget &STI)
8794
: AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP,
8895
AArch64::CATCHRET),
@@ -7485,11 +7492,335 @@ static bool getMiscPatterns(MachineInstr &Root,
74857492
return false;
74867493
}
74877494

7495+
/// Check if a given MachineInstr `MIa` may alias with any of the instructions
7496+
/// in `MemInstrs`.
7497+
static bool mayAlias(const MachineInstr &MIa,
7498+
SmallVectorImpl<const MachineInstr *> &MemInstrs,
7499+
AliasAnalysis *AA) {
7500+
for (const MachineInstr *MIb : MemInstrs) {
7501+
if (MIa.mayAlias(AA, *MIb, /*UseTBAA*/ false)) {
7502+
MIb->dump();
7503+
return true;
7504+
}
7505+
}
7506+
7507+
return false;
7508+
}
7509+
7510+
/// Check if the given instruction forms a gather load pattern that can be
7511+
/// optimized for better Memory-Level Parallelism (MLP). This function
7512+
/// identifies chains of NEON lane load instructions that load data from
7513+
/// different memory addresses into individual lanes of a 128-bit vector
7514+
/// register, then attempts to split the pattern into parallel loads to break
7515+
/// the serial dependency between instructions.
7516+
///
7517+
/// Pattern Matched:
7518+
/// Initial scalar load -> SUBREG_TO_REG (lane 0) -> LD1i* (lane 1) ->
7519+
/// LD1i* (lane 2) -> ... -> LD1i* (lane N-1, Root)
7520+
///
7521+
/// Transformed Into:
7522+
/// Two parallel vector loads using fewer lanes each, followed by ZIP1v2i64
7523+
/// to combine the results, enabling better memory-level parallelism.
7524+
///
7525+
/// Supported Element Types:
7526+
/// - 32-bit elements (LD1i32, 4 lanes total)
7527+
/// - 16-bit elements (LD1i16, 8 lanes total)
7528+
/// - 8-bit elements (LD1i8, 16 lanes total)
7529+
static bool getGatherLanePattern(MachineInstr &Root,
7530+
SmallVectorImpl<unsigned> &Patterns,
7531+
unsigned LoadLaneOpCode, unsigned NumLanes) {
7532+
const MachineFunction *MF = Root.getMF();
7533+
7534+
// Early exit if optimizing for size.
7535+
if (MF->getFunction().hasMinSize())
7536+
return false;
7537+
7538+
const MachineRegisterInfo &MRI = MF->getRegInfo();
7539+
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
7540+
7541+
// The root of the pattern must load into the last lane of the vector.
7542+
if (Root.getOperand(2).getImm() != NumLanes - 1)
7543+
return false;
7544+
7545+
// Check that we have load into all lanes except lane 0.
7546+
// For each load we also want to check that:
7547+
// 1. It has a single non-debug use (since we will be replacing the virtual
7548+
// register)
7549+
// 2. That the addressing mode only uses a single pointer operand
7550+
auto *CurrInstr = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
7551+
auto Range = llvm::seq<unsigned>(1, NumLanes - 1);
7552+
SmallSet<unsigned, 16> RemainingLanes(Range.begin(), Range.end());
7553+
SmallVector<const MachineInstr *, 16> LoadInstrs = {};
7554+
while (!RemainingLanes.empty() && CurrInstr &&
7555+
CurrInstr->getOpcode() == LoadLaneOpCode &&
7556+
MRI.hasOneNonDBGUse(CurrInstr->getOperand(0).getReg()) &&
7557+
CurrInstr->getNumOperands() == 4) {
7558+
RemainingLanes.erase(CurrInstr->getOperand(2).getImm());
7559+
LoadInstrs.push_back(CurrInstr);
7560+
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
7561+
}
7562+
7563+
// Check that we have found a match for lanes N-1.. 1.
7564+
if (!RemainingLanes.empty())
7565+
return false;
7566+
7567+
// Match the SUBREG_TO_REG sequence.
7568+
if (CurrInstr->getOpcode() != TargetOpcode::SUBREG_TO_REG)
7569+
return false;
7570+
7571+
// Verify that the subreg to reg loads an integer into the first lane.
7572+
auto Lane0LoadReg = CurrInstr->getOperand(2).getReg();
7573+
unsigned SingleLaneSizeInBits = 128 / NumLanes;
7574+
if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != SingleLaneSizeInBits)
7575+
return false;
7576+
7577+
// Verify that it also has a single non debug use.
7578+
if (!MRI.hasOneNonDBGUse(Lane0LoadReg))
7579+
return false;
7580+
7581+
LoadInstrs.push_back(MRI.getUniqueVRegDef(Lane0LoadReg));
7582+
7583+
// If there is any chance of aliasing, do not apply the pattern.
7584+
// Walk backward through the MBB starting from Root.
7585+
// Exit early if we've encountered all load instructions or hit the search
7586+
// limit.
7587+
auto MBBItr = Root.getIterator();
7588+
unsigned RemainingSteps = GatherOptSearchLimit;
7589+
SmallSet<const MachineInstr *, 16> RemainingLoadInstrs;
7590+
RemainingLoadInstrs.insert(LoadInstrs.begin(), LoadInstrs.end());
7591+
const MachineBasicBlock *MBB = Root.getParent();
7592+
7593+
for (; MBBItr != MBB->begin() && RemainingSteps > 0 &&
7594+
!RemainingLoadInstrs.empty();
7595+
--MBBItr, --RemainingSteps) {
7596+
const MachineInstr &CurrInstr = *MBBItr;
7597+
7598+
// Remove this instruction from remaining loads if it's one we're tracking.
7599+
RemainingLoadInstrs.erase(&CurrInstr);
7600+
7601+
// Check for potential aliasing with any of the load instructions to
7602+
// optimize.
7603+
if ((CurrInstr.mayLoadOrStore() || CurrInstr.isCall()) &&
7604+
mayAlias(CurrInstr, LoadInstrs, nullptr))
7605+
return false;
7606+
}
7607+
7608+
// If we hit the search limit without finding all load instructions,
7609+
// don't match the pattern.
7610+
if (RemainingSteps == 0 && !RemainingLoadInstrs.empty())
7611+
return false;
7612+
7613+
switch (NumLanes) {
7614+
case 4:
7615+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i32);
7616+
break;
7617+
case 8:
7618+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i16);
7619+
break;
7620+
case 16:
7621+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i8);
7622+
break;
7623+
default:
7624+
llvm_unreachable("Got bad number of lanes for gather pattern.");
7625+
}
7626+
7627+
return true;
7628+
}
7629+
7630+
/// Search for patterns of LD instructions we can optimize.
7631+
static bool getLoadPatterns(MachineInstr &Root,
7632+
SmallVectorImpl<unsigned> &Patterns) {
7633+
7634+
// The pattern searches for loads into single lanes.
7635+
switch (Root.getOpcode()) {
7636+
case AArch64::LD1i32:
7637+
return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 4);
7638+
case AArch64::LD1i16:
7639+
return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 8);
7640+
case AArch64::LD1i8:
7641+
return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 16);
7642+
default:
7643+
return false;
7644+
}
7645+
}
7646+
7647+
/// Generate optimized instruction sequence for gather load patterns to improve
7648+
/// Memory-Level Parallelism (MLP). This function transforms a chain of
7649+
/// sequential NEON lane loads into parallel vector loads that can execute
7650+
/// concurrently.
7651+
static void
7652+
generateGatherLanePattern(MachineInstr &Root,
7653+
SmallVectorImpl<MachineInstr *> &InsInstrs,
7654+
SmallVectorImpl<MachineInstr *> &DelInstrs,
7655+
DenseMap<Register, unsigned> &InstrIdxForVirtReg,
7656+
unsigned Pattern, unsigned NumLanes) {
7657+
MachineFunction &MF = *Root.getParent()->getParent();
7658+
MachineRegisterInfo &MRI = MF.getRegInfo();
7659+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
7660+
7661+
// Gather the initial load instructions to build the pattern.
7662+
SmallVector<MachineInstr *, 16> LoadToLaneInstrs;
7663+
MachineInstr *CurrInstr = &Root;
7664+
for (unsigned i = 0; i < NumLanes - 1; ++i) {
7665+
LoadToLaneInstrs.push_back(CurrInstr);
7666+
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
7667+
}
7668+
7669+
// Sort the load instructions according to the lane.
7670+
llvm::sort(LoadToLaneInstrs,
7671+
[](const MachineInstr *A, const MachineInstr *B) {
7672+
return A->getOperand(2).getImm() > B->getOperand(2).getImm();
7673+
});
7674+
7675+
MachineInstr *SubregToReg = CurrInstr;
7676+
LoadToLaneInstrs.push_back(
7677+
MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg()));
7678+
auto LoadToLaneInstrsAscending = llvm::reverse(LoadToLaneInstrs);
7679+
7680+
const TargetRegisterClass *FPR128RegClass =
7681+
MRI.getRegClass(Root.getOperand(0).getReg());
7682+
7683+
// Helper lambda to create a LD1 instruction.
7684+
auto CreateLD1Instruction = [&](MachineInstr *OriginalInstr,
7685+
Register SrcRegister, unsigned Lane,
7686+
Register OffsetRegister,
7687+
bool OffsetRegisterKillState) {
7688+
auto NewRegister = MRI.createVirtualRegister(FPR128RegClass);
7689+
MachineInstrBuilder LoadIndexIntoRegister =
7690+
BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()),
7691+
NewRegister)
7692+
.addReg(SrcRegister)
7693+
.addImm(Lane)
7694+
.addReg(OffsetRegister, getKillRegState(OffsetRegisterKillState));
7695+
InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size()));
7696+
InsInstrs.push_back(LoadIndexIntoRegister);
7697+
return NewRegister;
7698+
};
7699+
7700+
// Helper to create load instruction based on the NumLanes in the NEON
7701+
// register we are rewriting.
7702+
auto CreateLDRInstruction = [&](unsigned NumLanes, Register DestReg,
7703+
Register OffsetReg,
7704+
bool KillState) -> MachineInstrBuilder {
7705+
unsigned Opcode;
7706+
switch (NumLanes) {
7707+
case 4:
7708+
Opcode = AArch64::LDRSui;
7709+
break;
7710+
case 8:
7711+
Opcode = AArch64::LDRHui;
7712+
break;
7713+
case 16:
7714+
Opcode = AArch64::LDRBui;
7715+
break;
7716+
default:
7717+
llvm_unreachable(
7718+
"Got unsupported number of lanes in machine-combiner gather pattern");
7719+
}
7720+
// Immediate offset load
7721+
return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg)
7722+
.addReg(OffsetReg)
7723+
.addImm(0);
7724+
};
7725+
7726+
// Load the remaining lanes into register 0.
7727+
auto LanesToLoadToReg0 =
7728+
llvm::make_range(LoadToLaneInstrsAscending.begin() + 1,
7729+
LoadToLaneInstrsAscending.begin() + NumLanes / 2);
7730+
Register PrevReg = SubregToReg->getOperand(0).getReg();
7731+
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) {
7732+
const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3);
7733+
PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1,
7734+
OffsetRegOperand.getReg(),
7735+
OffsetRegOperand.isKill());
7736+
DelInstrs.push_back(LoadInstr);
7737+
}
7738+
Register LastLoadReg0 = PrevReg;
7739+
7740+
// First load into register 1. Perform an integer load to zero out the upper
7741+
// lanes in a single instruction.
7742+
MachineInstr *Lane0Load = *LoadToLaneInstrsAscending.begin();
7743+
MachineInstr *OriginalSplitLoad =
7744+
*std::next(LoadToLaneInstrsAscending.begin(), NumLanes / 2);
7745+
Register DestRegForMiddleIndex = MRI.createVirtualRegister(
7746+
MRI.getRegClass(Lane0Load->getOperand(0).getReg()));
7747+
7748+
const MachineOperand &OriginalSplitToLoadOffsetOperand =
7749+
OriginalSplitLoad->getOperand(3);
7750+
MachineInstrBuilder MiddleIndexLoadInstr =
7751+
CreateLDRInstruction(NumLanes, DestRegForMiddleIndex,
7752+
OriginalSplitToLoadOffsetOperand.getReg(),
7753+
OriginalSplitToLoadOffsetOperand.isKill());
7754+
7755+
InstrIdxForVirtReg.insert(
7756+
std::make_pair(DestRegForMiddleIndex, InsInstrs.size()));
7757+
InsInstrs.push_back(MiddleIndexLoadInstr);
7758+
DelInstrs.push_back(OriginalSplitLoad);
7759+
7760+
// Subreg To Reg instruction for register 1.
7761+
Register DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass);
7762+
unsigned SubregType;
7763+
switch (NumLanes) {
7764+
case 4:
7765+
SubregType = AArch64::ssub;
7766+
break;
7767+
case 8:
7768+
SubregType = AArch64::hsub;
7769+
break;
7770+
case 16:
7771+
SubregType = AArch64::bsub;
7772+
break;
7773+
default:
7774+
llvm_unreachable(
7775+
"Got invalid NumLanes for machine-combiner gather pattern");
7776+
}
7777+
7778+
auto SubRegToRegInstr =
7779+
BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()),
7780+
DestRegForSubregToReg)
7781+
.addImm(0)
7782+
.addReg(DestRegForMiddleIndex, getKillRegState(true))
7783+
.addImm(SubregType);
7784+
InstrIdxForVirtReg.insert(
7785+
std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
7786+
InsInstrs.push_back(SubRegToRegInstr);
7787+
7788+
// Load remaining lanes into register 1.
7789+
auto LanesToLoadToReg1 =
7790+
llvm::make_range(LoadToLaneInstrsAscending.begin() + NumLanes / 2 + 1,
7791+
LoadToLaneInstrsAscending.end());
7792+
PrevReg = SubRegToRegInstr->getOperand(0).getReg();
7793+
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) {
7794+
const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3);
7795+
PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1,
7796+
OffsetRegOperand.getReg(),
7797+
OffsetRegOperand.isKill());
7798+
7799+
// Do not add the last reg to DelInstrs - it will be removed later.
7800+
if (Index == NumLanes / 2 - 2) {
7801+
break;
7802+
}
7803+
DelInstrs.push_back(LoadInstr);
7804+
}
7805+
Register LastLoadReg1 = PrevReg;
7806+
7807+
// Create the final zip instruction to combine the results.
7808+
MachineInstrBuilder ZipInstr =
7809+
BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64),
7810+
Root.getOperand(0).getReg())
7811+
.addReg(LastLoadReg0)
7812+
.addReg(LastLoadReg1);
7813+
InsInstrs.push_back(ZipInstr);
7814+
}
7815+
74887816
CombinerObjective
74897817
AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const {
74907818
switch (Pattern) {
74917819
case AArch64MachineCombinerPattern::SUBADD_OP1:
74927820
case AArch64MachineCombinerPattern::SUBADD_OP2:
7821+
case AArch64MachineCombinerPattern::GATHER_LANE_i32:
7822+
case AArch64MachineCombinerPattern::GATHER_LANE_i16:
7823+
case AArch64MachineCombinerPattern::GATHER_LANE_i8:
74937824
return CombinerObjective::MustReduceDepth;
74947825
default:
74957826
return TargetInstrInfo::getCombinerObjective(Pattern);
@@ -7519,6 +7850,10 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
75197850
if (getMiscPatterns(Root, Patterns))
75207851
return true;
75217852

7853+
// Load patterns
7854+
if (getLoadPatterns(Root, Patterns))
7855+
return true;
7856+
75227857
return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
75237858
DoRegPressureReduce);
75247859
}
@@ -8774,6 +9109,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
87749109
MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs);
87759110
break;
87769111
}
9112+
case AArch64MachineCombinerPattern::GATHER_LANE_i32: {
9113+
generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
9114+
Pattern, 4);
9115+
break;
9116+
}
9117+
case AArch64MachineCombinerPattern::GATHER_LANE_i16: {
9118+
generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
9119+
Pattern, 8);
9120+
break;
9121+
}
9122+
case AArch64MachineCombinerPattern::GATHER_LANE_i8: {
9123+
generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
9124+
Pattern, 16);
9125+
break;
9126+
}
87779127

87789128
} // end switch (Pattern)
87799129
// Record MUL and ADD/SUB for deletion

0 commit comments

Comments
 (0)