-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[AArch64][MachineCombiner] Combine sequences of gather patterns #152979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
2673ecf
c0ef5b7
8086280
3a6998a
fd49786
87e425c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,9 @@ | |
| #include "Utils/AArch64BaseInfo.h" | ||
| #include "llvm/ADT/ArrayRef.h" | ||
| #include "llvm/ADT/STLExtras.h" | ||
| #include "llvm/ADT/SmallSet.h" | ||
| #include "llvm/ADT/SmallVector.h" | ||
| #include "llvm/Analysis/AliasAnalysis.h" | ||
| #include "llvm/CodeGen/CFIInstBuilder.h" | ||
| #include "llvm/CodeGen/LivePhysRegs.h" | ||
| #include "llvm/CodeGen/MachineBasicBlock.h" | ||
|
|
@@ -83,6 +85,11 @@ static cl::opt<unsigned> | |
| BDisplacementBits("aarch64-b-offset-bits", cl::Hidden, cl::init(26), | ||
| cl::desc("Restrict range of B instructions (DEBUG)")); | ||
|
|
||
| static cl::opt<unsigned> GatherOptSearchLimit( | ||
| "aarch64-search-limit", cl::Hidden, cl::init(2048), | ||
| cl::desc("Restrict range of instructions to search for the " | ||
| "machine-combiner gather pattern optimization")); | ||
|
|
||
| AArch64InstrInfo::AArch64InstrInfo(const AArch64Subtarget &STI) | ||
| : AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP, | ||
| AArch64::CATCHRET), | ||
|
|
@@ -7412,11 +7419,335 @@ static bool getMiscPatterns(MachineInstr &Root, | |
| return false; | ||
| } | ||
|
|
||
| /// Check if a given MachineInstr `MIa` may alias with any of the instructions | ||
| /// in `MemInstrs`. | ||
| static bool mayAlias(const MachineInstr &MIa, | ||
| SmallVectorImpl<const MachineInstr *> &MemInstrs, | ||
| AliasAnalysis *AA) { | ||
| for (const MachineInstr *MIb : MemInstrs) { | ||
| if (MIa.mayAlias(AA, *MIb, /*UseTBAA*/ false)) { | ||
| MIb->dump(); | ||
| return true; | ||
| } | ||
| } | ||
|
|
||
| return false; | ||
| } | ||
|
|
||
| /// Check if the given instruction forms a gather load pattern that can be | ||
| /// optimized for better Memory-Level Parallelism (MLP). This function | ||
| /// identifies chains of NEON lane load instructions that load data from | ||
| /// different memory addresses into individual lanes of a 128-bit vector | ||
| /// register, then attempts to split the pattern into parallel loads to break | ||
| /// the serial dependency between instructions. | ||
| /// | ||
| /// Pattern Matched: | ||
| /// Initial scalar load -> SUBREG_TO_REG (lane 0) -> LD1i* (lane 1) -> | ||
| /// LD1i* (lane 2) -> ... -> LD1i* (lane N-1, Root) | ||
| /// | ||
| /// Transformed Into: | ||
| /// Two parallel vector loads using fewer lanes each, followed by ZIP1v2i64 | ||
| /// to combine the results, enabling better memory-level parallelism. | ||
| /// | ||
| /// Supported Element Types: | ||
| /// - 32-bit elements (LD1i32, 4 lanes total) | ||
| /// - 16-bit elements (LD1i16, 8 lanes total) | ||
| /// - 8-bit elements (LD1i8, 16 lanes total) | ||
| static bool getGatherLanePattern(MachineInstr &Root, | ||
| SmallVectorImpl<unsigned> &Patterns, | ||
| unsigned LoadLaneOpCode, unsigned NumLanes) { | ||
| const MachineFunction *MF = Root.getMF(); | ||
|
|
||
| // Early exit if optimizing for size. | ||
| if (MF->getFunction().hasMinSize()) | ||
| return false; | ||
|
|
||
| const MachineRegisterInfo &MRI = MF->getRegInfo(); | ||
| const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo(); | ||
|
|
||
| // The root of the pattern must load into the last lane of the vector. | ||
| if (Root.getOperand(2).getImm() != NumLanes - 1) | ||
| return false; | ||
|
|
||
| // Check that we have load into all lanes except lane 0. | ||
| // For each load we also want to check that: | ||
| // 1. It has a single non-debug use (since we will be replacing the virtual | ||
| // register) | ||
| // 2. That the addressing mode only uses a single pointer operand | ||
| auto *CurrInstr = MRI.getUniqueVRegDef(Root.getOperand(1).getReg()); | ||
| auto Range = llvm::seq<unsigned>(1, NumLanes - 1); | ||
| SmallSet<unsigned, 16> RemainingLanes(Range.begin(), Range.end()); | ||
| SmallVector<const MachineInstr *, 16> LoadInstrs = {}; | ||
jcohen-apple marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| while (!RemainingLanes.empty() && CurrInstr && | ||
| CurrInstr->getOpcode() == LoadLaneOpCode && | ||
| MRI.hasOneNonDBGUse(CurrInstr->getOperand(0).getReg()) && | ||
| CurrInstr->getNumOperands() == 4) { | ||
| RemainingLanes.erase(CurrInstr->getOperand(2).getImm()); | ||
| LoadInstrs.push_back(CurrInstr); | ||
| CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg()); | ||
| } | ||
|
|
||
| // Check that we have found a match for lanes N-1.. 1. | ||
| if (!RemainingLanes.empty()) | ||
| return false; | ||
|
|
||
| // Match the SUBREG_TO_REG sequence. | ||
| if (CurrInstr->getOpcode() != TargetOpcode::SUBREG_TO_REG) | ||
| return false; | ||
|
|
||
| // Verify that the subreg to reg loads an integer into the first lane. | ||
| auto Lane0LoadReg = CurrInstr->getOperand(2).getReg(); | ||
| unsigned SingleLaneSizeInBits = 128 / NumLanes; | ||
| if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != SingleLaneSizeInBits) | ||
| return false; | ||
|
|
||
| // Verify that it also has a single non debug use. | ||
| if (!MRI.hasOneNonDBGUse(Lane0LoadReg)) | ||
| return false; | ||
|
|
||
| LoadInstrs.push_back(MRI.getUniqueVRegDef(Lane0LoadReg)); | ||
|
|
||
| // If there is any chance of aliasing, do not apply the pattern. | ||
| // Walk backward through the MBB starting from Root. | ||
| // Exit early if we've encountered all load instructions or hit the search | ||
| // limit. | ||
| auto MBBItr = Root.getIterator(); | ||
| unsigned RemainingSteps = GatherOptSearchLimit; | ||
| SmallSet<const MachineInstr *, 16> RemainingLoadInstrs; | ||
| RemainingLoadInstrs.insert(LoadInstrs.begin(), LoadInstrs.end()); | ||
| const MachineBasicBlock *MBB = Root.getParent(); | ||
|
|
||
| for (; MBBItr != MBB->begin() && RemainingSteps > 0 && | ||
| !RemainingLoadInstrs.empty(); | ||
| --MBBItr, --RemainingSteps) { | ||
| const MachineInstr &CurrInstr = *MBBItr; | ||
|
|
||
| // Remove this instruction from remaining loads if it's one we're tracking. | ||
| RemainingLoadInstrs.erase(&CurrInstr); | ||
|
|
||
| // Check for potential aliasing with any of the load instructions to | ||
| // optimize. | ||
| if ((CurrInstr.mayLoadOrStore() || CurrInstr.isCall()) && | ||
| mayAlias(CurrInstr, LoadInstrs, nullptr)) | ||
|
||
| return false; | ||
| } | ||
|
|
||
| // If we hit the search limit without finding all load instructions, | ||
| // don't match the pattern. | ||
| if (RemainingSteps == 0 && !RemainingLoadInstrs.empty()) | ||
| return false; | ||
|
|
||
| switch (NumLanes) { | ||
| case 4: | ||
| Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i32); | ||
| break; | ||
| case 8: | ||
| Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i16); | ||
| break; | ||
| case 16: | ||
| Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i8); | ||
| break; | ||
| default: | ||
| llvm_unreachable("Got bad number of lanes for gather pattern."); | ||
| } | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| /// Search for patterns of LD instructions we can optimize. | ||
| static bool getLoadPatterns(MachineInstr &Root, | ||
| SmallVectorImpl<unsigned> &Patterns) { | ||
|
|
||
| // The pattern searches for loads into single lanes. | ||
| switch (Root.getOpcode()) { | ||
| case AArch64::LD1i32: | ||
| return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 4); | ||
| case AArch64::LD1i16: | ||
| return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 8); | ||
| case AArch64::LD1i8: | ||
| return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 16); | ||
| default: | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| /// Generate optimized instruction sequence for gather load patterns to improve | ||
| /// Memory-Level Parallelism (MLP). This function transforms a chain of | ||
| /// sequential NEON lane loads into parallel vector loads that can execute | ||
| /// concurrently. | ||
| static void | ||
| generateGatherLanePattern(MachineInstr &Root, | ||
| SmallVectorImpl<MachineInstr *> &InsInstrs, | ||
| SmallVectorImpl<MachineInstr *> &DelInstrs, | ||
| DenseMap<Register, unsigned> &InstrIdxForVirtReg, | ||
| unsigned Pattern, unsigned NumLanes) { | ||
| MachineFunction &MF = *Root.getParent()->getParent(); | ||
| MachineRegisterInfo &MRI = MF.getRegInfo(); | ||
| const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); | ||
|
|
||
| // Gather the initial load instructions to build the pattern. | ||
| SmallVector<MachineInstr *, 16> LoadToLaneInstrs; | ||
| MachineInstr *CurrInstr = &Root; | ||
| for (unsigned i = 0; i < NumLanes - 1; ++i) { | ||
| LoadToLaneInstrs.push_back(CurrInstr); | ||
| CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg()); | ||
| } | ||
|
|
||
| // Sort the load instructions according to the lane. | ||
| llvm::sort(LoadToLaneInstrs, | ||
| [](const MachineInstr *A, const MachineInstr *B) { | ||
| return A->getOperand(2).getImm() > B->getOperand(2).getImm(); | ||
| }); | ||
|
|
||
| MachineInstr *SubregToReg = CurrInstr; | ||
| LoadToLaneInstrs.push_back( | ||
| MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg())); | ||
| auto LoadToLaneInstrsAscending = llvm::reverse(LoadToLaneInstrs); | ||
|
|
||
| const TargetRegisterClass *FPR128RegClass = | ||
| MRI.getRegClass(Root.getOperand(0).getReg()); | ||
|
|
||
| // Helper lambda to create a LD1 instruction. | ||
| auto CreateLD1Instruction = [&](MachineInstr *OriginalInstr, | ||
| Register SrcRegister, unsigned Lane, | ||
| Register OffsetRegister, | ||
| bool OffsetRegisterKillState) { | ||
| auto NewRegister = MRI.createVirtualRegister(FPR128RegClass); | ||
| MachineInstrBuilder LoadIndexIntoRegister = | ||
| BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()), | ||
| NewRegister) | ||
| .addReg(SrcRegister) | ||
| .addImm(Lane) | ||
| .addReg(OffsetRegister, getKillRegState(OffsetRegisterKillState)); | ||
| InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size())); | ||
| InsInstrs.push_back(LoadIndexIntoRegister); | ||
| return NewRegister; | ||
| }; | ||
|
|
||
| // Helper to create load instruction based on the NumLanes in the NEON | ||
| // register we are rewriting. | ||
| auto CreateLDRInstruction = [&](unsigned NumLanes, Register DestReg, | ||
| Register OffsetReg, | ||
| bool KillState) -> MachineInstrBuilder { | ||
| unsigned Opcode; | ||
| switch (NumLanes) { | ||
| case 4: | ||
| Opcode = AArch64::LDRSui; | ||
| break; | ||
| case 8: | ||
| Opcode = AArch64::LDRHui; | ||
| break; | ||
| case 16: | ||
| Opcode = AArch64::LDRBui; | ||
| break; | ||
| default: | ||
| llvm_unreachable( | ||
| "Got unsupported number of lanes in machine-combiner gather pattern"); | ||
| } | ||
| // Immediate offset load | ||
| return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg) | ||
| .addReg(OffsetReg) | ||
| .addImm(0); | ||
| }; | ||
|
|
||
| // Load the remaining lanes into register 0. | ||
| auto LanesToLoadToReg0 = | ||
| llvm::make_range(LoadToLaneInstrsAscending.begin() + 1, | ||
| LoadToLaneInstrsAscending.begin() + NumLanes / 2); | ||
| Register PrevReg = SubregToReg->getOperand(0).getReg(); | ||
| for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) { | ||
| const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3); | ||
| PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1, | ||
| OffsetRegOperand.getReg(), | ||
| OffsetRegOperand.isKill()); | ||
| DelInstrs.push_back(LoadInstr); | ||
| } | ||
| Register LastLoadReg0 = PrevReg; | ||
|
|
||
| // First load into register 1. Perform an integer load to zero out the upper | ||
| // lanes in a single instruction. | ||
| MachineInstr *Lane0Load = *LoadToLaneInstrsAscending.begin(); | ||
| MachineInstr *OriginalSplitLoad = | ||
| *std::next(LoadToLaneInstrsAscending.begin(), NumLanes / 2); | ||
| Register DestRegForMiddleIndex = MRI.createVirtualRegister( | ||
| MRI.getRegClass(Lane0Load->getOperand(0).getReg())); | ||
|
|
||
| const MachineOperand &OriginalSplitToLoadOffsetOperand = | ||
| OriginalSplitLoad->getOperand(3); | ||
| MachineInstrBuilder MiddleIndexLoadInstr = | ||
| CreateLDRInstruction(NumLanes, DestRegForMiddleIndex, | ||
| OriginalSplitToLoadOffsetOperand.getReg(), | ||
| OriginalSplitToLoadOffsetOperand.isKill()); | ||
|
|
||
| InstrIdxForVirtReg.insert( | ||
| std::make_pair(DestRegForMiddleIndex, InsInstrs.size())); | ||
| InsInstrs.push_back(MiddleIndexLoadInstr); | ||
| DelInstrs.push_back(OriginalSplitLoad); | ||
|
|
||
| // Subreg To Reg instruction for register 1. | ||
| Register DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass); | ||
| unsigned SubregType; | ||
| switch (NumLanes) { | ||
| case 4: | ||
| SubregType = AArch64::ssub; | ||
| break; | ||
| case 8: | ||
| SubregType = AArch64::hsub; | ||
| break; | ||
| case 16: | ||
| SubregType = AArch64::bsub; | ||
| break; | ||
| default: | ||
| llvm_unreachable( | ||
| "Got invalid NumLanes for machine-combiner gather pattern"); | ||
| } | ||
|
|
||
| auto SubRegToRegInstr = | ||
| BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()), | ||
| DestRegForSubregToReg) | ||
| .addImm(0) | ||
| .addReg(DestRegForMiddleIndex, getKillRegState(true)) | ||
| .addImm(SubregType); | ||
| InstrIdxForVirtReg.insert( | ||
| std::make_pair(DestRegForSubregToReg, InsInstrs.size())); | ||
| InsInstrs.push_back(SubRegToRegInstr); | ||
|
|
||
| // Load remaining lanes into register 1. | ||
| auto LanesToLoadToReg1 = | ||
| llvm::make_range(LoadToLaneInstrsAscending.begin() + NumLanes / 2 + 1, | ||
| LoadToLaneInstrsAscending.end()); | ||
| PrevReg = SubRegToRegInstr->getOperand(0).getReg(); | ||
| for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) { | ||
| const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3); | ||
| PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1, | ||
| OffsetRegOperand.getReg(), | ||
| OffsetRegOperand.isKill()); | ||
|
|
||
| // Do not add the last reg to DelInstrs - it will be removed later. | ||
| if (Index == NumLanes / 2 - 2) { | ||
| break; | ||
| } | ||
| DelInstrs.push_back(LoadInstr); | ||
| } | ||
| Register LastLoadReg1 = PrevReg; | ||
|
|
||
| // Create the final zip instruction to combine the results. | ||
| MachineInstrBuilder ZipInstr = | ||
| BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64), | ||
| Root.getOperand(0).getReg()) | ||
| .addReg(LastLoadReg0) | ||
| .addReg(LastLoadReg1); | ||
| InsInstrs.push_back(ZipInstr); | ||
| } | ||
|
|
||
| CombinerObjective | ||
| AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const { | ||
| switch (Pattern) { | ||
| case AArch64MachineCombinerPattern::SUBADD_OP1: | ||
| case AArch64MachineCombinerPattern::SUBADD_OP2: | ||
| case AArch64MachineCombinerPattern::GATHER_LANE_i32: | ||
| case AArch64MachineCombinerPattern::GATHER_LANE_i16: | ||
| case AArch64MachineCombinerPattern::GATHER_LANE_i8: | ||
| return CombinerObjective::MustReduceDepth; | ||
| default: | ||
| return TargetInstrInfo::getCombinerObjective(Pattern); | ||
|
|
@@ -7446,6 +7777,10 @@ bool AArch64InstrInfo::getMachineCombinerPatterns( | |
| if (getMiscPatterns(Root, Patterns)) | ||
| return true; | ||
|
|
||
| // Load patterns | ||
| if (getLoadPatterns(Root, Patterns)) | ||
| return true; | ||
|
|
||
| return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns, | ||
| DoRegPressureReduce); | ||
| } | ||
|
|
@@ -8701,6 +9036,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence( | |
| MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs); | ||
| break; | ||
| } | ||
| case AArch64MachineCombinerPattern::GATHER_LANE_i32: { | ||
| generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, | ||
| Pattern, 4); | ||
| break; | ||
| } | ||
| case AArch64MachineCombinerPattern::GATHER_LANE_i16: { | ||
| generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, | ||
| Pattern, 8); | ||
| break; | ||
| } | ||
| case AArch64MachineCombinerPattern::GATHER_LANE_i8: { | ||
| generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, | ||
| Pattern, 16); | ||
| break; | ||
| } | ||
|
|
||
| } // end switch (Pattern) | ||
| // Record MUL and ADD/SUB for deletion | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.