@@ -572,9 +572,10 @@ MCRegister SIRegisterInfo::reservedPrivateSegmentBufferReg(
572572std::pair<unsigned , unsigned >
573573SIRegisterInfo::getMaxNumVectorRegs (const MachineFunction &MF) const {
574574 const SIMachineFunctionInfo *MFI = MF.getInfo <SIMachineFunctionInfo>();
575- unsigned MaxNumVGPRs = ST.getMaxNumVGPRs (MF);
576- unsigned MaxNumAGPRs = MaxNumVGPRs;
577- unsigned TotalNumVGPRs = AMDGPU::VGPR_32RegClass.getNumRegs ();
575+ const unsigned MaxVectorRegs = ST.getMaxNumVGPRs (MF);
576+
577+ unsigned MaxNumVGPRs = MaxVectorRegs;
578+ unsigned MaxNumAGPRs = 0 ;
578579
579580 // On GFX90A, the number of VGPRs and AGPRs need not be equal. Theoretically,
580581 // a wave may have up to 512 total vector registers combining together both
@@ -585,16 +586,49 @@ SIRegisterInfo::getMaxNumVectorRegs(const MachineFunction &MF) const {
585586 // TODO: it shall be possible to estimate maximum AGPR/VGPR pressure and split
586587 // register file accordingly.
587588 if (ST.hasGFX90AInsts ()) {
588- if (MFI->mayNeedAGPRs ()) {
589- MaxNumVGPRs /= 2 ;
590- MaxNumAGPRs = MaxNumVGPRs;
589+ unsigned MinNumAGPRs = 0 ;
590+ const unsigned TotalNumAGPRs = AMDGPU::AGPR_32RegClass.getNumRegs ();
591+ const unsigned TotalNumVGPRs = AMDGPU::VGPR_32RegClass.getNumRegs ();
592+
593+ const std::pair<unsigned , unsigned > DefaultNumAGPR = {~0u , ~0u };
594+
595+ // TODO: Replace amdgpu-no-agpr with amdgpu-agpr-alloc=0
596+ // TODO: Move this logic into subtarget on IR function
597+ //
598+ // TODO: The lower bound should probably force the number of required
599+ // registers up, overriding amdgpu-waves-per-eu.
600+ std::tie (MinNumAGPRs, MaxNumAGPRs) = AMDGPU::getIntegerPairAttribute (
601+ MF.getFunction (), " amdgpu-agpr-alloc" , DefaultNumAGPR,
602+ /* OnlyFirstRequired=*/ true );
603+
604+ if (MinNumAGPRs == DefaultNumAGPR.first ) {
605+ // Default to splitting half the registers if AGPRs are required.
606+
607+ if (MFI->mayNeedAGPRs ())
608+ MinNumAGPRs = MaxNumAGPRs = MaxVectorRegs / 2 ;
609+ else
610+ MinNumAGPRs = 0 ;
591611 } else {
592- if (MaxNumVGPRs > TotalNumVGPRs) {
593- MaxNumAGPRs = MaxNumVGPRs - TotalNumVGPRs;
594- MaxNumVGPRs = TotalNumVGPRs;
595- } else
596- MaxNumAGPRs = 0 ;
612+ // Align to accum_offset's allocation granularity.
613+ MinNumAGPRs = alignTo (MinNumAGPRs, 4 );
614+
615+ MinNumAGPRs = std::min (MinNumAGPRs, TotalNumAGPRs);
597616 }
617+
618+ // Clamp values to be inbounds of our limits, and ensure min <= max.
619+
620+ MaxNumAGPRs = std::min (std::max (MinNumAGPRs, MaxNumAGPRs), MaxVectorRegs);
621+ MinNumAGPRs = std::min (std::min (MinNumAGPRs, TotalNumAGPRs), MaxNumAGPRs);
622+
623+ MaxNumVGPRs = std::min (MaxVectorRegs - MinNumAGPRs, TotalNumVGPRs);
624+ MaxNumAGPRs = std::min (MaxVectorRegs - MaxNumVGPRs, MaxNumAGPRs);
625+
626+ assert (MaxNumVGPRs + MaxNumAGPRs <= MaxVectorRegs &&
627+ MaxNumAGPRs <= TotalNumAGPRs && MaxNumVGPRs <= TotalNumVGPRs &&
628+ " invalid register counts" );
629+ } else if (ST.hasMAIInsts ()) {
630+ // On gfx908 the number of AGPRs always equals the number of VGPRs.
631+ MaxNumAGPRs = MaxNumVGPRs = MaxVectorRegs;
598632 }
599633
600634 return std::pair (MaxNumVGPRs, MaxNumAGPRs);
0 commit comments