Skip to content

Commit 2a18fbe

Browse files
committed
- Rework to use unified store of states
- Handle hashing collisions
1 parent 0230c8e commit 2a18fbe

File tree

1 file changed

+55
-7
lines changed

1 file changed

+55
-7
lines changed

llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -441,16 +441,55 @@ using IsExpiredFn = function_ref<bool(const MachineInstr &, int WaitStates)>;
441441
using GetNumWaitStatesFn = function_ref<unsigned int(const MachineInstr &)>;
442442

443443
// Search for a hazard in a block and its predecessors.
444-
// StateT must implement getHashValue().
444+
// StateT must implement getHashValue() and isEqual().
445445
template <typename StateT>
446446
static bool
447447
hasHazard(StateT InitialState,
448448
function_ref<HazardFnResult(StateT &, const MachineInstr &)> IsHazard,
449449
function_ref<void(StateT &, const MachineInstr &)> UpdateState,
450450
const MachineBasicBlock *InitialMBB,
451451
MachineBasicBlock::const_reverse_instr_iterator InitialI) {
452-
SmallVector<std::pair<const MachineBasicBlock *, StateT>> Worklist;
453-
DenseSet<std::pair<const MachineBasicBlock *, unsigned>> Visited;
452+
SmallVector<std::pair<const MachineBasicBlock *, unsigned>> Worklist;
453+
SmallDenseSet<std::pair<const MachineBasicBlock *, unsigned>> Visited;
454+
SmallVector<std::pair<unsigned, unsigned>, 1> Collisions;
455+
SmallDenseMap<unsigned, unsigned> StateHash2Idx;
456+
SmallVector<StateT> States;
457+
458+
// States contains a vector of unique state structures.
459+
// StateT is hashed via getHashValue() and StateHash2Idx maps each hash
460+
// to an index in the States vector.
461+
// In the unlikely event of a hash collision the Collision vector provides
462+
// additional hash to index associations which must be retrieved by a linear
463+
// scan.
464+
465+
// Retrieve unique constant index for a StateT structure in the States vector.
466+
auto ResolveStateIdx = [&](const StateT State) {
467+
unsigned StateHash = State.getHashValue();
468+
unsigned StateIdx;
469+
if (!StateHash2Idx.contains(StateHash)) {
470+
StateIdx = States.size();
471+
States.push_back(State);
472+
StateHash2Idx[StateHash] = StateIdx;
473+
} else {
474+
StateIdx = StateHash2Idx[StateHash];
475+
if (LLVM_UNLIKELY(!StateT::isEqual(State, States[StateIdx]))) {
476+
// Hash collision
477+
auto *Collision = llvm::find_if(Collisions, [&](auto &C) {
478+
return C.first == StateHash &&
479+
StateT::isEqual(State, States[C.second]);
480+
});
481+
if (Collision) {
482+
StateIdx = Collision->second;
483+
} else {
484+
StateIdx = States.size();
485+
States.push_back(State);
486+
Collisions.emplace_back(StateHash, StateIdx);
487+
}
488+
}
489+
}
490+
return StateIdx;
491+
};
492+
454493
const MachineBasicBlock *MBB = InitialMBB;
455494
StateT State = InitialState;
456495
auto I = InitialI;
@@ -477,18 +516,20 @@ hasHazard(StateT InitialState,
477516
}
478517

479518
if (!Expired) {
480-
unsigned StateHash = State.getHashValue();
519+
unsigned StateIdx = ResolveStateIdx(State);
481520
for (MachineBasicBlock *Pred : MBB->predecessors()) {
482-
if (!Visited.insert(std::pair(Pred, StateHash)).second)
521+
if (!Visited.insert(std::pair(Pred, StateIdx)).second)
483522
continue;
484-
Worklist.emplace_back(Pred, State);
523+
Worklist.emplace_back(Pred, StateIdx);
485524
}
486525
}
487526

488527
if (Worklist.empty())
489528
break;
490529

491-
std::tie(MBB, State) = Worklist.pop_back_val();
530+
unsigned StateIdx;
531+
std::tie(MBB, StateIdx) = Worklist.pop_back_val();
532+
State = States[StateIdx];
492533
I = MBB->instr_rbegin();
493534
}
494535

@@ -1658,6 +1699,10 @@ bool GCNHazardRecognizer::fixVALUPartialForwardingHazard(MachineInstr *MI) {
16581699
unsigned getHashValue() const {
16591700
return hash_combine(ExecPos, VALUs, hash_combine_range(DefPos));
16601701
}
1702+
static bool isEqual(const StateType &LHS, const StateType &RHS) {
1703+
return LHS.DefPos == RHS.DefPos && LHS.ExecPos == RHS.ExecPos &&
1704+
LHS.VALUs == RHS.VALUs;
1705+
}
16611706
};
16621707

16631708
StateType State;
@@ -1796,6 +1841,9 @@ bool GCNHazardRecognizer::fixVALUTransUseHazard(MachineInstr *MI) {
17961841
int TRANS = 0;
17971842

17981843
unsigned getHashValue() const { return hash_combine(VALUs, TRANS); }
1844+
static bool isEqual(const StateType &LHS, const StateType &RHS) {
1845+
return LHS.VALUs == RHS.VALUs && LHS.TRANS == RHS.TRANS;
1846+
}
17991847
};
18001848

18011849
StateType State;

0 commit comments

Comments
 (0)