Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
334 changes: 334 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
#include "Utils/AArch64BaseInfo.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/CodeGen/CFIInstBuilder.h"
#include "llvm/CodeGen/LivePhysRegs.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
Expand Down Expand Up @@ -83,6 +85,11 @@ static cl::opt<unsigned>
BDisplacementBits("aarch64-b-offset-bits", cl::Hidden, cl::init(26),
cl::desc("Restrict range of B instructions (DEBUG)"));

static cl::opt<unsigned> GatherOptSearchLimit(
"aarch64-search-limit", cl::Hidden, cl::init(2048),
cl::desc("Restrict range of instructions to search for the "
"machine-combiner gather pattern optimization"));

AArch64InstrInfo::AArch64InstrInfo(const AArch64Subtarget &STI)
: AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP,
AArch64::CATCHRET),
Expand Down Expand Up @@ -7412,11 +7419,319 @@ static bool getMiscPatterns(MachineInstr &Root,
return false;
}

/// Check if the given instruction forms a gather load pattern that can be
/// optimized for better Memory-Level Parallelism (MLP). This function
/// identifies chains of NEON lane load instructions that load data from
/// different memory addresses into individual lanes of a 128-bit vector
/// register, then attempts to split the pattern into parallel loads to break
/// the serial dependency between instructions.
///
/// Pattern Matched:
/// Initial scalar load -> SUBREG_TO_REG (lane 0) -> LD1i* (lane 1) ->
/// LD1i* (lane 2) -> ... -> LD1i* (lane N-1, Root)
///
/// Transformed Into:
/// Two parallel vector loads using fewer lanes each, followed by ZIP1v2i64
/// to combine the results, enabling better memory-level parallelism.
///
/// Supported Element Types:
/// - 32-bit elements (LD1i32, 4 lanes total)
/// - 16-bit elements (LD1i16, 8 lanes total)
/// - 8-bit elements (LD1i8, 16 lanes total)
static bool getGatherLanePattern(MachineInstr &Root,
SmallVectorImpl<unsigned> &Patterns,
unsigned LoadLaneOpCode, unsigned NumLanes) {
const MachineFunction *MF = Root.getMF();

// Early exit if optimizing for size.
if (MF->getFunction().hasMinSize())
return false;

const MachineRegisterInfo &MRI = MF->getRegInfo();
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();

// The root of the pattern must load into the last lane of the vector.
if (Root.getOperand(2).getImm() != NumLanes - 1)
return false;

// Check that we have load into all lanes except lane 0.
// For each load we also want to check that:
// 1. It has a single non-debug use (since we will be replacing the virtual
// register)
// 2. That the addressing mode only uses a single pointer operand
auto *CurrInstr = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
auto Range = llvm::seq<unsigned>(1, NumLanes - 1);
SmallSet<unsigned, 16> RemainingLanes(Range.begin(), Range.end());
SmallVector<const MachineInstr *, 16> LoadInstrs;
while (!RemainingLanes.empty() && CurrInstr &&
CurrInstr->getOpcode() == LoadLaneOpCode &&
MRI.hasOneNonDBGUse(CurrInstr->getOperand(0).getReg()) &&
CurrInstr->getNumOperands() == 4) {
RemainingLanes.erase(CurrInstr->getOperand(2).getImm());
LoadInstrs.push_back(CurrInstr);
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
}

// Check that we have found a match for lanes N-1.. 1.
if (!RemainingLanes.empty())
return false;

// Match the SUBREG_TO_REG sequence.
if (CurrInstr->getOpcode() != TargetOpcode::SUBREG_TO_REG)
return false;

// Verify that the subreg to reg loads an integer into the first lane.
auto Lane0LoadReg = CurrInstr->getOperand(2).getReg();
unsigned SingleLaneSizeInBits = 128 / NumLanes;
if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != SingleLaneSizeInBits)
return false;

// Verify that it also has a single non debug use.
if (!MRI.hasOneNonDBGUse(Lane0LoadReg))
return false;

LoadInstrs.push_back(MRI.getUniqueVRegDef(Lane0LoadReg));

// If there is any chance of aliasing, do not apply the pattern.
// Walk backward through the MBB starting from Root.
// Exit early if we've encountered all load instructions or hit the search
// limit.
auto MBBItr = Root.getIterator();
unsigned RemainingSteps = GatherOptSearchLimit;
SmallSet<const MachineInstr *, 16> RemainingLoadInstrs;
RemainingLoadInstrs.insert(LoadInstrs.begin(), LoadInstrs.end());
const MachineBasicBlock *MBB = Root.getParent();

