Skip to content
Merged
265 changes: 265 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "Utils/AArch64BaseInfo.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/CodeGen/CFIInstBuilder.h"
#include "llvm/CodeGen/LivePhysRegs.h"
Expand All @@ -35,6 +36,7 @@
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/RegisterScavenging.h"
#include "llvm/CodeGen/StackMaps.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/DebugInfoMetadata.h"
Expand Down Expand Up @@ -7351,6 +7353,9 @@ bool AArch64InstrInfo::isThroughputPattern(unsigned Pattern) const {
case AArch64MachineCombinerPattern::MULSUBv2i32_indexed_OP2:
case AArch64MachineCombinerPattern::MULSUBv4i32_indexed_OP1:
case AArch64MachineCombinerPattern::MULSUBv4i32_indexed_OP2:
case AArch64MachineCombinerPattern::GATHER_LANE_i32:
case AArch64MachineCombinerPattern::GATHER_LANE_i16:
case AArch64MachineCombinerPattern::GATHER_LANE_i8:
return true;
} // end switch (Pattern)
return false;
Expand Down Expand Up @@ -7391,11 +7396,252 @@ static bool getMiscPatterns(MachineInstr &Root,
return false;
}

static bool getGatherPattern(MachineInstr &Root,
SmallVectorImpl<unsigned> &Patterns,
unsigned LoadLaneOpCode, unsigned NumLanes) {
const MachineFunction *MF = Root.getMF();

// Early exit if optimizing for size.
if (MF->getFunction().hasMinSize())
return false;

const MachineRegisterInfo &MRI = MF->getRegInfo();
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();

// The root of the pattern must load into the last lane of the vector.
if (Root.getOperand(2).getImm() != NumLanes - 1)
return false;

// Check that we have load into all lanes except lane 0.
// For each load we also want to check that:
// 1. It has a single non-debug use (since we will be replacing the virtual
// register)
// 2. That the addressing mode only uses a single offset register.
auto *CurrInstr = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
auto Range = llvm::seq<unsigned>(1, NumLanes - 1);
SmallSet<unsigned, 4> RemainingLanes(Range.begin(), Range.end());
while (!RemainingLanes.empty() && CurrInstr &&
CurrInstr->getOpcode() == LoadLaneOpCode &&
MRI.hasOneNonDBGUse(CurrInstr->getOperand(0).getReg()) &&
CurrInstr->getNumOperands() == 4) {
RemainingLanes.erase(CurrInstr->getOperand(2).getImm());
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
}

if (!RemainingLanes.empty())
return false;

// Match the SUBREG_TO_REG sequence.
if (CurrInstr->getOpcode() != TargetOpcode::SUBREG_TO_REG)
return false;

// Verify that the subreg to reg loads an integer into the first lane.
auto Lane0LoadReg = CurrInstr->getOperand(2).getReg();
unsigned SingleLaneSizeInBits = 128 / NumLanes;
if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != SingleLaneSizeInBits)
return false;

// Verify that it also has a single non debug use.
if (!MRI.hasOneNonDBGUse(Lane0LoadReg))
return false;

switch (NumLanes) {
case 4:
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i32);
break;
case 8:
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i16);
break;
case 16:
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i8);
break;
default:
llvm_unreachable("Got bad number of lanes for gather pattern.");
}

return true;
}

/// Search for patterns where we use LD1 instructions to load into
/// separate lanes of an 128 bit Neon register. We can increase Memory Level
/// Parallelism by loading into 2 Neon registers instead.
static bool getLoadPatterns(MachineInstr &Root,
SmallVectorImpl<unsigned> &Patterns) {

// The pattern searches for loads into single lanes.
switch (Root.getOpcode()) {
case AArch64::LD1i32:
return getGatherPattern(Root, Patterns, Root.getOpcode(), 4);
case AArch64::LD1i16:
return getGatherPattern(Root, Patterns, Root.getOpcode(), 8);
case AArch64::LD1i8:
return getGatherPattern(Root, Patterns, Root.getOpcode(), 16);
default:
return false;
}
}

static void
generateGatherPattern(MachineInstr &Root,
SmallVectorImpl<MachineInstr *> &InsInstrs,
SmallVectorImpl<MachineInstr *> &DelInstrs,
DenseMap<Register, unsigned> &InstrIdxForVirtReg,
unsigned Pattern, unsigned NumLanes) {

MachineFunction &MF = *Root.getParent()->getParent();
MachineRegisterInfo &MRI = MF.getRegInfo();
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();

// Gather the initial load instructions to build the pattern
SmallVector<MachineInstr *, 16> LoadToLaneInstrs;
MachineInstr *CurrInstr = &Root;
for (unsigned i = 0; i < NumLanes - 1; ++i) {
LoadToLaneInstrs.push_back(CurrInstr);
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
}

// Sort the load instructions according to the lane.
llvm::sort(LoadToLaneInstrs,
[](const MachineInstr *A, const MachineInstr *B) {
return A->getOperand(2).getImm() > B->getOperand(2).getImm();
});

MachineInstr *SubregToReg = CurrInstr;
LoadToLaneInstrs.push_back(
MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg()));
auto LoadToLaneInstrsAscending = llvm::reverse(LoadToLaneInstrs);

const TargetRegisterClass *FPR128RegClass =
MRI.getRegClass(Root.getOperand(0).getReg());

auto LoadLaneToRegister = [&](MachineInstr *OriginalInstr,
Register SrcRegister, unsigned Lane,
Register OffsetRegister) {
auto NewRegister = MRI.createVirtualRegister(FPR128RegClass);
MachineInstrBuilder LoadIndexIntoRegister =
BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()),
NewRegister)
.addReg(SrcRegister)
.addImm(Lane)
.addReg(OffsetRegister, getKillRegState(true));
InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size()));
InsInstrs.push_back(LoadIndexIntoRegister);
return NewRegister;
};

