1111// ===----------------------------------------------------------------------===//
1212
1313#include " llvm/CodeGen/TargetInstrInfo.h"
14+ #include " llvm/ADT/SmallSet.h"
1415#include " llvm/ADT/StringExtras.h"
1516#include " llvm/BinaryFormat/Dwarf.h"
1617#include " llvm/CodeGen/MachineCombinerPattern.h"
@@ -42,6 +43,19 @@ static cl::opt<bool> DisableHazardRecognizer(
4243 " disable-sched-hazard" , cl::Hidden, cl::init(false ),
4344 cl::desc(" Disable hazard detection during preRA scheduling" ));
4445
46+ static cl::opt<bool > EnableAccReassociation (
47+ " acc-reassoc" , cl::Hidden, cl::init(true ),
48+ cl::desc(" Enable reassociation of accumulation chains" ));
49+
50+ static cl::opt<unsigned int >
51+ MinAccumulatorDepth (" acc-min-depth" , cl::Hidden, cl::init(8 ),
52+ cl::desc(" Minimum length of accumulator chains "
53+ " required for the optimization to kick in" ));
54+
55+ static cl::opt<unsigned int > MaxAccumulatorWidth (
56+ " acc-max-width" , cl::Hidden, cl::init(3 ),
57+ cl::desc(" Maximum number of branches in the accumulator tree" ));
58+
4559TargetInstrInfo::~TargetInstrInfo () = default ;
4660
4761const TargetRegisterClass*
@@ -899,6 +913,154 @@ bool TargetInstrInfo::isReassociationCandidate(const MachineInstr &Inst,
899913 hasReassociableSibling (Inst, Commuted);
900914}
901915
916+ // Utility routine that checks if \param MO is defined by an
917+ // \param CombineOpc instruction in the basic block \param MBB.
918+ // If \param CombineOpc is not provided, the OpCode check will
919+ // be skipped.
920+ static bool canCombine (MachineBasicBlock &MBB, MachineOperand &MO,
921+ unsigned CombineOpc = 0 ) {
922+ MachineRegisterInfo &MRI = MBB.getParent ()->getRegInfo ();
923+ MachineInstr *MI = nullptr ;
924+
925+ if (MO.isReg () && MO.getReg ().isVirtual ())
926+ MI = MRI.getUniqueVRegDef (MO.getReg ());
927+ // And it needs to be in the trace (otherwise, it won't have a depth).
928+ if (!MI || MI->getParent () != &MBB ||
929+ ((unsigned )MI->getOpcode () != CombineOpc && CombineOpc != 0 ))
930+ return false ;
931+ // Must only used by the user we combine with.
932+ if (!MRI.hasOneNonDBGUse (MI->getOperand (0 ).getReg ()))
933+ return false ;
934+
935+ return true ;
936+ }
937+
938+ // A chain of accumulation instructions will be selected IFF:
939+ // 1. All the accumulation instructions in the chain have the same opcode,
940+ // besides the first that has a slightly different opcode because it does
941+ // not accumulate into a register.
942+ // 2. All the instructions in the chain are combinable (have a single use
943+ // which itself is part of the chain).
944+ // 3. Meets the required minimum length.
945+ void TargetInstrInfo::getAccumulatorChain (
946+ MachineInstr *CurrentInstr, SmallVectorImpl<Register> &Chain) const {
947+ // Walk up the chain of accumulation instructions and collect them in the
948+ // vector.
949+ MachineBasicBlock &MBB = *CurrentInstr->getParent ();
950+ const MachineRegisterInfo &MRI = MBB.getParent ()->getRegInfo ();
951+ unsigned AccumulatorOpcode = CurrentInstr->getOpcode ();
952+ std::optional<unsigned > ChainStartOpCode =
953+ getAccumulationStartOpcode (AccumulatorOpcode);
954+
955+ if (!ChainStartOpCode.has_value ())
956+ return ;
957+
958+ // Push the first accumulator result to the start of the chain.
959+ Chain.push_back (CurrentInstr->getOperand (0 ).getReg ());
960+
961+ // Collect the accumulator input register from all instructions in the chain.
962+ while (CurrentInstr &&
963+ canCombine (MBB, CurrentInstr->getOperand (1 ), AccumulatorOpcode)) {
964+ Chain.push_back (CurrentInstr->getOperand (1 ).getReg ());
965+ CurrentInstr = MRI.getUniqueVRegDef (CurrentInstr->getOperand (1 ).getReg ());
966+ }
967+
968+ // Add the instruction at the top of the chain.
969+ if (CurrentInstr->getOpcode () == AccumulatorOpcode &&
970+ canCombine (MBB, CurrentInstr->getOperand (1 )))
971+ Chain.push_back (CurrentInstr->getOperand (1 ).getReg ());
972+ }
973+
974+ // / Find chains of accumulations that can be rewritten as a tree for increased
975+ // / ILP.
976+ bool TargetInstrInfo::getAccumulatorReassociationPatterns (
977+ MachineInstr &Root, SmallVectorImpl<unsigned > &Patterns) const {
978+ if (!EnableAccReassociation)
979+ return false ;
980+
981+ unsigned Opc = Root.getOpcode ();
982+ if (!isAccumulationOpcode (Opc))
983+ return false ;
984+
985+ // Verify that this is the end of the chain.
986+ MachineBasicBlock &MBB = *Root.getParent ();
987+ MachineRegisterInfo &MRI = MBB.getParent ()->getRegInfo ();
988+ if (!MRI.hasOneNonDBGUser (Root.getOperand (0 ).getReg ()))
989+ return false ;
990+
991+ auto User = MRI.use_instr_begin (Root.getOperand (0 ).getReg ());
992+ if (User->getOpcode () == Opc)
993+ return false ;
994+
995+ // Walk up the use chain and collect the reduction chain.
996+ SmallVector<Register, 32 > Chain;
997+ getAccumulatorChain (&Root, Chain);
998+
999+ // Reject chains which are too short to be worth modifying.
1000+ if (Chain.size () < MinAccumulatorDepth)
1001+ return false ;
1002+
1003+ // Check if the MBB this instruction is a part of contains any other chains.
1004+ // If so, don't apply it.
1005+ SmallSet<Register, 32 > ReductionChain (Chain.begin (), Chain.end ());
1006+ for (const auto &I : MBB) {
1007+ if (I.getOpcode () == Opc &&
1008+ !ReductionChain.contains (I.getOperand (0 ).getReg ()))
1009+ return false ;
1010+ }
1011+
1012+ Patterns.push_back (MachineCombinerPattern::ACC_CHAIN);
1013+ return true ;
1014+ }
1015+
1016+ // Reduce branches of the accumulator tree by adding them together.
1017+ void TargetInstrInfo::reduceAccumulatorTree (
1018+ SmallVectorImpl<Register> &RegistersToReduce,
1019+ SmallVectorImpl<MachineInstr *> &InsInstrs, MachineFunction &MF,
1020+ MachineInstr &Root, MachineRegisterInfo &MRI,
1021+ DenseMap<unsigned , unsigned > &InstrIdxForVirtReg,
1022+ Register ResultReg) const {
1023+ const TargetInstrInfo *TII = MF.getSubtarget ().getInstrInfo ();
1024+ SmallVector<Register, 8 > NewRegs;
1025+
1026+ // Get the opcode for the reduction instruction we will need to build.
1027+ // If for some reason it is not defined, early exit and don't apply this.
1028+ unsigned ReduceOpCode = getReduceOpcodeForAccumulator (Root.getOpcode ());
1029+
1030+ for (unsigned int i = 1 ; i <= (RegistersToReduce.size () / 2 ); i += 2 ) {
1031+ auto RHS = RegistersToReduce[i - 1 ];
1032+ auto LHS = RegistersToReduce[i];
1033+ Register Dest;
1034+ // If we are reducing 2 registers, reuse the original result register.
1035+ if (RegistersToReduce.size () == 2 )
1036+ Dest = ResultReg;
1037+ // Otherwise, create a new virtual register to hold the partial sum.
1038+ else {
1039+ auto NewVR = MRI.createVirtualRegister (
1040+ MRI.getRegClass (Root.getOperand (0 ).getReg ()));
1041+ Dest = NewVR;
1042+ NewRegs.push_back (Dest);
1043+ InstrIdxForVirtReg.insert (std::make_pair (Dest, InsInstrs.size ()));
1044+ }
1045+
1046+ // Create the new reduction instruction.
1047+ MachineInstrBuilder MIB =
1048+ BuildMI (MF, MIMetadata (Root), TII->get (ReduceOpCode), Dest)
1049+ .addReg (RHS, getKillRegState (true ))
1050+ .addReg (LHS, getKillRegState (true ));
1051+ // Copy any flags needed from the original instruction.
1052+ MIB->setFlags (Root.getFlags ());
1053+ InsInstrs.push_back (MIB);
1054+ }
1055+
1056+ // If the number of registers to reduce is odd, add the remaining register to
1057+ // the vector of registers to reduce.
1058+ if (RegistersToReduce.size () % 2 != 0 )
1059+ NewRegs.push_back (RegistersToReduce[RegistersToReduce.size () - 1 ]);
1060+
1061+ RegistersToReduce = NewRegs;
1062+ }
1063+
9021064// The concept of the reassociation pass is that these operations can benefit
9031065// from this kind of transformation:
9041066//
@@ -938,6 +1100,8 @@ bool TargetInstrInfo::getMachineCombinerPatterns(
9381100 }
9391101 return true ;
9401102 }
1103+ if (getAccumulatorReassociationPatterns (Root, Patterns))
1104+ return true ;
9411105
9421106 return false ;
9431107}
@@ -949,7 +1113,12 @@ bool TargetInstrInfo::isThroughputPattern(unsigned Pattern) const {
9491113
9501114CombinerObjective
9511115TargetInstrInfo::getCombinerObjective (unsigned Pattern) const {
952- return CombinerObjective::Default;
1116+ switch (Pattern) {
1117+ case MachineCombinerPattern::ACC_CHAIN:
1118+ return CombinerObjective::MustReduceDepth;
1119+ default :
1120+ return CombinerObjective::Default;
1121+ }
9531122}
9541123
9551124std::pair<unsigned , unsigned >
@@ -1252,19 +1421,101 @@ void TargetInstrInfo::genAlternativeCodeSequence(
12521421 SmallVectorImpl<MachineInstr *> &DelInstrs,
12531422 DenseMap<Register, unsigned > &InstIdxForVirtReg) const {
12541423 MachineRegisterInfo &MRI = Root.getMF ()->getRegInfo ();
1424+ MachineBasicBlock &MBB = *Root.getParent ();
1425+ MachineFunction &MF = *MBB.getParent ();
1426+ const TargetInstrInfo *TII = MF.getSubtarget ().getInstrInfo ();
12551427
1256- // Select the previous instruction in the sequence based on the input pattern.
1257- std::array<unsigned , 5 > OperandIndices;
1258- getReassociateOperandIndices (Root, Pattern, OperandIndices);
1259- MachineInstr *Prev =
1260- MRI.getUniqueVRegDef (Root.getOperand (OperandIndices[0 ]).getReg ());
1428+ switch (Pattern) {
1429+ case MachineCombinerPattern::REASSOC_AX_BY:
1430+ case MachineCombinerPattern::REASSOC_AX_YB:
1431+ case MachineCombinerPattern::REASSOC_XA_BY:
1432+ case MachineCombinerPattern::REASSOC_XA_YB: {
1433+ // Select the previous instruction in the sequence based on the input
1434+ // pattern.
1435+ std::array<unsigned , 5 > OperandIndices;
1436+ getReassociateOperandIndices (Root, Pattern, OperandIndices);
1437+ MachineInstr *Prev =
1438+ MRI.getUniqueVRegDef (Root.getOperand (OperandIndices[0 ]).getReg ());
1439+
1440+ // Don't reassociate if Prev and Root are in different blocks.
1441+ if (Prev->getParent () != Root.getParent ())
1442+ return ;
12611443
1262- // Don't reassociate if Prev and Root are in different blocks.
1263- if (Prev->getParent () != Root.getParent ())
1264- return ;
1444+ reassociateOps (Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
1445+ InstIdxForVirtReg);
1446+ break ;
1447+ }
1448+ case MachineCombinerPattern::ACC_CHAIN: {
1449+ SmallVector<Register, 32 > ChainRegs;
1450+ getAccumulatorChain (&Root, ChainRegs);
1451+ unsigned int Depth = ChainRegs.size ();
1452+ assert (MaxAccumulatorWidth > 1 &&
1453+ " Max accumulator width set to illegal value" );
1454+ unsigned int MaxWidth = Log2_32 (Depth) < MaxAccumulatorWidth
1455+ ? Log2_32 (Depth)
1456+ : MaxAccumulatorWidth;
1457+
1458+ // Walk down the chain and rewrite it as a tree.
1459+ for (auto IndexedReg : llvm::enumerate (llvm::reverse (ChainRegs))) {
1460+ // No need to rewrite the first node, it is already perfect as it is.
1461+ if (IndexedReg.index () == 0 )
1462+ continue ;
1463+
1464+ MachineInstr *Instr = MRI.getUniqueVRegDef (IndexedReg.value ());
1465+ MachineInstrBuilder MIB;
1466+ Register AccReg;
1467+ if (IndexedReg.index () < MaxWidth) {
1468+ // Now we need to create new instructions for the first row.
1469+ AccReg = Instr->getOperand (0 ).getReg ();
1470+ std::optional<unsigned > OpCode =
1471+ getAccumulationStartOpcode (Root.getOpcode ());
1472+ assert (OpCode.value () &&
1473+ " Missing opcode for accumulation instruction." );
1474+
1475+ MIB = BuildMI (MF, MIMetadata (*Instr), TII->get (OpCode.value ()), AccReg)
1476+ .addReg (Instr->getOperand (2 ).getReg (),
1477+ getKillRegState (Instr->getOperand (2 ).isKill ()))
1478+ .addReg (Instr->getOperand (3 ).getReg (),
1479+ getKillRegState (Instr->getOperand (3 ).isKill ()));
1480+ } else {
1481+ // For the remaining cases, we need to use an output register of one of
1482+ // the newly inserted instuctions as operand 1
1483+ AccReg = Instr->getOperand (0 ).getReg () == Root.getOperand (0 ).getReg ()
1484+ ? MRI.createVirtualRegister (
1485+ MRI.getRegClass (Root.getOperand (0 ).getReg ()))
1486+ : Instr->getOperand (0 ).getReg ();
1487+ assert (IndexedReg.index () - MaxWidth >= 0 );
1488+ auto AccumulatorInput =
1489+ ChainRegs[Depth - (IndexedReg.index () - MaxWidth) - 1 ];
1490+ MIB = BuildMI (MF, MIMetadata (*Instr), TII->get (Instr->getOpcode ()),
1491+ AccReg)
1492+ .addReg (AccumulatorInput, getKillRegState (true ))
1493+ .addReg (Instr->getOperand (2 ).getReg (),
1494+ getKillRegState (Instr->getOperand (2 ).isKill ()))
1495+ .addReg (Instr->getOperand (3 ).getReg (),
1496+ getKillRegState (Instr->getOperand (3 ).isKill ()));
1497+ }
12651498
1266- reassociateOps (Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
1267- InstIdxForVirtReg);
1499+ MIB->setFlags (Instr->getFlags ());
1500+ InstIdxForVirtReg.insert (std::make_pair (AccReg, InsInstrs.size ()));
1501+ InsInstrs.push_back (MIB);
1502+ DelInstrs.push_back (Instr);
1503+ }
1504+
1505+ SmallVector<Register, 8 > RegistersToReduce;
1506+ for (unsigned i = (InsInstrs.size () - MaxWidth); i < InsInstrs.size ();
1507+ ++i) {
1508+ auto Reg = InsInstrs[i]->getOperand (0 ).getReg ();
1509+ RegistersToReduce.push_back (Reg);
1510+ }
1511+
1512+ while (RegistersToReduce.size () > 1 )
1513+ reduceAccumulatorTree (RegistersToReduce, InsInstrs, MF, Root, MRI,
1514+ InstIdxForVirtReg, Root.getOperand (0 ).getReg ());
1515+
1516+ break ;
1517+ }
1518+ }
12681519}
12691520
12701521MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy () const {
0 commit comments