Skip to content

Commit c81dddc

Browse files
committed
Comments
1 parent f041563 commit c81dddc

File tree

1 file changed

+41
-20
lines changed

1 file changed

+41
-20
lines changed

llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,15 @@ auto inst_counter_types(InstCounterType MaxCounter = NUM_INST_CNTS) {
100100
/// Integer IDs used to track vector memory locations we may have to wait on.
101101
/// Encoded as u16 chunks:
102102
///
103-
/// [0, MAX_REGUNITS ): MCRegUnit
104-
/// [FIRST_LDSDMA, LAST_LDSDMA ): LDS DMA IDs
103+
/// [0, REGUNITS_END ): MCRegUnit
104+
/// [LDSDMA_BEGIN, LDSDMA_END ) : LDS DMA IDs
105+
///
106+
/// NOTE: The choice of encoding these as "u16 chunks" is arbitrary.
107+
/// It gives (2 << 16) - 1 entries per category which is more than enough
108+
/// for all register units. MCPhysReg is u16 so we don't even support >u16
109+
/// physical register numbers at this time, let alone >u16 register units.
110+
/// In any case, an assertion in "WaitcntBrackets" ensures REGUNITS_END
111+
/// is enough for all register units.
105112
using VMEMID = uint32_t;
106113

107114
enum : VMEMID {
@@ -586,7 +593,12 @@ class SIInsertWaitcnts {
586593
// "s_waitcnt 0" before use.
587594
class WaitcntBrackets {
588595
public:
589-
WaitcntBrackets(const SIInsertWaitcnts *Context) : Context(Context) {}
596+
WaitcntBrackets(const SIInsertWaitcnts *Context) : Context(Context) {
597+
static_assert(REGUNITS_BEGIN == 0,
598+
"REGUNITS_BEGIN must be zero; tracking depends on being able "
599+
"to convert a register unit ID to a VMEMID directly!");
600+
assert(Context->TRI->getNumRegUnits() < REGUNITS_END);
601+
}
590602

591603
bool isSmemCounter(InstCounterType T) const {
592604
return T == Context->SmemAccessCounter || T == X_CNT;
@@ -730,10 +742,10 @@ class WaitcntBrackets {
730742

731743
iterator_range<MCRegUnitIterator> regunits(MCPhysReg Reg) const {
732744
assert(Reg != AMDGPU::SCC && "Shouldn't be used on SCC");
733-
const TargetRegisterClass *RC = Context->TRI->getPhysRegBaseClass(Reg);
734-
unsigned Size = Context->TRI->getRegSizeInBits(*RC);
735745
if (!Context->TRI->isInAllocatableClass(Reg))
736746
return {{}, {}};
747+
const TargetRegisterClass *RC = Context->TRI->getPhysRegBaseClass(Reg);
748+
unsigned Size = Context->TRI->getRegSizeInBits(*RC);
737749
if (Size == 16 && Context->ST->hasD16Writes32BitVgpr())
738750
Reg = Context->TRI->get32BitRegister(Reg);
739751
return Context->TRI->regunits(Reg);
@@ -794,6 +806,10 @@ class WaitcntBrackets {
794806
// For the VMem case, if the key is within the range of LDS DMA IDs,
795807
// then the corresponding index into the `LDSDMAStores` vector below is:
796808
// Key - LDSDMA_BEGIN - 1
809+
// This is because LDSDMA_BEGIN is a generic entry and does not have an
810+
// associated MachineInstr.
811+
//
812+
// TODO: Could we track SCC alongside SGPRs so it's not longer a special case?
797813

798814
struct VGPRInfo {
799815
// Scores for all instruction counters.
@@ -820,7 +836,6 @@ class WaitcntBrackets {
820836

821837
// Store representative LDS DMA operations. The only useful info here is
822838
// alias info. One store is kept per unique AAInfo.
823-
// Entry zero is the "generic" entry that applies to all LDSDMA stores.
824839
SmallVector<const MachineInstr *> LDSDMAStores;
825840
};
826841

@@ -849,7 +864,6 @@ class SIInsertWaitcntsLegacy : public MachineFunctionPass {
849864

850865
void WaitcntBrackets::setScoreByOperand(const MachineOperand &Op,
851866
InstCounterType CntTy, unsigned Score) {
852-
assert(Op.isReg());
853867
setRegScore(Op.getReg().asMCReg(), CntTy, Score);
854868
}
855869

@@ -1017,6 +1031,10 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
10171031
(TII->isDS(Inst) || TII->mayWriteLDSThroughDMA(Inst))) {
10181032
// MUBUF and FLAT LDS DMA operations need a wait on vmcnt before LDS
10191033
// written can be accessed. A load from LDS to VMEM does not need a wait.
1034+
//
1035+
// The "Slot" is the offset from LDSDMA_BEGIN. If it's non-zero, then
1036+
// there is a MachineInstr in LDSDMAStores used to track this LDSDMA
1037+
// store. The "Slot" is the index into LDSDMAStores + 1.
10201038
unsigned Slot = 0;
10211039
for (const auto *MemOp : Inst.memoperands()) {
10221040
if (!MemOp->isStore() ||
@@ -1029,9 +1047,7 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
10291047
// original memory object and practically produced in the module LDS
10301048
// lowering pass. If there is no scope available we will not be able
10311049
// to disambiguate LDS aliasing as after the module lowering all LDS
1032-
// is squashed into a single big object. Do not attempt to use one of
1033-
// the limited LDSDMAStores for something we will not be able to use
1034-
// anyway.
1050+
// is squashed into a single big object.
10351051
if (!AAI || !AAI.Scope)
10361052
break;
10371053
for (unsigned I = 0, E = LDSDMAStores.size(); I != E && !Slot; ++I) {
@@ -1044,15 +1060,14 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
10441060
}
10451061
if (Slot)
10461062
break;
1047-
// The slot may not be valid because it can be >= NUM_LDS_VGPRS which
1063+
// The slot may not be valid because it can be >= NUM_LDSDMA which
10481064
// means the scoreboard cannot track it. We still want to preserve the
10491065
// MI in order to check alias information, though.
10501066
LDSDMAStores.push_back(&Inst);
1051-
Slot = LDSDMAStores.size();
10521067
break;
10531068
}
10541069
setVMemScore(LDSDMA_BEGIN, T, CurrScore);
1055-
if (Slot && (LDSDMA_BEGIN + Slot) < LDSDMA_END)
1070+
if (Slot && Slot < NUM_LDSDMA)
10561071
setVMemScore(LDSDMA_BEGIN + Slot, T, CurrScore);
10571072
}
10581073

@@ -1107,8 +1122,11 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
11071122
// Print vgpr scores.
11081123
unsigned LB = getScoreLB(T);
11091124

1110-
for (auto &[ID, Info] : VMem) {
1111-
unsigned RegScore = Info.Scores[T];
1125+
SmallVector<VMEMID> SortedVMEMIDs(VMem.keys());
1126+
sort(SortedVMEMIDs);
1127+
1128+
for (auto ID : SortedVMEMIDs) {
1129+
unsigned RegScore = VMem.at(ID).Scores[T];
11121130
if (RegScore <= LB)
11131131
continue;
11141132
unsigned RelScore = RegScore - LB - 1;
@@ -1123,8 +1141,10 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
11231141

11241142
// Also need to print sgpr scores for lgkm_cnt or xcnt.
11251143
if (isSmemCounter(T)) {
1126-
for (auto &[ID, Info] : SGPRs) {
1127-
unsigned RegScore = Info.Scores[getSgprScoresIdx(T)];
1144+
SmallVector<MCRegUnit> SortedSMEMIDs(SGPRs.keys());
1145+
sort(SortedSMEMIDs);
1146+
for (auto ID : SortedSMEMIDs) {
1147+
unsigned RegScore = SGPRs.at(ID).Scores[getSgprScoresIdx(T)];
11281148
if (RegScore <= LB)
11291149
continue;
11301150
unsigned RelScore = RegScore - LB - 1;
@@ -2410,9 +2430,10 @@ bool WaitcntBrackets::merge(const WaitcntBrackets &Other) {
24102430
}
24112431

24122432
for (auto &[TID, Info] : Other.VMem) {
2413-
unsigned char NewVmemTypes = VMem[TID].VMEMTypes | Info.VMEMTypes;
2414-
StrictDom |= NewVmemTypes != VMem[TID].VMEMTypes;
2415-
VMem[TID].VMEMTypes = NewVmemTypes;
2433+
auto &Value = VMem[TID];
2434+
unsigned char NewVmemTypes = Value.VMEMTypes | Info.VMEMTypes;
2435+
StrictDom |= NewVmemTypes != Value.VMEMTypes;
2436+
Value.VMEMTypes = NewVmemTypes;
24162437
}
24172438

24182439
return StrictDom;

0 commit comments

Comments
 (0)