Skip to content

Commit e8a891b

Browse files
authored
[AArch64][Machine-Combiner] Split gather patterns into neon regs to multiple vectors (llvm#142941)
This changes optimizes gather-like sequences, where we load values separately into lanes of a neon vector. Since each load has serial dependency, when performing multiple i32 loads into a 128 bit vector for example, it is more profitable to load into separate vector registers and zip them. rdar://151851094
1 parent a6fb3b3 commit e8a891b

File tree

10 files changed

+996
-346
lines changed

10 files changed

+996
-346
lines changed

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "Utils/AArch64BaseInfo.h"
2121
#include "llvm/ADT/ArrayRef.h"
2222
#include "llvm/ADT/STLExtras.h"
23+
#include "llvm/ADT/SmallSet.h"
2324
#include "llvm/ADT/SmallVector.h"
2425
#include "llvm/CodeGen/CFIInstBuilder.h"
2526
#include "llvm/CodeGen/LivePhysRegs.h"
@@ -35,6 +36,7 @@
3536
#include "llvm/CodeGen/MachineRegisterInfo.h"
3637
#include "llvm/CodeGen/RegisterScavenging.h"
3738
#include "llvm/CodeGen/StackMaps.h"
39+
#include "llvm/CodeGen/TargetOpcodes.h"
3840
#include "llvm/CodeGen/TargetRegisterInfo.h"
3941
#include "llvm/CodeGen/TargetSubtargetInfo.h"
4042
#include "llvm/IR/DebugInfoMetadata.h"
@@ -7351,6 +7353,9 @@ bool AArch64InstrInfo::isThroughputPattern(unsigned Pattern) const {
73517353
case AArch64MachineCombinerPattern::MULSUBv2i32_indexed_OP2:
73527354
case AArch64MachineCombinerPattern::MULSUBv4i32_indexed_OP1:
73537355
case AArch64MachineCombinerPattern::MULSUBv4i32_indexed_OP2:
7356+
case AArch64MachineCombinerPattern::GATHER_LANE_i32:
7357+
case AArch64MachineCombinerPattern::GATHER_LANE_i16:
7358+
case AArch64MachineCombinerPattern::GATHER_LANE_i8:
73547359
return true;
73557360
} // end switch (Pattern)
73567361
return false;
@@ -7391,11 +7396,252 @@ static bool getMiscPatterns(MachineInstr &Root,
73917396
return false;
73927397
}
73937398

7399+
static bool getGatherPattern(MachineInstr &Root,
7400+
SmallVectorImpl<unsigned> &Patterns,
7401+
unsigned LoadLaneOpCode, unsigned NumLanes) {
7402+
const MachineFunction *MF = Root.getMF();
7403+
7404+
// Early exit if optimizing for size.
7405+
if (MF->getFunction().hasMinSize())
7406+
return false;
7407+
7408+
const MachineRegisterInfo &MRI = MF->getRegInfo();
7409+
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
7410+
7411+
// The root of the pattern must load into the last lane of the vector.
7412+
if (Root.getOperand(2).getImm() != NumLanes - 1)
7413+
return false;
7414+
7415+
// Check that we have load into all lanes except lane 0.
7416+
// For each load we also want to check that:
7417+
// 1. It has a single non-debug use (since we will be replacing the virtual
7418+
// register)
7419+
// 2. That the addressing mode only uses a single offset register.
7420+
auto *CurrInstr = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
7421+
auto Range = llvm::seq<unsigned>(1, NumLanes - 1);
7422+
SmallSet<unsigned, 4> RemainingLanes(Range.begin(), Range.end());
7423+
while (!RemainingLanes.empty() && CurrInstr &&
7424+
CurrInstr->getOpcode() == LoadLaneOpCode &&
7425+
MRI.hasOneNonDBGUse(CurrInstr->getOperand(0).getReg()) &&
7426+
CurrInstr->getNumOperands() == 4) {
7427+
RemainingLanes.erase(CurrInstr->getOperand(2).getImm());
7428+
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
7429+
}
7430+
7431+
if (!RemainingLanes.empty())
7432+
return false;
7433+
7434+
// Match the SUBREG_TO_REG sequence.
7435+
if (CurrInstr->getOpcode() != TargetOpcode::SUBREG_TO_REG)
7436+
return false;
7437+
7438+
// Verify that the subreg to reg loads an integer into the first lane.
7439+
auto Lane0LoadReg = CurrInstr->getOperand(2).getReg();
7440+
unsigned SingleLaneSizeInBits = 128 / NumLanes;
7441+
if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != SingleLaneSizeInBits)
7442+
return false;
7443+
7444+
// Verify that it also has a single non debug use.
7445+
if (!MRI.hasOneNonDBGUse(Lane0LoadReg))
7446+
return false;
7447+
7448+
switch (NumLanes) {
7449+
case 4:
7450+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i32);
7451+
break;
7452+
case 8:
7453+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i16);
7454+
break;
7455+
case 16:
7456+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i8);
7457+
break;
7458+
default:
7459+
llvm_unreachable("Got bad number of lanes for gather pattern.");
7460+
}
7461+
7462+
return true;
7463+
}
7464+
7465+
/// Search for patterns where we use LD1 instructions to load into
7466+
/// separate lanes of an 128 bit Neon register. We can increase Memory Level
7467+
/// Parallelism by loading into 2 Neon registers instead.
7468+
static bool getLoadPatterns(MachineInstr &Root,
7469+
SmallVectorImpl<unsigned> &Patterns) {
7470+
7471+
// The pattern searches for loads into single lanes.
7472+
switch (Root.getOpcode()) {
7473+
case AArch64::LD1i32:
7474+
return getGatherPattern(Root, Patterns, Root.getOpcode(), 4);
7475+
case AArch64::LD1i16:
7476+
return getGatherPattern(Root, Patterns, Root.getOpcode(), 8);
7477+
case AArch64::LD1i8:
7478+
return getGatherPattern(Root, Patterns, Root.getOpcode(), 16);
7479+
default:
7480+
return false;
7481+
}
7482+
}
7483+
7484+
static void
7485+
generateGatherPattern(MachineInstr &Root,
7486+
SmallVectorImpl<MachineInstr *> &InsInstrs,
7487+
SmallVectorImpl<MachineInstr *> &DelInstrs,
7488+
DenseMap<Register, unsigned> &InstrIdxForVirtReg,
7489+
unsigned Pattern, unsigned NumLanes) {
7490+
7491+
MachineFunction &MF = *Root.getParent()->getParent();
7492+
MachineRegisterInfo &MRI = MF.getRegInfo();
7493+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
7494+
7495+
// Gather the initial load instructions to build the pattern
7496+
SmallVector<MachineInstr *, 16> LoadToLaneInstrs;
7497+
MachineInstr *CurrInstr = &Root;
7498+
for (unsigned i = 0; i < NumLanes - 1; ++i) {
7499+
LoadToLaneInstrs.push_back(CurrInstr);
7500+
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
7501+
}
7502+
7503+
// Sort the load instructions according to the lane.
7504+
llvm::sort(LoadToLaneInstrs,
7505+
[](const MachineInstr *A, const MachineInstr *B) {
7506+
return A->getOperand(2).getImm() > B->getOperand(2).getImm();
7507+
});
7508+
7509+
MachineInstr *SubregToReg = CurrInstr;
7510+
LoadToLaneInstrs.push_back(
7511+
MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg()));
7512+
auto LoadToLaneInstrsAscending = llvm::reverse(LoadToLaneInstrs);
7513+
7514+
const TargetRegisterClass *FPR128RegClass =
7515+
MRI.getRegClass(Root.getOperand(0).getReg());
7516+
7517+
auto LoadLaneToRegister = [&](MachineInstr *OriginalInstr,
7518+
Register SrcRegister, unsigned Lane,
7519+
Register OffsetRegister) {
7520+
auto NewRegister = MRI.createVirtualRegister(FPR128RegClass);
7521+
MachineInstrBuilder LoadIndexIntoRegister =
7522+
BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()),
7523+
NewRegister)
7524+
.addReg(SrcRegister)
7525+
.addImm(Lane)
7526+
.addReg(OffsetRegister, getKillRegState(true));
7527+
InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size()));
7528+
InsInstrs.push_back(LoadIndexIntoRegister);
7529+
return NewRegister;
7530+
};
7531+
7532+
// Helper to create load instruction based on opcode
7533+
auto CreateLoadInstruction = [&](unsigned NumLanes, Register DestReg,
7534+
Register OffsetReg) -> MachineInstrBuilder {
7535+
unsigned Opcode;
7536+
switch (NumLanes) {
7537+
case 4:
7538+
Opcode = AArch64::LDRSui;
7539+
break;
7540+
case 8:
7541+
Opcode = AArch64::LDRHui;
7542+
break;
7543+
case 16:
7544+
Opcode = AArch64::LDRBui;
7545+
break;
7546+
default:
7547+
llvm_unreachable(
7548+
"Got unsupported number of lanes in machine-combiner gather pattern");
7549+
}
7550+
// Immediate offset load
7551+
return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg)
7552+
.addReg(OffsetReg)
7553+
.addImm(0); // immediate offset
7554+
};
7555+
7556+
// Load the remaining lanes into register 0.
7557+
auto LanesToLoadToReg0 =
7558+
llvm::make_range(LoadToLaneInstrsAscending.begin() + 1,
7559+
LoadToLaneInstrsAscending.begin() + NumLanes / 2);
7560+
auto PrevReg = SubregToReg->getOperand(0).getReg();
7561+
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) {
7562+
PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1,
7563+
LoadInstr->getOperand(3).getReg());
7564+
DelInstrs.push_back(LoadInstr);
7565+
}
7566+
auto LastLoadReg0 = PrevReg;
7567+
7568+
// First load into register 1. Perform a LDRSui to zero out the upper lanes in
7569+
// a single instruction.
7570+
auto Lane0Load = *LoadToLaneInstrsAscending.begin();
7571+
auto OriginalSplitLoad =
7572+
*std::next(LoadToLaneInstrsAscending.begin(), NumLanes / 2);
7573+
auto DestRegForMiddleIndex = MRI.createVirtualRegister(
7574+
MRI.getRegClass(Lane0Load->getOperand(0).getReg()));
7575+
7576+
MachineInstrBuilder MiddleIndexLoadInstr =
7577+
CreateLoadInstruction(NumLanes, DestRegForMiddleIndex,
7578+
OriginalSplitLoad->getOperand(3).getReg());
7579+
7580+
InstrIdxForVirtReg.insert(
7581+
std::make_pair(DestRegForMiddleIndex, InsInstrs.size()));
7582+
InsInstrs.push_back(MiddleIndexLoadInstr);
7583+
DelInstrs.push_back(OriginalSplitLoad);
7584+
7585+
// Subreg To Reg instruction for register 1.
7586+
auto DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass);
7587+
unsigned SubregType;
7588+
switch (NumLanes) {
7589+
case 4:
7590+
SubregType = AArch64::ssub;
7591+
break;
7592+
case 8:
7593+
SubregType = AArch64::hsub;
7594+
break;
7595+
case 16:
7596+
SubregType = AArch64::bsub;
7597+
break;
7598+
default:
7599+
llvm_unreachable(
7600+
"Got invalid NumLanes for machine-combiner gather pattern");
7601+
}
7602+
7603+
auto SubRegToRegInstr =
7604+
BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()),
7605+
DestRegForSubregToReg)
7606+
.addImm(0)
7607+
.addReg(DestRegForMiddleIndex, getKillRegState(true))
7608+
.addImm(SubregType);
7609+
InstrIdxForVirtReg.insert(
7610+
std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
7611+
InsInstrs.push_back(SubRegToRegInstr);
7612+
7613+
// Load remaining lanes into register 1.
7614+
auto LanesToLoadToReg1 =
7615+
llvm::make_range(LoadToLaneInstrsAscending.begin() + NumLanes / 2 + 1,
7616+
LoadToLaneInstrsAscending.end());
7617+
PrevReg = SubRegToRegInstr->getOperand(0).getReg();
7618+
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) {
7619+
PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1,
7620+
LoadInstr->getOperand(3).getReg());
7621+
if (Index == NumLanes / 2 - 2) {
7622+
break;
7623+
}
7624+
DelInstrs.push_back(LoadInstr);
7625+
}
7626+
auto LastLoadReg1 = PrevReg;
7627+
7628+
// Create the final zip instruction to combine the results.
7629+
MachineInstrBuilder ZipInstr =
7630+
BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64),
7631+
Root.getOperand(0).getReg())
7632+
.addReg(LastLoadReg0)
7633+
.addReg(LastLoadReg1);
7634+
InsInstrs.push_back(ZipInstr);
7635+
}
7636+
73947637
CombinerObjective
73957638
AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const {
73967639
switch (Pattern) {
73977640
case AArch64MachineCombinerPattern::SUBADD_OP1:
73987641
case AArch64MachineCombinerPattern::SUBADD_OP2:
7642+
case AArch64MachineCombinerPattern::GATHER_LANE_i32:
7643+
case AArch64MachineCombinerPattern::GATHER_LANE_i16:
7644+
case AArch64MachineCombinerPattern::GATHER_LANE_i8:
73997645
return CombinerObjective::MustReduceDepth;
74007646
default:
74017647
return TargetInstrInfo::getCombinerObjective(Pattern);
@@ -7425,6 +7671,10 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
74257671
if (getMiscPatterns(Root, Patterns))
74267672
return true;
74277673

7674+
// Load patterns
7675+
if (getLoadPatterns(Root, Patterns))
7676+
return true;
7677+
74287678
return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
74297679
DoRegPressureReduce);
74307680
}
@@ -8680,6 +8930,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
86808930
MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs);
86818931
break;
86828932
}
8933+
case AArch64MachineCombinerPattern::GATHER_LANE_i32: {
8934+
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
8935+
Pattern, 4);
8936+
break;
8937+
}
8938+
case AArch64MachineCombinerPattern::GATHER_LANE_i16: {
8939+
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
8940+
Pattern, 8);
8941+
break;
8942+
}
8943+
case AArch64MachineCombinerPattern::GATHER_LANE_i8: {
8944+
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
8945+
Pattern, 16);
8946+
break;
8947+
}
86838948

86848949
} // end switch (Pattern)
86858950
// Record MUL and ADD/SUB for deletion

llvm/lib/Target/AArch64/AArch64InstrInfo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ enum AArch64MachineCombinerPattern : unsigned {
172172
FMULv8i16_indexed_OP2,
173173

174174
FNMADD,
175+
176+
GATHER_LANE_i32,
177+
GATHER_LANE_i16,
178+
GATHER_LANE_i8
175179
};
176180
class AArch64InstrInfo final : public AArch64GenInstrInfo {
177181
const AArch64RegisterInfo RI;

0 commit comments

Comments
 (0)