@@ -100,8 +100,15 @@ auto inst_counter_types(InstCounterType MaxCounter = NUM_INST_CNTS) {
100100// / Integer IDs used to track vector memory locations we may have to wait on.
101101// / Encoded as u16 chunks:
102102// /
103- // / [0, MAX_REGUNITS ): MCRegUnit
104- // / [FIRST_LDSDMA, LAST_LDSDMA ): LDS DMA IDs
103+ // / [0, REGUNITS_END ): MCRegUnit
104+ // / [LDSDMA_BEGIN, LDSDMA_END ) : LDS DMA IDs
105+ // /
106+ // / NOTE: The choice of encoding these as "u16 chunks" is arbitrary.
107+ // / It gives (2 << 16) - 1 entries per category which is more than enough
108+ // / for all register units. MCPhysReg is u16 so we don't even support >u16
109+ // / physical register numbers at this time, let alone >u16 register units.
110+ // / In any case, an assertion in "WaitcntBrackets" ensures REGUNITS_END
111+ // / is enough for all register units.
105112using VMEMID = uint32_t ;
106113
107114enum : VMEMID {
@@ -586,7 +593,12 @@ class SIInsertWaitcnts {
586593// "s_waitcnt 0" before use.
587594class WaitcntBrackets {
588595public:
589- WaitcntBrackets (const SIInsertWaitcnts *Context) : Context(Context) {}
596+ WaitcntBrackets (const SIInsertWaitcnts *Context) : Context(Context) {
597+ static_assert (REGUNITS_BEGIN == 0 ,
598+ " REGUNITS_BEGIN must be zero; tracking depends on being able "
599+ " to convert a register unit ID to a VMEMID directly!" );
600+ assert (Context->TRI ->getNumRegUnits () < REGUNITS_END);
601+ }
590602
591603 bool isSmemCounter (InstCounterType T) const {
592604 return T == Context->SmemAccessCounter || T == X_CNT;
@@ -730,10 +742,10 @@ class WaitcntBrackets {
730742
731743 iterator_range<MCRegUnitIterator> regunits (MCPhysReg Reg) const {
732744 assert (Reg != AMDGPU::SCC && " Shouldn't be used on SCC" );
733- const TargetRegisterClass *RC = Context->TRI ->getPhysRegBaseClass (Reg);
734- unsigned Size = Context->TRI ->getRegSizeInBits (*RC);
735745 if (!Context->TRI ->isInAllocatableClass (Reg))
736746 return {{}, {}};
747+ const TargetRegisterClass *RC = Context->TRI ->getPhysRegBaseClass (Reg);
748+ unsigned Size = Context->TRI ->getRegSizeInBits (*RC);
737749 if (Size == 16 && Context->ST ->hasD16Writes32BitVgpr ())
738750 Reg = Context->TRI ->get32BitRegister (Reg);
739751 return Context->TRI ->regunits (Reg);
@@ -794,6 +806,10 @@ class WaitcntBrackets {
794806 // For the VMem case, if the key is within the range of LDS DMA IDs,
795807 // then the corresponding index into the `LDSDMAStores` vector below is:
796808 // Key - LDSDMA_BEGIN - 1
809+ // This is because LDSDMA_BEGIN is a generic entry and does not have an
810+ // associated MachineInstr.
811+ //
812+ // TODO: Could we track SCC alongside SGPRs so it's not longer a special case?
797813
798814 struct VGPRInfo {
799815 // Scores for all instruction counters.
@@ -820,7 +836,6 @@ class WaitcntBrackets {
820836
821837 // Store representative LDS DMA operations. The only useful info here is
822838 // alias info. One store is kept per unique AAInfo.
823- // Entry zero is the "generic" entry that applies to all LDSDMA stores.
824839 SmallVector<const MachineInstr *> LDSDMAStores;
825840};
826841
@@ -849,7 +864,6 @@ class SIInsertWaitcntsLegacy : public MachineFunctionPass {
849864
850865void WaitcntBrackets::setScoreByOperand (const MachineOperand &Op,
851866 InstCounterType CntTy, unsigned Score) {
852- assert (Op.isReg ());
853867 setRegScore (Op.getReg ().asMCReg (), CntTy, Score);
854868}
855869
@@ -1017,6 +1031,10 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
10171031 (TII->isDS (Inst) || TII->mayWriteLDSThroughDMA (Inst))) {
10181032 // MUBUF and FLAT LDS DMA operations need a wait on vmcnt before LDS
10191033 // written can be accessed. A load from LDS to VMEM does not need a wait.
1034+ //
1035+ // The "Slot" is the offset from LDSDMA_BEGIN. If it's non-zero, then
1036+ // there is a MachineInstr in LDSDMAStores used to track this LDSDMA
1037+ // store. The "Slot" is the index into LDSDMAStores + 1.
10201038 unsigned Slot = 0 ;
10211039 for (const auto *MemOp : Inst.memoperands ()) {
10221040 if (!MemOp->isStore () ||
@@ -1029,9 +1047,7 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
10291047 // original memory object and practically produced in the module LDS
10301048 // lowering pass. If there is no scope available we will not be able
10311049 // to disambiguate LDS aliasing as after the module lowering all LDS
1032- // is squashed into a single big object. Do not attempt to use one of
1033- // the limited LDSDMAStores for something we will not be able to use
1034- // anyway.
1050+ // is squashed into a single big object.
10351051 if (!AAI || !AAI.Scope )
10361052 break ;
10371053 for (unsigned I = 0 , E = LDSDMAStores.size (); I != E && !Slot; ++I) {
@@ -1044,15 +1060,14 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
10441060 }
10451061 if (Slot)
10461062 break ;
1047- // The slot may not be valid because it can be >= NUM_LDS_VGPRS which
1063+ // The slot may not be valid because it can be >= NUM_LDSDMA which
10481064 // means the scoreboard cannot track it. We still want to preserve the
10491065 // MI in order to check alias information, though.
10501066 LDSDMAStores.push_back (&Inst);
1051- Slot = LDSDMAStores.size ();
10521067 break ;
10531068 }
10541069 setVMemScore (LDSDMA_BEGIN, T, CurrScore);
1055- if (Slot && (LDSDMA_BEGIN + Slot) < LDSDMA_END )
1070+ if (Slot && Slot < NUM_LDSDMA )
10561071 setVMemScore (LDSDMA_BEGIN + Slot, T, CurrScore);
10571072 }
10581073
@@ -1107,8 +1122,11 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
11071122 // Print vgpr scores.
11081123 unsigned LB = getScoreLB (T);
11091124
1110- for (auto &[ID, Info] : VMem) {
1111- unsigned RegScore = Info.Scores [T];
1125+ SmallVector<VMEMID> SortedVMEMIDs (VMem.keys ());
1126+ sort (SortedVMEMIDs);
1127+
1128+ for (auto ID : SortedVMEMIDs) {
1129+ unsigned RegScore = VMem.at (ID).Scores [T];
11121130 if (RegScore <= LB)
11131131 continue ;
11141132 unsigned RelScore = RegScore - LB - 1 ;
@@ -1123,8 +1141,10 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
11231141
11241142 // Also need to print sgpr scores for lgkm_cnt or xcnt.
11251143 if (isSmemCounter (T)) {
1126- for (auto &[ID, Info] : SGPRs) {
1127- unsigned RegScore = Info.Scores [getSgprScoresIdx (T)];
1144+ SmallVector<MCRegUnit> SortedSMEMIDs (SGPRs.keys ());
1145+ sort (SortedSMEMIDs);
1146+ for (auto ID : SortedSMEMIDs) {
1147+ unsigned RegScore = SGPRs.at (ID).Scores [getSgprScoresIdx (T)];
11281148 if (RegScore <= LB)
11291149 continue ;
11301150 unsigned RelScore = RegScore - LB - 1 ;
@@ -2410,9 +2430,10 @@ bool WaitcntBrackets::merge(const WaitcntBrackets &Other) {
24102430 }
24112431
24122432 for (auto &[TID, Info] : Other.VMem ) {
2413- unsigned char NewVmemTypes = VMem[TID].VMEMTypes | Info.VMEMTypes ;
2414- StrictDom |= NewVmemTypes != VMem[TID].VMEMTypes ;
2415- VMem[TID].VMEMTypes = NewVmemTypes;
2433+ auto &Value = VMem[TID];
2434+ unsigned char NewVmemTypes = Value.VMEMTypes | Info.VMEMTypes ;
2435+ StrictDom |= NewVmemTypes != Value.VMEMTypes ;
2436+ Value.VMEMTypes = NewVmemTypes;
24162437 }
24172438
24182439 return StrictDom;
0 commit comments