for (; MBBItr != MBB->begin() && RemainingSteps > 0 &&
!RemainingLoadInstrs.empty();
--MBBItr, --RemainingSteps) {
const MachineInstr &CurrInstr = *MBBItr;

// Remove this instruction from remaining loads if it's one we're tracking.
RemainingLoadInstrs.erase(&CurrInstr);

// Check for potential aliasing with any of the load instructions to
// optimize.
if (CurrInstr.isLoadFoldBarrier())
return false;
}

// If we hit the search limit without finding all load instructions,
// don't match the pattern.
if (RemainingSteps == 0 && !RemainingLoadInstrs.empty())
return false;

switch (NumLanes) {
case 4:
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i32);
break;
case 8:
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i16);
break;
case 16:
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i8);
break;
default:
llvm_unreachable("Got bad number of lanes for gather pattern.");
}

return true;
}

/// Search for patterns of LD instructions we can optimize.
static bool getLoadPatterns(MachineInstr &Root,
SmallVectorImpl<unsigned> &Patterns) {

// The pattern searches for loads into single lanes.
switch (Root.getOpcode()) {
case AArch64::LD1i32:
return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 4);
case AArch64::LD1i16:
return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 8);
case AArch64::LD1i8:
return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 16);
default:
return false;
}
}

/// Generate optimized instruction sequence for gather load patterns to improve
/// Memory-Level Parallelism (MLP). This function transforms a chain of
/// sequential NEON lane loads into parallel vector loads that can execute
/// concurrently.
static void
generateGatherLanePattern(MachineInstr &Root,
SmallVectorImpl<MachineInstr *> &InsInstrs,
SmallVectorImpl<MachineInstr *> &DelInstrs,
DenseMap<Register, unsigned> &InstrIdxForVirtReg,
unsigned Pattern, unsigned NumLanes) {
MachineFunction &MF = *Root.getParent()->getParent();
MachineRegisterInfo &MRI = MF.getRegInfo();
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();

// Gather the initial load instructions to build the pattern.
SmallVector<MachineInstr *, 16> LoadToLaneInstrs;
MachineInstr *CurrInstr = &Root;
for (unsigned i = 0; i < NumLanes - 1; ++i) {
LoadToLaneInstrs.push_back(CurrInstr);
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
}

// Sort the load instructions according to the lane.
llvm::sort(LoadToLaneInstrs,
[](const MachineInstr *A, const MachineInstr *B) {
return A->getOperand(2).getImm() > B->getOperand(2).getImm();
});

MachineInstr *SubregToReg = CurrInstr;
LoadToLaneInstrs.push_back(
MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg()));
auto LoadToLaneInstrsAscending = llvm::reverse(LoadToLaneInstrs);

const TargetRegisterClass *FPR128RegClass =
MRI.getRegClass(Root.getOperand(0).getReg());

// Helper lambda to create a LD1 instruction.
auto CreateLD1Instruction = [&](MachineInstr *OriginalInstr,
Register SrcRegister, unsigned Lane,
Register OffsetRegister,
bool OffsetRegisterKillState) {
auto NewRegister = MRI.createVirtualRegister(FPR128RegClass);
MachineInstrBuilder LoadIndexIntoRegister =
BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()),
NewRegister)
.addReg(SrcRegister)
.addImm(Lane)
.addReg(OffsetRegister, getKillRegState(OffsetRegisterKillState));
InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size()));
InsInstrs.push_back(LoadIndexIntoRegister);
return NewRegister;
};

// Helper to create load instruction based on the NumLanes in the NEON
// register we are rewriting.
auto CreateLDRInstruction = [&](unsigned NumLanes, Register DestReg,
Register OffsetReg,
bool KillState) -> MachineInstrBuilder {
unsigned Opcode;
switch (NumLanes) {
case 4:
Opcode = AArch64::LDRSui;
break;
case 8:
Opcode = AArch64::LDRHui;
break;
case 16:
Opcode = AArch64::LDRBui;
break;
default:
llvm_unreachable(
"Got unsupported number of lanes in machine-combiner gather pattern");
}
// Immediate offset load
return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg)
.addReg(OffsetReg)
.addImm(0);
};

// Load the remaining lanes into register 0.
auto LanesToLoadToReg0 =
llvm::make_range(LoadToLaneInstrsAscending.begin() + 1,
LoadToLaneInstrsAscending.begin() + NumLanes / 2);
Register PrevReg = SubregToReg->getOperand(0).getReg();
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) {
const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3);
PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1,
OffsetRegOperand.getReg(),
OffsetRegOperand.isKill());
DelInstrs.push_back(LoadInstr);
}
Register LastLoadReg0 = PrevReg;

