Skip to content

Commit e5d72bf

Browse files
committed
Comments
1 parent 7eceab2 commit e5d72bf

File tree

1 file changed

+40
-18
lines changed

1 file changed

+40
-18
lines changed

llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,15 @@ auto inst_counter_types(InstCounterType MaxCounter = NUM_INST_CNTS) {
100100
/// Integer IDs used to track vector memory locations we may have to wait on.
101101
/// Encoded as u16 chunks:
102102
///
103-
/// [0, MAX_REGUNITS ): MCRegUnit
104-
/// [FIRST_LDSDMA, LAST_LDSDMA ): LDS DMA IDs
103+
/// [0, REGUNITS_END ): MCRegUnit
104+
/// [LDSDMA_BEGIN, LDSDMA_END ) : LDS DMA IDs
105+
///
106+
/// NOTE: The choice of encoding these as "u16 chunks" is arbitrary.
107+
/// It gives (2 << 16) - 1 entries per category which is more than enough
108+
/// for all register units. MCPhysReg is u16 so we don't even support >u16
109+
/// physical register numbers at this time, let alone >u16 register units.
110+
/// In any case, an assertion in "WaitcntBrackets" ensures REGUNITS_END
111+
/// is enough for all register units.
105112
using VMEMID = uint32_t;
106113

107114
enum : VMEMID {
@@ -593,7 +600,12 @@ class SIInsertWaitcnts {
593600
// "s_waitcnt 0" before use.
594601
class WaitcntBrackets {
595602
public:
596-
WaitcntBrackets(const SIInsertWaitcnts *Context) : Context(Context) {}
603+
WaitcntBrackets(const SIInsertWaitcnts *Context) : Context(Context) {
604+
static_assert(REGUNITS_BEGIN == 0,
605+
"REGUNITS_BEGIN must be zero; tracking depends on being able "
606+
"to convert a register unit ID to a VMEMID directly!");
607+
assert(Context->TRI->getNumRegUnits() < REGUNITS_END);
608+
}
597609

598610
bool isSmemCounter(InstCounterType T) const {
599611
return T == Context->SmemAccessCounter || T == X_CNT;
@@ -737,10 +749,10 @@ class WaitcntBrackets {
737749

738750
iterator_range<MCRegUnitIterator> regunits(MCPhysReg Reg) const {
739751
assert(Reg != AMDGPU::SCC && "Shouldn't be used on SCC");
740-
const TargetRegisterClass *RC = Context->TRI->getPhysRegBaseClass(Reg);
741-
unsigned Size = Context->TRI->getRegSizeInBits(*RC);
742752
if (!Context->TRI->isInAllocatableClass(Reg))
743753
return {{}, {}};
754+
const TargetRegisterClass *RC = Context->TRI->getPhysRegBaseClass(Reg);
755+
unsigned Size = Context->TRI->getRegSizeInBits(*RC);
744756
if (Size == 16 && Context->ST->hasD16Writes32BitVgpr())
745757
Reg = Context->TRI->get32BitRegister(Reg);
746758
return Context->TRI->regunits(Reg);
@@ -801,6 +813,10 @@ class WaitcntBrackets {
801813
// For the VMem case, if the key is within the range of LDS DMA IDs,
802814
// then the corresponding index into the `LDSDMAStores` vector below is:
803815
// Key - LDSDMA_BEGIN - 1
816+
// This is because LDSDMA_BEGIN is a generic entry and does not have an
817+
// associated MachineInstr.
818+
//
819+
// TODO: Could we track SCC alongside SGPRs so it's not longer a special case?
804820

805821
struct VGPRInfo {
806822
// Scores for all instruction counters.
@@ -827,7 +843,6 @@ class WaitcntBrackets {
827843

828844
// Store representative LDS DMA operations. The only useful info here is
829845
// alias info. One store is kept per unique AAInfo.
830-
// Entry zero is the "generic" entry that applies to all LDSDMA stores.
831846
SmallVector<const MachineInstr *> LDSDMAStores;
832847
};
833848

@@ -856,7 +871,6 @@ class SIInsertWaitcntsLegacy : public MachineFunctionPass {
856871

857872
void WaitcntBrackets::setScoreByOperand(const MachineOperand &Op,
858873
InstCounterType CntTy, unsigned Score) {
859-
assert(Op.isReg());
860874
setRegScore(Op.getReg().asMCReg(), CntTy, Score);
861875
}
862876

@@ -1023,6 +1037,10 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
10231037
(TII->isDS(Inst) || TII->mayWriteLDSThroughDMA(Inst))) {
10241038
// MUBUF and FLAT LDS DMA operations need a wait on vmcnt before LDS
10251039
// written can be accessed. A load from LDS to VMEM does not need a wait.
1040+
//
1041+
// The "Slot" is the offset from LDSDMA_BEGIN. If it's non-zero, then
1042+
// there is a MachineInstr in LDSDMAStores used to track this LDSDMA
1043+
// store. The "Slot" is the index into LDSDMAStores + 1.
10261044
unsigned Slot = 0;
10271045
for (const auto *MemOp : Inst.memoperands()) {
10281046
if (!MemOp->isStore() ||
@@ -1035,9 +1053,7 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
10351053
// original memory object and practically produced in the module LDS
10361054
// lowering pass. If there is no scope available we will not be able
10371055
// to disambiguate LDS aliasing as after the module lowering all LDS
1038-
// is squashed into a single big object. Do not attempt to use one of
1039-
// the limited LDSDMAStores for something we will not be able to use
1040-
// anyway.
1056+
// is squashed into a single big object.
10411057
if (!AAI || !AAI.Scope)
10421058
break;
10431059
for (unsigned I = 0, E = LDSDMAStores.size(); I != E && !Slot; ++I) {
@@ -1050,8 +1066,8 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
10501066
}
10511067
if (Slot)
10521068
break;
1069+
Slot = LDSDMAStores.size() + 1;
10531070
LDSDMAStores.push_back(&Inst);
1054-
Slot = LDSDMAStores.size();
10551071
break;
10561072
}
10571073
setVMemScore(LDSDMA_BEGIN, T, CurrScore);
@@ -1110,8 +1126,11 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
11101126
// Print vgpr scores.
11111127
unsigned LB = getScoreLB(T);
11121128

1113-
for (auto &[ID, Info] : VMem) {
1114-
unsigned RegScore = Info.Scores[T];
1129+
SmallVector<VMEMID> SortedVMEMIDs(VMem.keys());
1130+
sort(SortedVMEMIDs);
1131+
1132+
for (auto ID : SortedVMEMIDs) {
1133+
unsigned RegScore = VMem.at(ID).Scores[T];
11151134
if (RegScore <= LB)
11161135
continue;
11171136
unsigned RelScore = RegScore - LB - 1;
@@ -1126,8 +1145,10 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
11261145

11271146
// Also need to print sgpr scores for lgkm_cnt or xcnt.
11281147
if (isSmemCounter(T)) {
1129-
for (auto &[ID, Info] : SGPRs) {
1130-
unsigned RegScore = Info.Scores[getSgprScoresIdx(T)];
1148+
SmallVector<MCRegUnit> SortedSMEMIDs(SGPRs.keys());
1149+
sort(SortedSMEMIDs);
1150+
for (auto ID : SortedSMEMIDs) {
1151+
unsigned RegScore = SGPRs.at(ID).Scores[getSgprScoresIdx(T)];
11311152
if (RegScore <= LB)
11321153
continue;
11331154
unsigned RelScore = RegScore - LB - 1;
@@ -2402,9 +2423,10 @@ bool WaitcntBrackets::merge(const WaitcntBrackets &Other) {
24022423
}
24032424

24042425
for (auto &[TID, Info] : Other.VMem) {
2405-
unsigned char NewVmemTypes = VMem[TID].VMEMTypes | Info.VMEMTypes;
2406-
StrictDom |= NewVmemTypes != VMem[TID].VMEMTypes;
2407-
VMem[TID].VMEMTypes = NewVmemTypes;
2426+
auto &Value = VMem[TID];
2427+
unsigned char NewVmemTypes = Value.VMEMTypes | Info.VMEMTypes;
2428+
StrictDom |= NewVmemTypes != Value.VMEMTypes;
2429+
Value.VMEMTypes = NewVmemTypes;
24082430
}
24092431

24102432
return StrictDom;

0 commit comments

Comments
 (0)