// Helper to create load instruction based on opcode
auto CreateLoadInstruction = [&](unsigned NumLanes, Register DestReg,
Register OffsetReg) -> MachineInstrBuilder {
unsigned Opcode;
switch (NumLanes) {
case 4:
Opcode = AArch64::LDRSui;
break;
case 8:
Opcode = AArch64::LDRHui;
break;
case 16:
Opcode = AArch64::LDRBui;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 8 and 16 lanes, it may be worth splitting into more than 2 separate chains of loads?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like I wrote above to @davemgreen - it kicks in less often because of the additional use of load ports. I wanted to keep the pattern (relatively) simple, so decided not to do that right now. I can add it in a follow up PR if you think it's a good idea.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sounds good as follow-up

break;
default:
llvm_unreachable(
"Got unsupported number of lanes in machine-combiner gather pattern");
}
// Immediate offset load
return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg)
.addReg(OffsetReg)
.addImm(0); // immediate offset
};

// Load the remaining lanes into register 0.
auto LanesToLoadToReg0 =
llvm::make_range(LoadToLaneInstrsAscending.begin() + 1,
LoadToLaneInstrsAscending.begin() + NumLanes / 2);
auto PrevReg = SubregToReg->getOperand(0).getReg();
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) {
PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1,
LoadInstr->getOperand(3).getReg());
DelInstrs.push_back(LoadInstr);
}
auto LastLoadReg0 = PrevReg;

// First load into register 1. Perform a LDRSui to zero out the upper lanes in
// a single instruction.
auto Lane0Load = *LoadToLaneInstrsAscending.begin();
auto OriginalSplitLoad =
*std::next(LoadToLaneInstrsAscending.begin(), NumLanes / 2);
auto DestRegForMiddleIndex = MRI.createVirtualRegister(
MRI.getRegClass(Lane0Load->getOperand(0).getReg()));

MachineInstrBuilder MiddleIndexLoadInstr =
CreateLoadInstruction(NumLanes, DestRegForMiddleIndex,
OriginalSplitLoad->getOperand(3).getReg());

InstrIdxForVirtReg.insert(
std::make_pair(DestRegForMiddleIndex, InsInstrs.size()));
InsInstrs.push_back(MiddleIndexLoadInstr);
DelInstrs.push_back(OriginalSplitLoad);

// Subreg To Reg instruction for register 1.
auto DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass);
unsigned SubregType;
switch (NumLanes) {
case 4:
SubregType = AArch64::ssub;
break;
case 8:
SubregType = AArch64::hsub;
break;
case 16:
SubregType = AArch64::bsub;
break;
default:
llvm_unreachable(
"Got invalid NumLanes for machine-combiner gather pattern");
}

auto SubRegToRegInstr =
BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()),
DestRegForSubregToReg)
.addImm(0)
.addReg(DestRegForMiddleIndex, getKillRegState(true))
.addImm(SubregType);
InstrIdxForVirtReg.insert(
std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
InsInstrs.push_back(SubRegToRegInstr);

// Load remaining lanes into register 1.
auto LanesToLoadToReg1 =
llvm::make_range(LoadToLaneInstrsAscending.begin() + NumLanes / 2 + 1,
LoadToLaneInstrsAscending.end());
PrevReg = SubRegToRegInstr->getOperand(0).getReg();
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) {
PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1,
LoadInstr->getOperand(3).getReg());
if (Index == NumLanes / 2 - 2) {
break;
}
DelInstrs.push_back(LoadInstr);
}
auto LastLoadReg1 = PrevReg;

// Create the final zip instruction to combine the results.
MachineInstrBuilder ZipInstr =
BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64),
Root.getOperand(0).getReg())
.addReg(LastLoadReg0)
.addReg(LastLoadReg1);
InsInstrs.push_back(ZipInstr);
}

CombinerObjective
AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const {
switch (Pattern) {
case AArch64MachineCombinerPattern::SUBADD_OP1:
case AArch64MachineCombinerPattern::SUBADD_OP2:
case AArch64MachineCombinerPattern::GATHER_LANE_i32:
case AArch64MachineCombinerPattern::GATHER_LANE_i16:
case AArch64MachineCombinerPattern::GATHER_LANE_i8:
return CombinerObjective::MustReduceDepth;
default:
return TargetInstrInfo::getCombinerObjective(Pattern);
Expand Down Expand Up @@ -7425,6 +7671,10 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
if (getMiscPatterns(Root, Patterns))
return true;

// Load patterns
if (getLoadPatterns(Root, Patterns))
return true;

return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
DoRegPressureReduce);
}
Expand Down Expand Up @@ -8680,6 +8930,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs);
break;
}
case AArch64MachineCombinerPattern::GATHER_LANE_i32: {
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
Pattern, 4);
break;
}
case AArch64MachineCombinerPattern::GATHER_LANE_i16: {
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
Pattern, 8);
break;
}
case AArch64MachineCombinerPattern::GATHER_LANE_i8: {
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
Pattern, 16);
break;
}

} // end switch (Pattern)
// Record MUL and ADD/SUB for deletion
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ enum AArch64MachineCombinerPattern : unsigned {
FMULv8i16_indexed_OP2,

FNMADD,

GATHER_LANE_i32,
GATHER_LANE_i16,
GATHER_LANE_i8
};
class AArch64InstrInfo final : public AArch64GenInstrInfo {
const AArch64RegisterInfo RI;
Expand Down
Loading
Loading