// First load into register 1. Perform an integer load to zero out the upper
// lanes in a single instruction.
MachineInstr *Lane0Load = *LoadToLaneInstrsAscending.begin();
MachineInstr *OriginalSplitLoad =
*std::next(LoadToLaneInstrsAscending.begin(), NumLanes / 2);
Register DestRegForMiddleIndex = MRI.createVirtualRegister(
MRI.getRegClass(Lane0Load->getOperand(0).getReg()));

const MachineOperand &OriginalSplitToLoadOffsetOperand =
OriginalSplitLoad->getOperand(3);
MachineInstrBuilder MiddleIndexLoadInstr =
CreateLDRInstruction(NumLanes, DestRegForMiddleIndex,
OriginalSplitToLoadOffsetOperand.getReg(),
OriginalSplitToLoadOffsetOperand.isKill());

InstrIdxForVirtReg.insert(
std::make_pair(DestRegForMiddleIndex, InsInstrs.size()));
InsInstrs.push_back(MiddleIndexLoadInstr);
DelInstrs.push_back(OriginalSplitLoad);

// Subreg To Reg instruction for register 1.
Register DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass);
unsigned SubregType;
switch (NumLanes) {
case 4:
SubregType = AArch64::ssub;
break;
case 8:
SubregType = AArch64::hsub;
break;
case 16:
SubregType = AArch64::bsub;
break;
default:
llvm_unreachable(
"Got invalid NumLanes for machine-combiner gather pattern");
}

auto SubRegToRegInstr =
BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()),
DestRegForSubregToReg)
.addImm(0)
.addReg(DestRegForMiddleIndex, getKillRegState(true))
.addImm(SubregType);
InstrIdxForVirtReg.insert(
std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
InsInstrs.push_back(SubRegToRegInstr);

// Load remaining lanes into register 1.
auto LanesToLoadToReg1 =
llvm::make_range(LoadToLaneInstrsAscending.begin() + NumLanes / 2 + 1,
LoadToLaneInstrsAscending.end());
PrevReg = SubRegToRegInstr->getOperand(0).getReg();
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) {
const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3);
PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1,
OffsetRegOperand.getReg(),
OffsetRegOperand.isKill());

// Do not add the last reg to DelInstrs - it will be removed later.
if (Index == NumLanes / 2 - 2) {
break;
}
DelInstrs.push_back(LoadInstr);
}
Register LastLoadReg1 = PrevReg;

// Create the final zip instruction to combine the results.
MachineInstrBuilder ZipInstr =
BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64),
Root.getOperand(0).getReg())
.addReg(LastLoadReg0)
.addReg(LastLoadReg1);
InsInstrs.push_back(ZipInstr);
}

CombinerObjective
AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const {
switch (Pattern) {
case AArch64MachineCombinerPattern::SUBADD_OP1:
case AArch64MachineCombinerPattern::SUBADD_OP2:
case AArch64MachineCombinerPattern::GATHER_LANE_i32:
case AArch64MachineCombinerPattern::GATHER_LANE_i16:
case AArch64MachineCombinerPattern::GATHER_LANE_i8:
return CombinerObjective::MustReduceDepth;
default:
return TargetInstrInfo::getCombinerObjective(Pattern);
Expand Down Expand Up @@ -7446,6 +7761,10 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
if (getMiscPatterns(Root, Patterns))
return true;

// Load patterns
if (getLoadPatterns(Root, Patterns))
return true;

return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
DoRegPressureReduce);
}
Expand Down Expand Up @@ -8701,6 +9020,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs);
break;
}
case AArch64MachineCombinerPattern::GATHER_LANE_i32: {
generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
Pattern, 4);
break;
}
case AArch64MachineCombinerPattern::GATHER_LANE_i16: {
generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
Pattern, 8);
break;
}
case AArch64MachineCombinerPattern::GATHER_LANE_i8: {
generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
Pattern, 16);
break;
}

} // end switch (Pattern)
// Record MUL and ADD/SUB for deletion
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ enum AArch64MachineCombinerPattern : unsigned {
FMULv8i16_indexed_OP2,

FNMADD,

GATHER_LANE_i32,
GATHER_LANE_i16,
GATHER_LANE_i8
};
class AArch64InstrInfo final : public AArch64GenInstrInfo {
const AArch64RegisterInfo RI;
Expand Down
Loading