Skip to content

Commit e30245a

Browse files
committed
Support additional data types
1 parent 8dd153f commit e30245a

File tree

3 files changed

+385
-98
lines changed

3 files changed

+385
-98
lines changed

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 199 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -7279,7 +7279,9 @@ bool AArch64InstrInfo::isThroughputPattern(unsigned Pattern) const {
72797279
case AArch64MachineCombinerPattern::MULSUBv2i32_indexed_OP2:
72807280
case AArch64MachineCombinerPattern::MULSUBv4i32_indexed_OP1:
72817281
case AArch64MachineCombinerPattern::MULSUBv4i32_indexed_OP2:
7282-
case AArch64MachineCombinerPattern::SPLIT_LD:
7282+
case AArch64MachineCombinerPattern::GATHER_i32:
7283+
case AArch64MachineCombinerPattern::GATHER_i16:
7284+
case AArch64MachineCombinerPattern::GATHER_i8:
72837285
return true;
72847286
} // end switch (Pattern)
72857287
return false;
@@ -7320,32 +7322,24 @@ static bool getMiscPatterns(MachineInstr &Root,
73207322
return false;
73217323
}
73227324

7323-
/// Search for patterns where we use LD1i32 instructions to load into
7324-
/// 4 separate lanes of a 128 bit Neon register. We can increase ILP
7325-
/// by loading into 2 Neon registers instead.
7326-
static bool getLoadPatterns(MachineInstr &Root,
7327-
SmallVectorImpl<unsigned> &Patterns) {
7325+
static bool getGatherPattern(MachineInstr &Root,
7326+
SmallVectorImpl<unsigned> &Patterns,
7327+
unsigned LoadLaneOpCode,
7328+
unsigned NumLanes) {
73287329
const MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
73297330
const TargetRegisterInfo *TRI =
73307331
Root.getMF()->getSubtarget().getRegisterInfo();
7331-
// Enable this only on Darwin targets, where it should be profitable. Other
7332-
// targets can remove this check if it is profitable there as well.
7333-
if (!Root.getMF()->getTarget().getTargetTriple().isOSDarwin())
7334-
return false;
7335-
7336-
// The pattern searches for loads into single lanes.
7337-
if (Root.getOpcode() != AArch64::LD1i32)
7338-
return false;
73397332

73407333
// The root of the pattern must load into the last lane of the vector.
7341-
if (Root.getOperand(2).getImm() != 3)
7334+
if (Root.getOperand(2).getImm() != NumLanes - 1)
73427335
return false;
73437336

73447337
// Check that we have load into all lanes except lane 0.
73457338
auto *CurrInstr = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
7346-
SmallSet<unsigned, 4> RemainingLanes({1, 2});
7339+
auto Range = llvm::seq<unsigned>(1, NumLanes - 1);
7340+
SmallSet<unsigned, 4> RemainingLanes(Range.begin(), Range.end());
73477341
while (RemainingLanes.begin() != RemainingLanes.end() &&
7348-
CurrInstr->getOpcode() == AArch64::LD1i32 &&
7342+
CurrInstr->getOpcode() == LoadLaneOpCode &&
73497343
MRI.hasOneNonDBGUse(CurrInstr->getOperand(0).getReg()) &&
73507344
CurrInstr->getNumOperands() == 4) {
73517345
RemainingLanes.erase(CurrInstr->getOperand(2).getImm());
@@ -7361,23 +7355,194 @@ static bool getLoadPatterns(MachineInstr &Root,
73617355

73627356
// Verify that the subreg to reg loads an i32 into the first lane.
73637357
auto Lane0LoadReg = CurrInstr->getOperand(2).getReg();
7364-
if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != 32)
7358+
unsigned SingleLaneSizeInBits = 128 / NumLanes;
7359+
if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != SingleLaneSizeInBits)
73657360
return false;
73667361

73677362
// Verify that it also has a single non debug use.
73687363
if (!MRI.hasOneNonDBGUse(Lane0LoadReg))
73697364
return false;
73707365

7371-
Patterns.push_back(AArch64MachineCombinerPattern::SPLIT_LD);
7366+
switch (NumLanes) {
7367+
case 4:
7368+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i32);
7369+
break;
7370+
case 8:
7371+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i16);
7372+
break;
7373+
case 16:
7374+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i8);
7375+
break;
7376+
default:
7377+
llvm_unreachable("Got bad number of lanes for gather pattern.");
7378+
}
7379+
73727380
return true;
73737381
}
73747382

