@@ -441,16 +441,55 @@ using IsExpiredFn = function_ref<bool(const MachineInstr &, int WaitStates)>;
441441using GetNumWaitStatesFn = function_ref<unsigned int (const MachineInstr &)>;
442442
443443// Search for a hazard in a block and its predecessors.
444- // StateT must implement getHashValue().
444+ // StateT must implement getHashValue() and isEqual() .
445445template <typename StateT>
446446static bool
447447hasHazard (StateT InitialState,
448448 function_ref<HazardFnResult(StateT &, const MachineInstr &)> IsHazard,
449449 function_ref<void(StateT &, const MachineInstr &)> UpdateState,
450450 const MachineBasicBlock *InitialMBB,
451451 MachineBasicBlock::const_reverse_instr_iterator InitialI) {
452- SmallVector<std::pair<const MachineBasicBlock *, StateT>> Worklist;
453- DenseSet<std::pair<const MachineBasicBlock *, unsigned >> Visited;
452+ SmallVector<std::pair<const MachineBasicBlock *, unsigned >> Worklist;
453+ SmallDenseSet<std::pair<const MachineBasicBlock *, unsigned >> Visited;
454+ SmallVector<std::pair<unsigned , unsigned >, 1 > Collisions;
455+ SmallDenseMap<unsigned , unsigned > StateHash2Idx;
456+ SmallVector<StateT> States;
457+
458+ // States contains a vector of unique state structures.
459+ // StateT is hashed via getHashValue() and StateHash2Idx maps each hash
460+ // to an index in the States vector.
461+ // In the unlikely event of a hash collision the Collision vector provides
462+ // additional hash to index associations which must be retrieved by a linear
463+ // scan.
464+
465+ // Retrieve unique constant index for a StateT structure in the States vector.
466+ auto ResolveStateIdx = [&](const StateT State) {
467+ unsigned StateHash = State.getHashValue ();
468+ unsigned StateIdx;
469+ if (!StateHash2Idx.contains (StateHash)) {
470+ StateIdx = States.size ();
471+ States.push_back (State);
472+ StateHash2Idx[StateHash] = StateIdx;
473+ } else {
474+ StateIdx = StateHash2Idx[StateHash];
475+ if (LLVM_UNLIKELY (!StateT::isEqual (State, States[StateIdx]))) {
476+ // Hash collision
477+ auto *Collision = llvm::find_if (Collisions, [&](auto &C) {
478+ return C.first == StateHash &&
479+ StateT::isEqual (State, States[C.second ]);
480+ });
481+ if (Collision) {
482+ StateIdx = Collision->second ;
483+ } else {
484+ StateIdx = States.size ();
485+ States.push_back (State);
486+ Collisions.emplace_back (StateHash, StateIdx);
487+ }
488+ }
489+ }
490+ return StateIdx;
491+ };
492+
454493 const MachineBasicBlock *MBB = InitialMBB;
455494 StateT State = InitialState;
456495 auto I = InitialI;
@@ -477,18 +516,20 @@ hasHazard(StateT InitialState,
477516 }
478517
479518 if (!Expired) {
480- unsigned StateHash = State. getHashValue ( );
519+ unsigned StateIdx = ResolveStateIdx (State );
481520 for (MachineBasicBlock *Pred : MBB->predecessors ()) {
482- if (!Visited.insert (std::pair (Pred, StateHash )).second )
521+ if (!Visited.insert (std::pair (Pred, StateIdx )).second )
483522 continue ;
484- Worklist.emplace_back (Pred, State );
523+ Worklist.emplace_back (Pred, StateIdx );
485524 }
486525 }
487526
488527 if (Worklist.empty ())
489528 break ;
490529
491- std::tie (MBB, State) = Worklist.pop_back_val ();
530+ unsigned StateIdx;
531+ std::tie (MBB, StateIdx) = Worklist.pop_back_val ();
532+ State = States[StateIdx];
492533 I = MBB->instr_rbegin ();
493534 }
494535
@@ -1658,6 +1699,10 @@ bool GCNHazardRecognizer::fixVALUPartialForwardingHazard(MachineInstr *MI) {
16581699 unsigned getHashValue () const {
16591700 return hash_combine (ExecPos, VALUs, hash_combine_range (DefPos));
16601701 }
1702+ static bool isEqual (const StateType &LHS, const StateType &RHS) {
1703+ return LHS.DefPos == RHS.DefPos && LHS.ExecPos == RHS.ExecPos &&
1704+ LHS.VALUs == RHS.VALUs ;
1705+ }
16611706 };
16621707
16631708 StateType State;
@@ -1796,6 +1841,9 @@ bool GCNHazardRecognizer::fixVALUTransUseHazard(MachineInstr *MI) {
17961841 int TRANS = 0 ;
17971842
17981843 unsigned getHashValue () const { return hash_combine (VALUs, TRANS); }
1844+ static bool isEqual (const StateType &LHS, const StateType &RHS) {
1845+ return LHS.VALUs == RHS.VALUs && LHS.TRANS == RHS.TRANS ;
1846+ }
17991847 };
18001848
18011849 StateType State;
0 commit comments