@@ -100,8 +100,15 @@ auto inst_counter_types(InstCounterType MaxCounter = NUM_INST_CNTS) {
100100// / Integer IDs used to track vector memory locations we may have to wait on.
101101// / Encoded as u16 chunks:
102102// /
103- // / [0, MAX_REGUNITS ): MCRegUnit
104- // / [FIRST_LDSDMA, LAST_LDSDMA ): LDS DMA IDs
103+ // / [0, REGUNITS_END ): MCRegUnit
104+ // / [LDSDMA_BEGIN, LDSDMA_END ) : LDS DMA IDs
105+ // /
106+ // / NOTE: The choice of encoding these as "u16 chunks" is arbitrary.
107+ // / It gives (2 << 16) - 1 entries per category which is more than enough
108+ // / for all register units. MCPhysReg is u16 so we don't even support >u16
109+ // / physical register numbers at this time, let alone >u16 register units.
110+ // / In any case, an assertion in "WaitcntBrackets" ensures REGUNITS_END
111+ // / is enough for all register units.
105112using VMEMID = uint32_t ;
106113
107114enum : VMEMID {
@@ -593,7 +600,12 @@ class SIInsertWaitcnts {
593600// "s_waitcnt 0" before use.
594601class WaitcntBrackets {
595602public:
596- WaitcntBrackets (const SIInsertWaitcnts *Context) : Context(Context) {}
603+ WaitcntBrackets (const SIInsertWaitcnts *Context) : Context(Context) {
604+ static_assert (REGUNITS_BEGIN == 0 ,
605+ " REGUNITS_BEGIN must be zero; tracking depends on being able "
606+ " to convert a register unit ID to a VMEMID directly!" );
607+ assert (Context->TRI ->getNumRegUnits () < REGUNITS_END);
608+ }
597609
598610 bool isSmemCounter (InstCounterType T) const {
599611 return T == Context->SmemAccessCounter || T == X_CNT;
@@ -737,10 +749,10 @@ class WaitcntBrackets {
737749
738750 iterator_range<MCRegUnitIterator> regunits (MCPhysReg Reg) const {
739751 assert (Reg != AMDGPU::SCC && " Shouldn't be used on SCC" );
740- const TargetRegisterClass *RC = Context->TRI ->getPhysRegBaseClass (Reg);
741- unsigned Size = Context->TRI ->getRegSizeInBits (*RC);
742752 if (!Context->TRI ->isInAllocatableClass (Reg))
743753 return {{}, {}};
754+ const TargetRegisterClass *RC = Context->TRI ->getPhysRegBaseClass (Reg);
755+ unsigned Size = Context->TRI ->getRegSizeInBits (*RC);
744756 if (Size == 16 && Context->ST ->hasD16Writes32BitVgpr ())
745757 Reg = Context->TRI ->get32BitRegister (Reg);
746758 return Context->TRI ->regunits (Reg);
@@ -801,6 +813,10 @@ class WaitcntBrackets {
801813 // For the VMem case, if the key is within the range of LDS DMA IDs,
802814 // then the corresponding index into the `LDSDMAStores` vector below is:
803815 // Key - LDSDMA_BEGIN - 1
816+ // This is because LDSDMA_BEGIN is a generic entry and does not have an
817+ // associated MachineInstr.
818+ //
819+ // TODO: Could we track SCC alongside SGPRs so it's not longer a special case?
804820
805821 struct VGPRInfo {
806822 // Scores for all instruction counters.
@@ -827,7 +843,6 @@ class WaitcntBrackets {
827843
828844 // Store representative LDS DMA operations. The only useful info here is
829845 // alias info. One store is kept per unique AAInfo.
830- // Entry zero is the "generic" entry that applies to all LDSDMA stores.
831846 SmallVector<const MachineInstr *> LDSDMAStores;
832847};
833848
@@ -856,7 +871,6 @@ class SIInsertWaitcntsLegacy : public MachineFunctionPass {
856871
857872void WaitcntBrackets::setScoreByOperand (const MachineOperand &Op,
858873 InstCounterType CntTy, unsigned Score) {
859- assert (Op.isReg ());
860874 setRegScore (Op.getReg ().asMCReg (), CntTy, Score);
861875}
862876
@@ -1023,6 +1037,10 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
10231037 (TII->isDS (Inst) || TII->mayWriteLDSThroughDMA (Inst))) {
10241038 // MUBUF and FLAT LDS DMA operations need a wait on vmcnt before LDS
10251039 // written can be accessed. A load from LDS to VMEM does not need a wait.
1040+ //
1041+ // The "Slot" is the offset from LDSDMA_BEGIN. If it's non-zero, then
1042+ // there is a MachineInstr in LDSDMAStores used to track this LDSDMA
1043+ // store. The "Slot" is the index into LDSDMAStores + 1.
10261044 unsigned Slot = 0 ;
10271045 for (const auto *MemOp : Inst.memoperands ()) {
10281046 if (!MemOp->isStore () ||
@@ -1035,9 +1053,7 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
10351053 // original memory object and practically produced in the module LDS
10361054 // lowering pass. If there is no scope available we will not be able
10371055 // to disambiguate LDS aliasing as after the module lowering all LDS
1038- // is squashed into a single big object. Do not attempt to use one of
1039- // the limited LDSDMAStores for something we will not be able to use
1040- // anyway.
1056+ // is squashed into a single big object.
10411057 if (!AAI || !AAI.Scope )
10421058 break ;
10431059 for (unsigned I = 0 , E = LDSDMAStores.size (); I != E && !Slot; ++I) {
@@ -1050,8 +1066,8 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
10501066 }
10511067 if (Slot)
10521068 break ;
1069+ Slot = LDSDMAStores.size () + 1 ;
10531070 LDSDMAStores.push_back (&Inst);
1054- Slot = LDSDMAStores.size ();
10551071 break ;
10561072 }
10571073 setVMemScore (LDSDMA_BEGIN, T, CurrScore);
@@ -1110,8 +1126,11 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
11101126 // Print vgpr scores.
11111127 unsigned LB = getScoreLB (T);
11121128
1113- for (auto &[ID, Info] : VMem) {
1114- unsigned RegScore = Info.Scores [T];
1129+ SmallVector<VMEMID> SortedVMEMIDs (VMem.keys ());
1130+ sort (SortedVMEMIDs);
1131+
1132+ for (auto ID : SortedVMEMIDs) {
1133+ unsigned RegScore = VMem.at (ID).Scores [T];
11151134 if (RegScore <= LB)
11161135 continue ;
11171136 unsigned RelScore = RegScore - LB - 1 ;
@@ -1126,8 +1145,10 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
11261145
11271146 // Also need to print sgpr scores for lgkm_cnt or xcnt.
11281147 if (isSmemCounter (T)) {
1129- for (auto &[ID, Info] : SGPRs) {
1130- unsigned RegScore = Info.Scores [getSgprScoresIdx (T)];
1148+ SmallVector<MCRegUnit> SortedSMEMIDs (SGPRs.keys ());
1149+ sort (SortedSMEMIDs);
1150+ for (auto ID : SortedSMEMIDs) {
1151+ unsigned RegScore = SGPRs.at (ID).Scores [getSgprScoresIdx (T)];
11311152 if (RegScore <= LB)
11321153 continue ;
11331154 unsigned RelScore = RegScore - LB - 1 ;
@@ -2402,9 +2423,10 @@ bool WaitcntBrackets::merge(const WaitcntBrackets &Other) {
24022423 }
24032424
24042425 for (auto &[TID, Info] : Other.VMem ) {
2405- unsigned char NewVmemTypes = VMem[TID].VMEMTypes | Info.VMEMTypes ;
2406- StrictDom |= NewVmemTypes != VMem[TID].VMEMTypes ;
2407- VMem[TID].VMEMTypes = NewVmemTypes;
2426+ auto &Value = VMem[TID];
2427+ unsigned char NewVmemTypes = Value.VMEMTypes | Info.VMEMTypes ;
2428+ StrictDom |= NewVmemTypes != Value.VMEMTypes ;
2429+ Value.VMEMTypes = NewVmemTypes;
24082430 }
24092431
24102432 return StrictDom;
0 commit comments