@@ -57,7 +57,6 @@ AMDGPUSubtarget::getMaxLocalMemSizeWithWaveCount(unsigned NWaves,
5757
5858std::pair<unsigned , unsigned > AMDGPUSubtarget::getOccupancyWithWorkGroupSizes (
5959 uint32_t LDSBytes, std::pair<unsigned , unsigned > FlatWorkGroupSizes) const {
60-
6160 // FIXME: We should take into account the LDS allocation granularity.
6261 const unsigned MaxWGsLDS = getLocalMemorySize () / std::max (LDSBytes, 1u );
6362
@@ -133,6 +132,58 @@ std::pair<unsigned, unsigned> AMDGPUSubtarget::getOccupancyWithWorkGroupSizes(
133132 // wavefronts are spread across all EUs as evenly as possible.
134133 return {std::clamp (MinWavesPerCU / getEUsPerCU (), 1U , WavesPerEU),
135134 std::clamp (divideCeil (MaxWavesPerCU, getEUsPerCU ()), 1U , WavesPerEU)};
135+ }
136+
137+ // FIXME: Should return min,max range.
138+ //
139+ // Returns the maximum occupancy, in number of waves per SIMD / EU, that can
140+ // be achieved when only the given function is running on the machine; and
141+ // taking into account the overall number of wave slots, the (maximum) workgroup
142+ // size, and the per-workgroup LDS allocation size.
143+ unsigned
144+ AMDGPUSubtarget::getOccupancyWithLocalMemSize (uint32_t Bytes,
145+ const Function &F) const {
146+ const unsigned MaxWorkGroupSize = getFlatWorkGroupSizes (F).second ;
147+ const unsigned MaxWorkGroupsPerCu = getMaxWorkGroupsPerCU (MaxWorkGroupSize);
148+ if (!MaxWorkGroupsPerCu)
149+ return 0 ;
150+
151+ const unsigned WaveSize = getWavefrontSize ();
152+
153+ // FIXME: Do we need to account for alignment requirement of LDS rounding the
154+ // size up?
155+ // Compute restriction based on LDS usage
156+ unsigned NumGroups = getLocalMemorySize () / (Bytes ? Bytes : 1u );
157+
158+ // This can be queried with more LDS than is possible, so just assume the
159+ // worst.
160+ if (NumGroups == 0 )
161+ return 1 ;
162+
163+ NumGroups = std::min (MaxWorkGroupsPerCu, NumGroups);
164+
165+ // Round to the number of waves per CU.
166+ const unsigned MaxGroupNumWaves = divideCeil (MaxWorkGroupSize, WaveSize);
167+ unsigned MaxWaves = NumGroups * MaxGroupNumWaves;
168+
169+ // Number of waves per EU (SIMD).
170+ MaxWaves = divideCeil (MaxWaves, getEUsPerCU ());
171+
172+ // Clamp to the maximum possible number of waves.
173+ MaxWaves = std::min (MaxWaves, getMaxWavesPerEU ());
174+
175+ // FIXME: Needs to be a multiple of the group size?
176+ // MaxWaves = MaxGroupNumWaves * (MaxWaves / MaxGroupNumWaves);
177+
178+ assert (MaxWaves > 0 && MaxWaves <= getMaxWavesPerEU () &&
179+ " computed invalid occupancy" );
180+ return MaxWaves;
181+ }
182+
183+ unsigned
184+ AMDGPUSubtarget::getOccupancyWithLocalMemSize (const MachineFunction &MF) const {
185+ const auto *MFI = MF.getInfo <SIMachineFunctionInfo>();
186+ return getOccupancyWithLocalMemSize (MFI->getLDSSize (), MF.getFunction ());
136187}
137188
138189std::pair<unsigned , unsigned > AMDGPUSubtarget::getOccupancyWithWorkGroupSizes (
0 commit comments