@@ -571,10 +571,10 @@ MCRegister SIRegisterInfo::reservedPrivateSegmentBufferReg(
571571
572572std::pair<unsigned , unsigned >
573573SIRegisterInfo::getMaxNumVectorRegs (const MachineFunction &MF) const {
574- const SIMachineFunctionInfo *MFI = MF. getInfo <SIMachineFunctionInfo>( );
575- unsigned MaxNumVGPRs = ST. getMaxNumVGPRs (MF);
576- unsigned MaxNumAGPRs = MaxNumVGPRs ;
577- unsigned TotalNumVGPRs = AMDGPU::VGPR_32RegClass. getNumRegs () ;
574+ const unsigned MaxVectorRegs = ST. getMaxNumVGPRs (MF );
575+
576+ unsigned MaxNumVGPRs = MaxVectorRegs ;
577+ unsigned MaxNumAGPRs = 0 ;
578578
579579 // On GFX90A, the number of VGPRs and AGPRs need not be equal. Theoretically,
580580 // a wave may have up to 512 total vector registers combining together both
@@ -585,16 +585,44 @@ SIRegisterInfo::getMaxNumVectorRegs(const MachineFunction &MF) const {
585585 // TODO: it shall be possible to estimate maximum AGPR/VGPR pressure and split
586586 // register file accordingly.
587587 if (ST.hasGFX90AInsts ()) {
588- if (MFI->mayNeedAGPRs ()) {
589- MaxNumVGPRs /= 2 ;
590- MaxNumAGPRs = MaxNumVGPRs;
588+ unsigned MinNumAGPRs = 0 ;
589+ const unsigned TotalNumAGPRs = AMDGPU::AGPR_32RegClass.getNumRegs ();
590+ const unsigned TotalNumVGPRs = AMDGPU::VGPR_32RegClass.getNumRegs ();
591+
592+ const std::pair<unsigned , unsigned > DefaultNumAGPR = {~0u , ~0u };
593+
594+ // TODO: Move this logic into subtarget on IR function
595+ //
596+ // TODO: The lower bound should probably force the number of required
597+ // registers up, overriding amdgpu-waves-per-eu.
598+ std::tie (MinNumAGPRs, MaxNumAGPRs) = AMDGPU::getIntegerPairAttribute (
599+ MF.getFunction (), " amdgpu-agpr-alloc" , DefaultNumAGPR,
600+ /* OnlyFirstRequired=*/ true );
601+
602+ if (MinNumAGPRs == DefaultNumAGPR.first ) {
603+ // Default to splitting half the registers if AGPRs are required.
604+ MinNumAGPRs = MaxNumAGPRs = MaxVectorRegs / 2 ;
591605 } else {
592- if (MaxNumVGPRs > TotalNumVGPRs) {
593- MaxNumAGPRs = MaxNumVGPRs - TotalNumVGPRs;
594- MaxNumVGPRs = TotalNumVGPRs;
595- } else
596- MaxNumAGPRs = 0 ;
606+ // Align to accum_offset's allocation granularity.
607+ MinNumAGPRs = alignTo (MinNumAGPRs, 4 );
608+
609+ MinNumAGPRs = std::min (MinNumAGPRs, TotalNumAGPRs);
597610 }
611+
612+ // Clamp values to be inbounds of our limits, and ensure min <= max.
613+
614+ MaxNumAGPRs = std::min (std::max (MinNumAGPRs, MaxNumAGPRs), MaxVectorRegs);
615+ MinNumAGPRs = std::min (std::min (MinNumAGPRs, TotalNumAGPRs), MaxNumAGPRs);
616+
617+ MaxNumVGPRs = std::min (MaxVectorRegs - MinNumAGPRs, TotalNumVGPRs);
618+ MaxNumAGPRs = std::min (MaxVectorRegs - MaxNumVGPRs, MaxNumAGPRs);
619+
620+ assert (MaxNumVGPRs + MaxNumAGPRs <= MaxVectorRegs &&
621+ MaxNumAGPRs <= TotalNumAGPRs && MaxNumVGPRs <= TotalNumVGPRs &&
622+ " invalid register counts" );
623+ } else if (ST.hasMAIInsts ()) {
624+ // On gfx908 the number of AGPRs always equals the number of VGPRs.
625+ MaxNumAGPRs = MaxNumVGPRs = MaxVectorRegs;
598626 }
599627
600628 return std::pair (MaxNumVGPRs, MaxNumAGPRs);
0 commit comments