|
20 | 20 | #include "Utils/AArch64BaseInfo.h" |
21 | 21 | #include "llvm/ADT/ArrayRef.h" |
22 | 22 | #include "llvm/ADT/STLExtras.h" |
| 23 | +#include "llvm/ADT/SmallSet.h" |
23 | 24 | #include "llvm/ADT/SmallVector.h" |
| 25 | +#include "llvm/Analysis/AliasAnalysis.h" |
24 | 26 | #include "llvm/CodeGen/CFIInstBuilder.h" |
25 | 27 | #include "llvm/CodeGen/LivePhysRegs.h" |
26 | 28 | #include "llvm/CodeGen/MachineBasicBlock.h" |
@@ -83,6 +85,11 @@ static cl::opt<unsigned> |
83 | 85 | BDisplacementBits("aarch64-b-offset-bits", cl::Hidden, cl::init(26), |
84 | 86 | cl::desc("Restrict range of B instructions (DEBUG)")); |
85 | 87 |
|
| 88 | +static cl::opt<unsigned> GatherOptSearchLimit( |
| 89 | + "aarch64-search-limit", cl::Hidden, cl::init(2048), |
| 90 | + cl::desc("Restrict range of instructions to search for the " |
| 91 | + "machine-combiner gather pattern optimization")); |
| 92 | + |
86 | 93 | AArch64InstrInfo::AArch64InstrInfo(const AArch64Subtarget &STI) |
87 | 94 | : AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP, |
88 | 95 | AArch64::CATCHRET), |
@@ -7485,11 +7492,335 @@ static bool getMiscPatterns(MachineInstr &Root, |
7485 | 7492 | return false; |
7486 | 7493 | } |
7487 | 7494 |
|
| 7495 | +/// Check if a given MachineInstr `MIa` may alias with any of the instructions |
| 7496 | +/// in `MemInstrs`. |
| 7497 | +static bool mayAlias(const MachineInstr &MIa, |
| 7498 | + SmallVectorImpl<const MachineInstr *> &MemInstrs, |
| 7499 | + AliasAnalysis *AA) { |
| 7500 | + for (const MachineInstr *MIb : MemInstrs) { |
| 7501 | + if (MIa.mayAlias(AA, *MIb, /*UseTBAA*/ false)) { |
| 7502 | + MIb->dump(); |
| 7503 | + return true; |
| 7504 | + } |
| 7505 | + } |
| 7506 | + |
| 7507 | + return false; |
| 7508 | +} |
| 7509 | + |
| 7510 | +/// Check if the given instruction forms a gather load pattern that can be |
| 7511 | +/// optimized for better Memory-Level Parallelism (MLP). This function |
| 7512 | +/// identifies chains of NEON lane load instructions that load data from |
| 7513 | +/// different memory addresses into individual lanes of a 128-bit vector |
| 7514 | +/// register, then attempts to split the pattern into parallel loads to break |
| 7515 | +/// the serial dependency between instructions. |
| 7516 | +/// |
| 7517 | +/// Pattern Matched: |
| 7518 | +/// Initial scalar load -> SUBREG_TO_REG (lane 0) -> LD1i* (lane 1) -> |
| 7519 | +/// LD1i* (lane 2) -> ... -> LD1i* (lane N-1, Root) |
| 7520 | +/// |
| 7521 | +/// Transformed Into: |
| 7522 | +/// Two parallel vector loads using fewer lanes each, followed by ZIP1v2i64 |
| 7523 | +/// to combine the results, enabling better memory-level parallelism. |
| 7524 | +/// |
| 7525 | +/// Supported Element Types: |
| 7526 | +/// - 32-bit elements (LD1i32, 4 lanes total) |
| 7527 | +/// - 16-bit elements (LD1i16, 8 lanes total) |
| 7528 | +/// - 8-bit elements (LD1i8, 16 lanes total) |
| 7529 | +static bool getGatherLanePattern(MachineInstr &Root, |
| 7530 | + SmallVectorImpl<unsigned> &Patterns, |
| 7531 | + unsigned LoadLaneOpCode, unsigned NumLanes) { |
| 7532 | + const MachineFunction *MF = Root.getMF(); |
| 7533 | + |
| 7534 | + // Early exit if optimizing for size. |
| 7535 | + if (MF->getFunction().hasMinSize()) |
| 7536 | + return false; |
| 7537 | + |
| 7538 | + const MachineRegisterInfo &MRI = MF->getRegInfo(); |
| 7539 | + const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo(); |
| 7540 | + |
| 7541 | + // The root of the pattern must load into the last lane of the vector. |
| 7542 | + if (Root.getOperand(2).getImm() != NumLanes - 1) |
| 7543 | + return false; |
| 7544 | + |
| 7545 | + // Check that we have load into all lanes except lane 0. |
| 7546 | + // For each load we also want to check that: |
| 7547 | + // 1. It has a single non-debug use (since we will be replacing the virtual |
| 7548 | + // register) |
| 7549 | + // 2. That the addressing mode only uses a single pointer operand |
| 7550 | + auto *CurrInstr = MRI.getUniqueVRegDef(Root.getOperand(1).getReg()); |
| 7551 | + auto Range = llvm::seq<unsigned>(1, NumLanes - 1); |
| 7552 | + SmallSet<unsigned, 16> RemainingLanes(Range.begin(), Range.end()); |
| 7553 | + SmallVector<const MachineInstr *, 16> LoadInstrs = {}; |
| 7554 | + while (!RemainingLanes.empty() && CurrInstr && |
| 7555 | + CurrInstr->getOpcode() == LoadLaneOpCode && |
| 7556 | + MRI.hasOneNonDBGUse(CurrInstr->getOperand(0).getReg()) && |
| 7557 | + CurrInstr->getNumOperands() == 4) { |
| 7558 | + RemainingLanes.erase(CurrInstr->getOperand(2).getImm()); |
| 7559 | + LoadInstrs.push_back(CurrInstr); |
| 7560 | + CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg()); |
| 7561 | + } |
| 7562 | + |
| 7563 | + // Check that we have found a match for lanes N-1.. 1. |
| 7564 | + if (!RemainingLanes.empty()) |
| 7565 | + return false; |
| 7566 | + |
| 7567 | + // Match the SUBREG_TO_REG sequence. |
| 7568 | + if (CurrInstr->getOpcode() != TargetOpcode::SUBREG_TO_REG) |
| 7569 | + return false; |
| 7570 | + |
| 7571 | + // Verify that the subreg to reg loads an integer into the first lane. |
| 7572 | + auto Lane0LoadReg = CurrInstr->getOperand(2).getReg(); |
| 7573 | + unsigned SingleLaneSizeInBits = 128 / NumLanes; |
| 7574 | + if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != SingleLaneSizeInBits) |
| 7575 | + return false; |
| 7576 | + |
| 7577 | + // Verify that it also has a single non debug use. |
| 7578 | + if (!MRI.hasOneNonDBGUse(Lane0LoadReg)) |
| 7579 | + return false; |
| 7580 | + |
| 7581 | + LoadInstrs.push_back(MRI.getUniqueVRegDef(Lane0LoadReg)); |
| 7582 | + |
| 7583 | + // If there is any chance of aliasing, do not apply the pattern. |
| 7584 | + // Walk backward through the MBB starting from Root. |
| 7585 | + // Exit early if we've encountered all load instructions or hit the search |
| 7586 | + // limit. |
| 7587 | + auto MBBItr = Root.getIterator(); |
| 7588 | + unsigned RemainingSteps = GatherOptSearchLimit; |
| 7589 | + SmallSet<const MachineInstr *, 16> RemainingLoadInstrs; |
| 7590 | + RemainingLoadInstrs.insert(LoadInstrs.begin(), LoadInstrs.end()); |
| 7591 | + const MachineBasicBlock *MBB = Root.getParent(); |
| 7592 | + |
| 7593 | + for (; MBBItr != MBB->begin() && RemainingSteps > 0 && |
| 7594 | + !RemainingLoadInstrs.empty(); |
| 7595 | + --MBBItr, --RemainingSteps) { |
| 7596 | + const MachineInstr &CurrInstr = *MBBItr; |
| 7597 | + |
| 7598 | + // Remove this instruction from remaining loads if it's one we're tracking. |
| 7599 | + RemainingLoadInstrs.erase(&CurrInstr); |
| 7600 | + |
| 7601 | + // Check for potential aliasing with any of the load instructions to |
| 7602 | + // optimize. |
| 7603 | + if ((CurrInstr.mayLoadOrStore() || CurrInstr.isCall()) && |
| 7604 | + mayAlias(CurrInstr, LoadInstrs, nullptr)) |
| 7605 | + return false; |
| 7606 | + } |
| 7607 | + |
| 7608 | + // If we hit the search limit without finding all load instructions, |
| 7609 | + // don't match the pattern. |
| 7610 | + if (RemainingSteps == 0 && !RemainingLoadInstrs.empty()) |
| 7611 | + return false; |
| 7612 | + |
| 7613 | + switch (NumLanes) { |
| 7614 | + case 4: |
| 7615 | + Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i32); |
| 7616 | + break; |
| 7617 | + case 8: |
| 7618 | + Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i16); |
| 7619 | + break; |
| 7620 | + case 16: |
| 7621 | + Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i8); |
| 7622 | + break; |
| 7623 | + default: |
| 7624 | + llvm_unreachable("Got bad number of lanes for gather pattern."); |
| 7625 | + } |
| 7626 | + |
| 7627 | + return true; |
| 7628 | +} |
| 7629 | + |
| 7630 | +/// Search for patterns of LD instructions we can optimize. |
| 7631 | +static bool getLoadPatterns(MachineInstr &Root, |
| 7632 | + SmallVectorImpl<unsigned> &Patterns) { |
| 7633 | + |
| 7634 | + // The pattern searches for loads into single lanes. |
| 7635 | + switch (Root.getOpcode()) { |
| 7636 | + case AArch64::LD1i32: |
| 7637 | + return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 4); |
| 7638 | + case AArch64::LD1i16: |
| 7639 | + return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 8); |
| 7640 | + case AArch64::LD1i8: |
| 7641 | + return getGatherLanePattern(Root, Patterns, Root.getOpcode(), 16); |
| 7642 | + default: |
| 7643 | + return false; |
| 7644 | + } |
| 7645 | +} |
| 7646 | + |
| 7647 | +/// Generate optimized instruction sequence for gather load patterns to improve |
| 7648 | +/// Memory-Level Parallelism (MLP). This function transforms a chain of |
| 7649 | +/// sequential NEON lane loads into parallel vector loads that can execute |
| 7650 | +/// concurrently. |
| 7651 | +static void |
| 7652 | +generateGatherLanePattern(MachineInstr &Root, |
| 7653 | + SmallVectorImpl<MachineInstr *> &InsInstrs, |
| 7654 | + SmallVectorImpl<MachineInstr *> &DelInstrs, |
| 7655 | + DenseMap<Register, unsigned> &InstrIdxForVirtReg, |
| 7656 | + unsigned Pattern, unsigned NumLanes) { |
| 7657 | + MachineFunction &MF = *Root.getParent()->getParent(); |
| 7658 | + MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 7659 | + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); |
| 7660 | + |
| 7661 | + // Gather the initial load instructions to build the pattern. |
| 7662 | + SmallVector<MachineInstr *, 16> LoadToLaneInstrs; |
| 7663 | + MachineInstr *CurrInstr = &Root; |
| 7664 | + for (unsigned i = 0; i < NumLanes - 1; ++i) { |
| 7665 | + LoadToLaneInstrs.push_back(CurrInstr); |
| 7666 | + CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg()); |
| 7667 | + } |
| 7668 | + |
| 7669 | + // Sort the load instructions according to the lane. |
| 7670 | + llvm::sort(LoadToLaneInstrs, |
| 7671 | + [](const MachineInstr *A, const MachineInstr *B) { |
| 7672 | + return A->getOperand(2).getImm() > B->getOperand(2).getImm(); |
| 7673 | + }); |
| 7674 | + |
| 7675 | + MachineInstr *SubregToReg = CurrInstr; |
| 7676 | + LoadToLaneInstrs.push_back( |
| 7677 | + MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg())); |
| 7678 | + auto LoadToLaneInstrsAscending = llvm::reverse(LoadToLaneInstrs); |
| 7679 | + |
| 7680 | + const TargetRegisterClass *FPR128RegClass = |
| 7681 | + MRI.getRegClass(Root.getOperand(0).getReg()); |
| 7682 | + |
| 7683 | + // Helper lambda to create a LD1 instruction. |
| 7684 | + auto CreateLD1Instruction = [&](MachineInstr *OriginalInstr, |
| 7685 | + Register SrcRegister, unsigned Lane, |
| 7686 | + Register OffsetRegister, |
| 7687 | + bool OffsetRegisterKillState) { |
| 7688 | + auto NewRegister = MRI.createVirtualRegister(FPR128RegClass); |
| 7689 | + MachineInstrBuilder LoadIndexIntoRegister = |
| 7690 | + BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()), |
| 7691 | + NewRegister) |
| 7692 | + .addReg(SrcRegister) |
| 7693 | + .addImm(Lane) |
| 7694 | + .addReg(OffsetRegister, getKillRegState(OffsetRegisterKillState)); |
| 7695 | + InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size())); |
| 7696 | + InsInstrs.push_back(LoadIndexIntoRegister); |
| 7697 | + return NewRegister; |
| 7698 | + }; |
| 7699 | + |
| 7700 | + // Helper to create load instruction based on the NumLanes in the NEON |
| 7701 | + // register we are rewriting. |
| 7702 | + auto CreateLDRInstruction = [&](unsigned NumLanes, Register DestReg, |
| 7703 | + Register OffsetReg, |
| 7704 | + bool KillState) -> MachineInstrBuilder { |
| 7705 | + unsigned Opcode; |
| 7706 | + switch (NumLanes) { |
| 7707 | + case 4: |
| 7708 | + Opcode = AArch64::LDRSui; |
| 7709 | + break; |
| 7710 | + case 8: |
| 7711 | + Opcode = AArch64::LDRHui; |
| 7712 | + break; |
| 7713 | + case 16: |
| 7714 | + Opcode = AArch64::LDRBui; |
| 7715 | + break; |
| 7716 | + default: |
| 7717 | + llvm_unreachable( |
| 7718 | + "Got unsupported number of lanes in machine-combiner gather pattern"); |
| 7719 | + } |
| 7720 | + // Immediate offset load |
| 7721 | + return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg) |
| 7722 | + .addReg(OffsetReg) |
| 7723 | + .addImm(0); |
| 7724 | + }; |
| 7725 | + |
| 7726 | + // Load the remaining lanes into register 0. |
| 7727 | + auto LanesToLoadToReg0 = |
| 7728 | + llvm::make_range(LoadToLaneInstrsAscending.begin() + 1, |
| 7729 | + LoadToLaneInstrsAscending.begin() + NumLanes / 2); |
| 7730 | + Register PrevReg = SubregToReg->getOperand(0).getReg(); |
| 7731 | + for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) { |
| 7732 | + const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3); |
| 7733 | + PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1, |
| 7734 | + OffsetRegOperand.getReg(), |
| 7735 | + OffsetRegOperand.isKill()); |
| 7736 | + DelInstrs.push_back(LoadInstr); |
| 7737 | + } |
| 7738 | + Register LastLoadReg0 = PrevReg; |
| 7739 | + |
| 7740 | + // First load into register 1. Perform an integer load to zero out the upper |
| 7741 | + // lanes in a single instruction. |
| 7742 | + MachineInstr *Lane0Load = *LoadToLaneInstrsAscending.begin(); |
| 7743 | + MachineInstr *OriginalSplitLoad = |
| 7744 | + *std::next(LoadToLaneInstrsAscending.begin(), NumLanes / 2); |
| 7745 | + Register DestRegForMiddleIndex = MRI.createVirtualRegister( |
| 7746 | + MRI.getRegClass(Lane0Load->getOperand(0).getReg())); |
| 7747 | + |
| 7748 | + const MachineOperand &OriginalSplitToLoadOffsetOperand = |
| 7749 | + OriginalSplitLoad->getOperand(3); |
| 7750 | + MachineInstrBuilder MiddleIndexLoadInstr = |
| 7751 | + CreateLDRInstruction(NumLanes, DestRegForMiddleIndex, |
| 7752 | + OriginalSplitToLoadOffsetOperand.getReg(), |
| 7753 | + OriginalSplitToLoadOffsetOperand.isKill()); |
| 7754 | + |
| 7755 | + InstrIdxForVirtReg.insert( |
| 7756 | + std::make_pair(DestRegForMiddleIndex, InsInstrs.size())); |
| 7757 | + InsInstrs.push_back(MiddleIndexLoadInstr); |
| 7758 | + DelInstrs.push_back(OriginalSplitLoad); |
| 7759 | + |
| 7760 | + // Subreg To Reg instruction for register 1. |
| 7761 | + Register DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass); |
| 7762 | + unsigned SubregType; |
| 7763 | + switch (NumLanes) { |
| 7764 | + case 4: |
| 7765 | + SubregType = AArch64::ssub; |
| 7766 | + break; |
| 7767 | + case 8: |
| 7768 | + SubregType = AArch64::hsub; |
| 7769 | + break; |
| 7770 | + case 16: |
| 7771 | + SubregType = AArch64::bsub; |
| 7772 | + break; |
| 7773 | + default: |
| 7774 | + llvm_unreachable( |
| 7775 | + "Got invalid NumLanes for machine-combiner gather pattern"); |
| 7776 | + } |
| 7777 | + |
| 7778 | + auto SubRegToRegInstr = |
| 7779 | + BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()), |
| 7780 | + DestRegForSubregToReg) |
| 7781 | + .addImm(0) |
| 7782 | + .addReg(DestRegForMiddleIndex, getKillRegState(true)) |
| 7783 | + .addImm(SubregType); |
| 7784 | + InstrIdxForVirtReg.insert( |
| 7785 | + std::make_pair(DestRegForSubregToReg, InsInstrs.size())); |
| 7786 | + InsInstrs.push_back(SubRegToRegInstr); |
| 7787 | + |
| 7788 | + // Load remaining lanes into register 1. |
| 7789 | + auto LanesToLoadToReg1 = |
| 7790 | + llvm::make_range(LoadToLaneInstrsAscending.begin() + NumLanes / 2 + 1, |
| 7791 | + LoadToLaneInstrsAscending.end()); |
| 7792 | + PrevReg = SubRegToRegInstr->getOperand(0).getReg(); |
| 7793 | + for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) { |
| 7794 | + const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3); |
| 7795 | + PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1, |
| 7796 | + OffsetRegOperand.getReg(), |
| 7797 | + OffsetRegOperand.isKill()); |
| 7798 | + |
| 7799 | + // Do not add the last reg to DelInstrs - it will be removed later. |
| 7800 | + if (Index == NumLanes / 2 - 2) { |
| 7801 | + break; |
| 7802 | + } |
| 7803 | + DelInstrs.push_back(LoadInstr); |
| 7804 | + } |
| 7805 | + Register LastLoadReg1 = PrevReg; |
| 7806 | + |
| 7807 | + // Create the final zip instruction to combine the results. |
| 7808 | + MachineInstrBuilder ZipInstr = |
| 7809 | + BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64), |
| 7810 | + Root.getOperand(0).getReg()) |
| 7811 | + .addReg(LastLoadReg0) |
| 7812 | + .addReg(LastLoadReg1); |
| 7813 | + InsInstrs.push_back(ZipInstr); |
| 7814 | +} |
| 7815 | + |
7488 | 7816 | CombinerObjective |
7489 | 7817 | AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const { |
7490 | 7818 | switch (Pattern) { |
7491 | 7819 | case AArch64MachineCombinerPattern::SUBADD_OP1: |
7492 | 7820 | case AArch64MachineCombinerPattern::SUBADD_OP2: |
| 7821 | + case AArch64MachineCombinerPattern::GATHER_LANE_i32: |
| 7822 | + case AArch64MachineCombinerPattern::GATHER_LANE_i16: |
| 7823 | + case AArch64MachineCombinerPattern::GATHER_LANE_i8: |
7493 | 7824 | return CombinerObjective::MustReduceDepth; |
7494 | 7825 | default: |
7495 | 7826 | return TargetInstrInfo::getCombinerObjective(Pattern); |
@@ -7519,6 +7850,10 @@ bool AArch64InstrInfo::getMachineCombinerPatterns( |
7519 | 7850 | if (getMiscPatterns(Root, Patterns)) |
7520 | 7851 | return true; |
7521 | 7852 |
|
| 7853 | + // Load patterns |
| 7854 | + if (getLoadPatterns(Root, Patterns)) |
| 7855 | + return true; |
| 7856 | + |
7522 | 7857 | return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns, |
7523 | 7858 | DoRegPressureReduce); |
7524 | 7859 | } |
@@ -8774,6 +9109,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence( |
8774 | 9109 | MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs); |
8775 | 9110 | break; |
8776 | 9111 | } |
| 9112 | + case AArch64MachineCombinerPattern::GATHER_LANE_i32: { |
| 9113 | + generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, |
| 9114 | + Pattern, 4); |
| 9115 | + break; |
| 9116 | + } |
| 9117 | + case AArch64MachineCombinerPattern::GATHER_LANE_i16: { |
| 9118 | + generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, |
| 9119 | + Pattern, 8); |
| 9120 | + break; |
| 9121 | + } |
| 9122 | + case AArch64MachineCombinerPattern::GATHER_LANE_i8: { |
| 9123 | + generateGatherLanePattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, |
| 9124 | + Pattern, 16); |
| 9125 | + break; |
| 9126 | + } |
8777 | 9127 |
|
8778 | 9128 | } // end switch (Pattern) |
8779 | 9129 | // Record MUL and ADD/SUB for deletion |
|
0 commit comments