|
20 | 20 | #include "Utils/AArch64BaseInfo.h" |
21 | 21 | #include "llvm/ADT/ArrayRef.h" |
22 | 22 | #include "llvm/ADT/STLExtras.h" |
| 23 | +#include "llvm/ADT/SmallSet.h" |
23 | 24 | #include "llvm/ADT/SmallVector.h" |
24 | 25 | #include "llvm/CodeGen/CFIInstBuilder.h" |
25 | 26 | #include "llvm/CodeGen/LivePhysRegs.h" |
@@ -7412,11 +7413,347 @@ static bool getMiscPatterns(MachineInstr &Root, |
7412 | 7413 | return false; |
7413 | 7414 | } |
7414 | 7415 |
|
| 7416 | +/// Check if there are any stores or calls between two instructions in the same |
| 7417 | +/// basic block. |
| 7418 | +static bool hasInterveningStoreOrCall(const MachineInstr *First, |
| 7419 | + const MachineInstr *Last) { |
| 7420 | + if (!First || !Last || First == Last) |
| 7421 | + return false; |
| 7422 | + |
| 7423 | + // Both instructions must be in the same basic block. |
| 7424 | + if (First->getParent() != Last->getParent()) |
| 7425 | + return false; |
| 7426 | + |
| 7427 | + // Sanity check that First comes before Last. |
| 7428 | + const MachineBasicBlock *MBB = First->getParent(); |
| 7429 | + auto InstrIt = First->getIterator(); |
| 7430 | + auto LastIt = Last->getIterator(); |
| 7431 | + |
| 7432 | + for (; InstrIt != MBB->end(); ++InstrIt) { |
| 7433 | + if (InstrIt == LastIt) |
| 7434 | + break; |
| 7435 | + |
| 7436 | + // Check for stores or calls that could interfere |
| 7437 | + if (InstrIt->mayStore() || InstrIt->isCall()) |
| 7438 | + return true; |
| 7439 | + } |
| 7440 | + |
| 7441 | + // If we reached the end of the basic block, our instructions must have not |
| 7442 | + // been ordered correctly and the analysis is invalid. |
| 7443 | + assert(InstrIt != MBB->end() && |
| 7444 | + "Got bad machine instructions, First should come before Last!"); |
| 7445 | + return false; |
| 7446 | +} |
| 7447 | + |
| 7448 | +/// Check if the given instruction forms a gather load pattern that can be |
| 7449 | +/// optimized for better Memory-Level Parallelism (MLP). This function |
| 7450 | +/// identifies chains of NEON lane load instructions that load data from |
| 7451 | +/// different memory addresses into individual lanes of a 128-bit vector |
| 7452 | +/// register, then attempts to split the pattern into parallel loads to break |
| 7453 | +/// the serial dependency between instructions. |
| 7454 | +/// |
| 7455 | +/// Pattern Matched: |
| 7456 | +/// Initial scalar load -> SUBREG_TO_REG (lane 0) -> LD1i* (lane 1) -> |
| 7457 | +/// LD1i* (lane 2) -> ... -> LD1i* (lane N-1, Root) |
| 7458 | +/// |
| 7459 | +/// Transformed Into: |
| 7460 | +/// Two parallel vector loads using fewer lanes each, followed by ZIP1v2i64 |
| 7461 | +/// to combine the results, enabling better memory-level parallelism. |
| 7462 | +/// |
| 7463 | +/// Supported Element Types: |
| 7464 | +/// - 32-bit elements (LD1i32, 4 lanes total) |
| 7465 | +/// - 16-bit elements (LD1i16, 8 lanes total) |
| 7466 | +/// - 8-bit elements (LD1i8, 16 lanes total) |
| 7467 | +static bool getGatherPattern(MachineInstr &Root, |
| 7468 | + SmallVectorImpl<unsigned> &Patterns, |
| 7469 | + unsigned LoadLaneOpCode, unsigned NumLanes) { |
| 7470 | + const MachineFunction *MF = Root.getMF(); |
| 7471 | + |
| 7472 | + // Early exit if optimizing for size. |
| 7473 | + if (MF->getFunction().hasMinSize()) |
| 7474 | + return false; |
| 7475 | + |
| 7476 | + const MachineRegisterInfo &MRI = MF->getRegInfo(); |
| 7477 | + const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo(); |
| 7478 | + |
| 7479 | + // The root of the pattern must load into the last lane of the vector. |
| 7480 | + if (Root.getOperand(2).getImm() != NumLanes - 1) |
| 7481 | + return false; |
| 7482 | + |
| 7483 | + // Check that we have load into all lanes except lane 0. |
| 7484 | + // For each load we also want to check that: |
| 7485 | + // 1. It has a single non-debug use (since we will be replacing the virtual |
| 7486 | + // register) |
| 7487 | + // 2. That the addressing mode only uses a single pointer operand |
| 7488 | + auto *CurrInstr = MRI.getUniqueVRegDef(Root.getOperand(1).getReg()); |
| 7489 | + auto Range = llvm::seq<unsigned>(1, NumLanes - 1); |
| 7490 | + SmallSet<unsigned, 16> RemainingLanes(Range.begin(), Range.end()); |
| 7491 | + SmallSet<const MachineInstr *, 16> LoadInstrs = {}; |
| 7492 | + while (!RemainingLanes.empty() && CurrInstr && |
| 7493 | + CurrInstr->getOpcode() == LoadLaneOpCode && |
| 7494 | + MRI.hasOneNonDBGUse(CurrInstr->getOperand(0).getReg()) && |
| 7495 | + CurrInstr->getNumOperands() == 4) { |
| 7496 | + RemainingLanes.erase(CurrInstr->getOperand(2).getImm()); |
| 7497 | + LoadInstrs.insert(CurrInstr); |
| 7498 | + CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg()); |
| 7499 | + } |
| 7500 | + |
| 7501 | + // Check that we have found a match for lanes N-1.. 1. |
| 7502 | + if (!RemainingLanes.empty()) |
| 7503 | + return false; |
| 7504 | + |
| 7505 | + // Match the SUBREG_TO_REG sequence. |
| 7506 | + if (CurrInstr->getOpcode() != TargetOpcode::SUBREG_TO_REG) |
| 7507 | + return false; |
| 7508 | + |
| 7509 | + // Verify that the subreg to reg loads an integer into the first lane. |
| 7510 | + auto Lane0LoadReg = CurrInstr->getOperand(2).getReg(); |
| 7511 | + unsigned SingleLaneSizeInBits = 128 / NumLanes; |
| 7512 | + if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != SingleLaneSizeInBits) |
| 7513 | + return false; |
| 7514 | + |
| 7515 | + // Verify that it also has a single non debug use. |
| 7516 | + if (!MRI.hasOneNonDBGUse(Lane0LoadReg)) |
| 7517 | + return false; |
| 7518 | + |
| 7519 | + LoadInstrs.insert(MRI.getUniqueVRegDef(Lane0LoadReg)); |
| 7520 | + |
| 7521 | + // Check for intervening stores or calls between the first and last load. |
| 7522 | + // Sort load instructions by program order. |
| 7523 | + SmallVector<const MachineInstr *, 16> SortedLoads(LoadInstrs.begin(), |
| 7524 | + LoadInstrs.end()); |
| 7525 | + llvm::sort(SortedLoads, [](const MachineInstr *A, const MachineInstr *B) { |
| 7526 | + if (A->getParent() != B->getParent()) { |
| 7527 | + // If in different blocks, this shouldn't happen for gather patterns. |
| 7528 | + return false; |
| 7529 | + } |
| 7530 | + // Compare positions within the same basic block. |
| 7531 | + for (const MachineInstr &MI : *A->getParent()) { |
| 7532 | + if (&MI == A) |
| 7533 | + return true; |
| 7534 | + if (&MI == B) |
| 7535 | + return false; |
| 7536 | + } |
| 7537 | + return false; |
| 7538 | + }); |
| 7539 | + |
| 7540 | + const MachineInstr *FirstLoad = SortedLoads.front(); |
| 7541 | + const MachineInstr *LastLoad = SortedLoads.back(); |
| 7542 | + |
| 7543 | + if (hasInterveningStoreOrCall(FirstLoad, LastLoad)) |
| 7544 | + return false; |
| 7545 | + |
| 7546 | + switch (NumLanes) { |
| 7547 | + case 4: |
| 7548 | + Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i32); |
| 7549 | + break; |
| 7550 | + case 8: |
| 7551 | + Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i16); |
| 7552 | + break; |
| 7553 | + case 16: |
| 7554 | + Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i8); |
| 7555 | + break; |
| 7556 | + default: |
| 7557 | + llvm_unreachable("Got bad number of lanes for gather pattern."); |
| 7558 | + } |
| 7559 | + |
| 7560 | + return true; |
| 7561 | +} |
| 7562 | + |
| 7563 | +/// Search for patterns of LD instructions we can optimize. |
| 7564 | +static bool getLoadPatterns(MachineInstr &Root, |
| 7565 | + SmallVectorImpl<unsigned> &Patterns) { |
| 7566 | + |
| 7567 | + // The pattern searches for loads into single lanes. |
| 7568 | + switch (Root.getOpcode()) { |
| 7569 | + case AArch64::LD1i32: |
| 7570 | + return getGatherPattern(Root, Patterns, Root.getOpcode(), 4); |
| 7571 | + case AArch64::LD1i16: |
| 7572 | + return getGatherPattern(Root, Patterns, Root.getOpcode(), 8); |
| 7573 | + case AArch64::LD1i8: |
| 7574 | + return getGatherPattern(Root, Patterns, Root.getOpcode(), 16); |
| 7575 | + default: |
| 7576 | + return false; |
| 7577 | + } |
| 7578 | +} |
| 7579 | + |
| 7580 | +/// Generate optimized instruction sequence for gather load patterns to improve |
| 7581 | +/// Memory-Level Parallelism (MLP). This function transforms a chain of |
| 7582 | +/// sequential NEON lane loads into parallel vector loads that can execute |
| 7583 | +/// concurrently. |
| 7584 | +static void |
| 7585 | +generateGatherPattern(MachineInstr &Root, |
| 7586 | + SmallVectorImpl<MachineInstr *> &InsInstrs, |
| 7587 | + SmallVectorImpl<MachineInstr *> &DelInstrs, |
| 7588 | + DenseMap<Register, unsigned> &InstrIdxForVirtReg, |
| 7589 | + unsigned Pattern, unsigned NumLanes) { |
| 7590 | + MachineFunction &MF = *Root.getParent()->getParent(); |
| 7591 | + MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 7592 | + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); |
| 7593 | + |
| 7594 | + // Gather the initial load instructions to build the pattern. |
| 7595 | + SmallVector<MachineInstr *, 16> LoadToLaneInstrs; |
| 7596 | + MachineInstr *CurrInstr = &Root; |
| 7597 | + for (unsigned i = 0; i < NumLanes - 1; ++i) { |
| 7598 | + LoadToLaneInstrs.push_back(CurrInstr); |
| 7599 | + CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg()); |
| 7600 | + } |
| 7601 | + |
| 7602 | + // Sort the load instructions according to the lane. |
| 7603 | + llvm::sort(LoadToLaneInstrs, |
| 7604 | + [](const MachineInstr *A, const MachineInstr *B) { |
| 7605 | + return A->getOperand(2).getImm() > B->getOperand(2).getImm(); |
| 7606 | + }); |
| 7607 | + |
| 7608 | + MachineInstr *SubregToReg = CurrInstr; |
| 7609 | + LoadToLaneInstrs.push_back( |
| 7610 | + MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg())); |
| 7611 | + auto LoadToLaneInstrsAscending = llvm::reverse(LoadToLaneInstrs); |
| 7612 | + |
| 7613 | + const TargetRegisterClass *FPR128RegClass = |
| 7614 | + MRI.getRegClass(Root.getOperand(0).getReg()); |
| 7615 | + |
| 7616 | + // Helper lambda to create a LD1 instruction. |
| 7617 | + auto CreateLD1Instruction = [&](MachineInstr *OriginalInstr, |
| 7618 | + Register SrcRegister, unsigned Lane, |
| 7619 | + Register OffsetRegister, |
| 7620 | + bool OffsetRegisterKillState) { |
| 7621 | + auto NewRegister = MRI.createVirtualRegister(FPR128RegClass); |
| 7622 | + MachineInstrBuilder LoadIndexIntoRegister = |
| 7623 | + BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()), |
| 7624 | + NewRegister) |
| 7625 | + .addReg(SrcRegister) |
| 7626 | + .addImm(Lane) |
| 7627 | + .addReg(OffsetRegister, getKillRegState(OffsetRegisterKillState)); |
| 7628 | + InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size())); |
| 7629 | + InsInstrs.push_back(LoadIndexIntoRegister); |
| 7630 | + return NewRegister; |
| 7631 | + }; |
| 7632 | + |
| 7633 | + // Helper to create load instruction based on the NumLanes in the NEON |
| 7634 | + // register we are rewriting. |
| 7635 | + auto CreateLDRInstruction = [&](unsigned NumLanes, Register DestReg, |
| 7636 | + Register OffsetReg, |
| 7637 | + bool KillState) -> MachineInstrBuilder { |
| 7638 | + unsigned Opcode; |
| 7639 | + switch (NumLanes) { |
| 7640 | + case 4: |
| 7641 | + Opcode = AArch64::LDRSui; |
| 7642 | + break; |
| 7643 | + case 8: |
| 7644 | + Opcode = AArch64::LDRHui; |
| 7645 | + break; |
| 7646 | + case 16: |
| 7647 | + Opcode = AArch64::LDRBui; |
| 7648 | + break; |
| 7649 | + default: |
| 7650 | + llvm_unreachable( |
| 7651 | + "Got unsupported number of lanes in machine-combiner gather pattern"); |
| 7652 | + } |
| 7653 | + // Immediate offset load |
| 7654 | + return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg) |
| 7655 | + .addReg(OffsetReg) |
| 7656 | + .addImm(0); |
| 7657 | + }; |
| 7658 | + |
| 7659 | + // Load the remaining lanes into register 0. |
| 7660 | + auto LanesToLoadToReg0 = |
| 7661 | + llvm::make_range(LoadToLaneInstrsAscending.begin() + 1, |
| 7662 | + LoadToLaneInstrsAscending.begin() + NumLanes / 2); |
| 7663 | + Register PrevReg = SubregToReg->getOperand(0).getReg(); |
| 7664 | + for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) { |
| 7665 | + const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3); |
| 7666 | + PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1, |
| 7667 | + OffsetRegOperand.getReg(), |
| 7668 | + OffsetRegOperand.isKill()); |
| 7669 | + DelInstrs.push_back(LoadInstr); |
| 7670 | + } |
| 7671 | + Register LastLoadReg0 = PrevReg; |
| 7672 | + |
| 7673 | + // First load into register 1. Perform an integer load to zero out the upper |
| 7674 | + // lanes in a single instruction. |
| 7675 | + MachineInstr *Lane0Load = *LoadToLaneInstrsAscending.begin(); |
| 7676 | + MachineInstr *OriginalSplitLoad = |
| 7677 | + *std::next(LoadToLaneInstrsAscending.begin(), NumLanes / 2); |
| 7678 | + Register DestRegForMiddleIndex = MRI.createVirtualRegister( |
| 7679 | + MRI.getRegClass(Lane0Load->getOperand(0).getReg())); |
| 7680 | + |
| 7681 | + const MachineOperand &OriginalSplitToLoadOffsetOperand = |
| 7682 | + OriginalSplitLoad->getOperand(3); |
| 7683 | + MachineInstrBuilder MiddleIndexLoadInstr = |
| 7684 | + CreateLDRInstruction(NumLanes, DestRegForMiddleIndex, |
| 7685 | + OriginalSplitToLoadOffsetOperand.getReg(), |
| 7686 | + OriginalSplitToLoadOffsetOperand.isKill()); |
| 7687 | + |
| 7688 | + InstrIdxForVirtReg.insert( |
| 7689 | + std::make_pair(DestRegForMiddleIndex, InsInstrs.size())); |
| 7690 | + InsInstrs.push_back(MiddleIndexLoadInstr); |
| 7691 | + DelInstrs.push_back(OriginalSplitLoad); |
| 7692 | + |
| 7693 | + // Subreg To Reg instruction for register 1. |
| 7694 | + Register DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass); |
| 7695 | + unsigned SubregType; |
| 7696 | + switch (NumLanes) { |
| 7697 | + case 4: |
| 7698 | + SubregType = AArch64::ssub; |
| 7699 | + break; |
| 7700 | + case 8: |
| 7701 | + SubregType = AArch64::hsub; |
| 7702 | + break; |
| 7703 | + case 16: |
| 7704 | + SubregType = AArch64::bsub; |
| 7705 | + break; |
| 7706 | + default: |
| 7707 | + llvm_unreachable( |
| 7708 | + "Got invalid NumLanes for machine-combiner gather pattern"); |
| 7709 | + } |
| 7710 | + |
| 7711 | + auto SubRegToRegInstr = |
| 7712 | + BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()), |
| 7713 | + DestRegForSubregToReg) |
| 7714 | + .addImm(0) |
| 7715 | + .addReg(DestRegForMiddleIndex, getKillRegState(true)) |
| 7716 | + .addImm(SubregType); |
| 7717 | + InstrIdxForVirtReg.insert( |
| 7718 | + std::make_pair(DestRegForSubregToReg, InsInstrs.size())); |
| 7719 | + InsInstrs.push_back(SubRegToRegInstr); |
| 7720 | + |
| 7721 | + // Load remaining lanes into register 1. |
| 7722 | + auto LanesToLoadToReg1 = |
| 7723 | + llvm::make_range(LoadToLaneInstrsAscending.begin() + NumLanes / 2 + 1, |
| 7724 | + LoadToLaneInstrsAscending.end()); |
| 7725 | + PrevReg = SubRegToRegInstr->getOperand(0).getReg(); |
| 7726 | + for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) { |
| 7727 | + const MachineOperand &OffsetRegOperand = LoadInstr->getOperand(3); |
| 7728 | + PrevReg = CreateLD1Instruction(LoadInstr, PrevReg, Index + 1, |
| 7729 | + OffsetRegOperand.getReg(), |
| 7730 | + OffsetRegOperand.isKill()); |
| 7731 | + |
| 7732 | + // Do not add the last reg to DelInstrs - it will be removed later. |
| 7733 | + if (Index == NumLanes / 2 - 2) { |
| 7734 | + break; |
| 7735 | + } |
| 7736 | + DelInstrs.push_back(LoadInstr); |
| 7737 | + } |
| 7738 | + Register LastLoadReg1 = PrevReg; |
| 7739 | + |
| 7740 | + // Create the final zip instruction to combine the results. |
| 7741 | + MachineInstrBuilder ZipInstr = |
| 7742 | + BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64), |
| 7743 | + Root.getOperand(0).getReg()) |
| 7744 | + .addReg(LastLoadReg0) |
| 7745 | + .addReg(LastLoadReg1); |
| 7746 | + InsInstrs.push_back(ZipInstr); |
| 7747 | +} |
| 7748 | + |
7415 | 7749 | CombinerObjective |
7416 | 7750 | AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const { |
7417 | 7751 | switch (Pattern) { |
7418 | 7752 | case AArch64MachineCombinerPattern::SUBADD_OP1: |
7419 | 7753 | case AArch64MachineCombinerPattern::SUBADD_OP2: |
| 7754 | + case AArch64MachineCombinerPattern::GATHER_LANE_i32: |
| 7755 | + case AArch64MachineCombinerPattern::GATHER_LANE_i16: |
| 7756 | + case AArch64MachineCombinerPattern::GATHER_LANE_i8: |
7420 | 7757 | return CombinerObjective::MustReduceDepth; |
7421 | 7758 | default: |
7422 | 7759 | return TargetInstrInfo::getCombinerObjective(Pattern); |
@@ -7446,6 +7783,10 @@ bool AArch64InstrInfo::getMachineCombinerPatterns( |
7446 | 7783 | if (getMiscPatterns(Root, Patterns)) |
7447 | 7784 | return true; |
7448 | 7785 |
|
| 7786 | + // Load patterns |
| 7787 | + if (getLoadPatterns(Root, Patterns)) |
| 7788 | + return true; |
| 7789 | + |
7449 | 7790 | return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns, |
7450 | 7791 | DoRegPressureReduce); |
7451 | 7792 | } |
@@ -8701,6 +9042,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence( |
8701 | 9042 | MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs); |
8702 | 9043 | break; |
8703 | 9044 | } |
| 9045 | + case AArch64MachineCombinerPattern::GATHER_LANE_i32: { |
| 9046 | + generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, |
| 9047 | + Pattern, 4); |
| 9048 | + break; |
| 9049 | + } |
| 9050 | + case AArch64MachineCombinerPattern::GATHER_LANE_i16: { |
| 9051 | + generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, |
| 9052 | + Pattern, 8); |
| 9053 | + break; |
| 9054 | + } |
| 9055 | + case AArch64MachineCombinerPattern::GATHER_LANE_i8: { |
| 9056 | + generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, |
| 9057 | + Pattern, 16); |
| 9058 | + break; |
| 9059 | + } |
8704 | 9060 |
|
8705 | 9061 | } // end switch (Pattern) |
8706 | 9062 | // Record MUL and ADD/SUB for deletion |
|
0 commit comments