7383+
/// Search for patterns where we use LD1i32 instructions to load into
7384+
/// 4 separate lanes of a 128 bit Neon register. We can increase ILP
7385+
/// by loading into 2 Neon registers instead.
7386+
static bool getLoadPatterns(MachineInstr &Root,
7387+
SmallVectorImpl<unsigned> &Patterns) {
7388+
// Enable this only on Darwin targets, where it should be profitable. Other
7389+
// targets can remove this check if it is profitable there as well.
7390+
if (!Root.getMF()->getTarget().getTargetTriple().isOSDarwin())
7391+
return false;
7392+
7393+
// The pattern searches for loads into single lanes.
7394+
switch (Root.getOpcode()) {
7395+
case AArch64::LD1i32:
7396+
return getGatherPattern(Root, Patterns, Root.getOpcode(), 4);
7397+
case AArch64::LD1i16:
7398+
return getGatherPattern(Root, Patterns, Root.getOpcode(), 8);
7399+
case AArch64::LD1i8:
7400+
return getGatherPattern(Root, Patterns, Root.getOpcode(), 16);
7401+
default:
7402+
return false;
7403+
}
7404+
}
7405+
7406+
static void generateGatherPattern(
7407+
MachineInstr &Root, SmallVectorImpl<MachineInstr *> &InsInstrs,
7408+
SmallVectorImpl<MachineInstr *> &DelInstrs,
7409+
DenseMap<Register, unsigned> &InstrIdxForVirtReg, unsigned Pattern,
7410+
unsigned NumLanes) {
7411+
7412+
MachineFunction &MF = *Root.getParent()->getParent();
7413+
MachineRegisterInfo &MRI = MF.getRegInfo();
7414+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
7415+
7416+
// Gather the initial load instructions to build the pattern
7417+
SmallVector<MachineInstr *, 16> LoadToLaneInstrs;
7418+
MachineInstr *CurrInstr = &Root;
7419+
for (unsigned i = 0; i < NumLanes - 1; ++i) {
7420+
LoadToLaneInstrs.push_back(CurrInstr);
7421+
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
7422+
}
7423+
7424+
MachineInstr *SubregToReg = CurrInstr;
7425+
LoadToLaneInstrs.push_back(MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg()));
7426+
auto OriginalLoadInstrs = llvm::reverse(LoadToLaneInstrs);
7427+
7428+
const TargetRegisterClass *FPR128RegClass =
7429+
MRI.getRegClass(Root.getOperand(0).getReg());
7430+
7431+
auto LoadLaneToRegister = [&](MachineInstr *OriginalInstr,
7432+
Register SrcRegister, unsigned Lane,
7433+
Register OffsetRegister) {
7434+
auto NewRegister = MRI.createVirtualRegister(FPR128RegClass);
7435+
MachineInstrBuilder LoadIndexIntoRegister =
7436+
BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()),
7437+
NewRegister)
7438+
.addReg(SrcRegister)
7439+
.addImm(Lane)
7440+
.addReg(OffsetRegister, getKillRegState(true));
7441+
InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size()));
7442+
InsInstrs.push_back(LoadIndexIntoRegister);
7443+
return NewRegister;
7444+
};
7445+
7446+
// Helper to create load instruction based on opcode
7447+
auto CreateLoadInstruction = [&](unsigned NumLanes, Register DestReg,
7448+
Register OffsetReg) -> MachineInstrBuilder {
7449+
unsigned Opcode;
7450+
switch (NumLanes) {
7451+
case 4:
7452+
Opcode = AArch64::LDRSui;
7453+
break;
7454+
case 8:
7455+
Opcode = AArch64::LDRHui;
7456+
break;
7457+
case 16:
7458+
Opcode = AArch64::LDRBui;
7459+
break;
7460+
default:
7461+
llvm_unreachable("Got unsupported number of lanes in machine-combiner gather pattern");
7462+
}
7463+
// Immediate offset load
7464+
return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg)
7465+
.addReg(OffsetReg)
7466+
.addImm(0); // immediate offset
7467+
};
7468+
7469+
// Load index 1 into register 0 lane 1
7470+
auto LanesToLoadToReg0 = llvm::make_range(OriginalLoadInstrs.begin() + 1, OriginalLoadInstrs.begin() + NumLanes / 2);
7471+
auto PrevReg = SubregToReg->getOperand(0).getReg();
7472+
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) {
7473+
PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1, LoadInstr->getOperand(3).getReg());
7474+
DelInstrs.push_back(LoadInstr);
7475+
}
7476+
auto LastLoadReg0 = PrevReg;
7477+
7478+
// Load index 2 into register 1 lane 0
7479+
auto Lane0Load = *OriginalLoadInstrs.begin();
7480+
auto OriginalSplitLoad = *std::next(OriginalLoadInstrs.begin(), NumLanes / 2);
7481+
auto DestRegForMiddleIndex = MRI.createVirtualRegister(
7482+
MRI.getRegClass(Lane0Load->getOperand(0).getReg()));
7483+
7484+
MachineInstrBuilder MiddleIndexLoadInstr = CreateLoadInstruction(
7485+
NumLanes, DestRegForMiddleIndex,
7486+
OriginalSplitLoad->getOperand(3).getReg());
7487+
7488+
InstrIdxForVirtReg.insert(std::make_pair(DestRegForMiddleIndex, InsInstrs.size()));
7489+
InsInstrs.push_back(MiddleIndexLoadInstr);
7490+
DelInstrs.push_back(OriginalSplitLoad);
7491+
7492+
// Convert fpr128 using subreg
7493+
auto DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass);
7494+
unsigned SubregType;
7495+
switch (NumLanes) {
7496+
case 4:
7497+
SubregType = AArch64::ssub;
7498+
break;
7499+
case 8:
7500+
SubregType = AArch64::hsub;
7501+
break;
7502+
case 16:
7503+
SubregType = AArch64::bsub;
7504+
break;
7505+
default:
7506+
llvm_unreachable("Got invalid NumLanes for machine-combiner gather pattern");
7507+
}
7508+
auto SubRegToRegInstr = BuildMI(MF, MIMetadata(Root),
7509+
TII->get(SubregToReg->getOpcode()),
7510+
DestRegForSubregToReg)
7511+
.addImm(0)
7512+
.addReg(DestRegForMiddleIndex, getKillRegState(true))
7513+
.addImm(SubregType);
7514+
InstrIdxForVirtReg.insert(std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
7515+
InsInstrs.push_back(SubRegToRegInstr);
7516+
7517+
// Load index 3 into register 1 lane 1
7518+
auto LanesToLoadToReg1 = llvm::make_range(OriginalLoadInstrs.begin() + NumLanes / 2 + 1, OriginalLoadInstrs.end());
7519+
PrevReg = SubRegToRegInstr->getOperand(0).getReg();
7520+
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) {
7521+
PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1, LoadInstr->getOperand(3).getReg());
7522+
if (Index == NumLanes / 2 - 2) {
7523+
break;
7524+
}
7525+
DelInstrs.push_back(LoadInstr);
7526+
}
7527+
auto LastLoadReg1 = PrevReg;
7528+
7529+
// Create the final zip instruction to combine the results
7530+
MachineInstrBuilder ZipInstr =
7531+
BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64),
7532+
Root.getOperand(0).getReg())
7533+
.addReg(LastLoadReg0)
7534+
.addReg(LastLoadReg1);
7535+
InsInstrs.push_back(ZipInstr);
7536+
}
7537+
73757538
CombinerObjective
73767539
AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const {
73777540
switch (Pattern) {
73787541
case AArch64MachineCombinerPattern::SUBADD_OP1:
73797542
case AArch64MachineCombinerPattern::SUBADD_OP2:
7380-
case AArch64MachineCombinerPattern::SPLIT_LD:
7543+
case AArch64MachineCombinerPattern::GATHER_i32:
7544+
case AArch64MachineCombinerPattern::GATHER_i16:
7545+
case AArch64MachineCombinerPattern::GATHER_i8:
73817546
return CombinerObjective::MustReduceDepth;
73827547
default:
73837548
return TargetInstrInfo::getCombinerObjective(Pattern);
@@ -8741,82 +8906,20 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
87418906
MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs);
87428907
break;
87438908
}
8744-
case AArch64MachineCombinerPattern::SPLIT_LD: {
8745-
// Gather the initial load instructions to build the pattern
8746-
MachineInstr *Lane2Load = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
8747-
MachineInstr *Lane1Load =
8748-
MRI.getUniqueVRegDef(Lane2Load->getOperand(1).getReg());
8749-
MachineInstr *SubregToReg =
8750-
MRI.getUniqueVRegDef(Lane1Load->getOperand(1).getReg());
8751-
MachineInstr *Lane0Load =
8752-
MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg());
8753-
8754-
const TargetRegisterClass *FPR128RegClass =
8755-
MRI.getRegClass(Root.getOperand(0).getReg());
8756-
8757-
auto LoadLaneToRegister = [&](MachineInstr *OriginalInstr,
8758-
Register SrcRegister, unsigned Lane,
8759-
Register OffsetRegister) {
8760-
auto NewRegister = MRI.createVirtualRegister(FPR128RegClass);
8761-
MachineInstrBuilder LoadIndexIntoRegister =
8762-
BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()),
8763-
NewRegister)
8764-
.addReg(SrcRegister)
8765-
.addImm(Lane)
8766-
.addReg(OffsetRegister, getKillRegState(true));
8767-
InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size()));
8768-
InsInstrs.push_back(LoadIndexIntoRegister);
8769-
return NewRegister;
8770-
};
8771-
8772-
// Helper to create load instruction based on opcode
8773-
auto CreateLoadInstruction = [&](unsigned Opcode, Register DestReg,
8774-
Register OffsetReg) -> MachineInstrBuilder {
8775-
return BuildMI(MF, MIMetadata(Root), TII->get(AArch64::LDRSui), DestReg)
8776-
.addReg(OffsetReg)
8777-
.addImm(0); // immediate offset
8778-
};
8779-
8780-
// Load index 1 into register 0 lane 1
8781-
Register Index1LoadReg =
8782-
LoadLaneToRegister(Lane1Load, SubregToReg->getOperand(0).getReg(), 1,
8783-
Lane1Load->getOperand(3).getReg());
8784-
DelInstrs.push_back(Lane1Load);
8785-
8786-
// Load index 2 into register 1 lane 0
8787-
auto DestRegForIndex2 = MRI.createVirtualRegister(
8788-
MRI.getRegClass(Lane0Load->getOperand(0).getReg()));
8789-
8790-
MachineInstrBuilder Index2LoadInstr = CreateLoadInstruction(
8791-
Lane0Load->getOpcode(), DestRegForIndex2,
8792-
Lane2Load->getOperand(3).getReg());
8793-
8794-
InstrIdxForVirtReg.insert(std::make_pair(DestRegForIndex2, InsInstrs.size()));
8795-
InsInstrs.push_back(Index2LoadInstr);
8796-
DelInstrs.push_back(Lane2Load);
8797-
8798-
// Convert fpr32 to fpr128 using subreg
8799-
auto DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass);
8800-
auto SubRegToRegInstr = BuildMI(MF, MIMetadata(Root),
8801-
TII->get(SubregToReg->getOpcode()),
8802-
DestRegForSubregToReg)
8803-
.addImm(0)
8804-
.addReg(DestRegForIndex2, getKillRegState(true))
8805-
.addImm(AArch64::ssub);
8806-
InstrIdxForVirtReg.insert(std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
8807-
InsInstrs.push_back(SubRegToRegInstr);
8808-
8809-
// Load index 3 into register 1 lane 1
8810-
auto Index3LoadReg = LoadLaneToRegister(&Root, DestRegForSubregToReg, 1,
8811-
Root.getOperand(3).getReg());
8812-
8813-
// Create the final zip instruction to combine the results
8814-
MachineInstrBuilder ZipInstr =
8815-
BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64),
8816-
Root.getOperand(0).getReg())
8817-
.addReg(Index1LoadReg)
8818-
.addReg(Index3LoadReg);
8819-
InsInstrs.push_back(ZipInstr);
8909+
case AArch64MachineCombinerPattern::GATHER_i32: {
8910+
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, Pattern, 4);
8911+
for (const auto Instr : DelInstrs) {
8912+
Instr->print(llvm::errs());
8913+
}
8914+
break;
8915+
}
8916+
case AArch64MachineCombinerPattern::GATHER_i16: {
8917+
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, Pattern, 8);
8918+
break;
8919+
}
8920+
case AArch64MachineCombinerPattern::GATHER_i8: {
8921+
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, Pattern, 16);
8922+
break;
88208923
}
88218924

88228925
} // end switch (Pattern)

llvm/lib/Target/AArch64/AArch64InstrInfo.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ enum AArch64MachineCombinerPattern : unsigned {
173173

174174
FNMADD,
175175

176-
SPLIT_LD,
176+
GATHER_i32 = 890,
177+
GATHER_i16,
178+
GATHER_i8
177179
};
178180
class AArch64InstrInfo final : public AArch64GenInstrInfo {
179181
const AArch64RegisterInfo RI;

0 commit comments

Comments
 (0)