@@ -1113,47 +1113,25 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
11131113 Function *F = getAssociatedFunction ();
11141114 auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
11151115
1116- auto TakeRange = [&](std::pair<unsigned , unsigned > R) {
1117- auto [Min, Max] = R;
1118- ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
1119- IntegerRangeState RangeState (Range);
1120- clampStateAndIndicateChange (this ->getState (), RangeState);
1121- indicateOptimisticFixpoint ();
1122- };
1123-
1124- std::pair<unsigned , unsigned > MaxWavesPerEURange{
1125- 1U , InfoCache.getMaxWavesPerEU (*F)};
1126-
11271116 // If the attribute exists, we will honor it if it is not the default.
11281117 if (auto Attr = InfoCache.getWavesPerEUAttr (*F)) {
1118+ std::pair<unsigned , unsigned > MaxWavesPerEURange{
1119+ 1U , InfoCache.getMaxWavesPerEU (*F)};
11291120 if (*Attr != MaxWavesPerEURange) {
1130- TakeRange (*Attr);
1121+ auto [Min, Max] = *Attr;
1122+ ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
1123+ IntegerRangeState RangeState (Range);
1124+ this ->getState () = RangeState;
1125+ indicateOptimisticFixpoint ();
11311126 return ;
11321127 }
11331128 }
11341129
1135- // Unlike AAAMDFlatWorkGroupSize, it's getting trickier here. Since the
1136- // calculation of waves per EU involves flat work group size, we can't
1137- // simply use an assumed flat work group size as a start point, because the
1138- // update of flat work group size is in an inverse direction of waves per
1139- // EU. However, we can still do something if it is an entry function. Since
1140- // an entry function is a terminal node, and flat work group size either
1141- // from attribute or default will be used anyway, we can take that value and
1142- // calculate the waves per EU based on it. This result can't be updated by
1143- // no means, but that could still allow us to propagate it.
1144- if (AMDGPU::isEntryFunctionCC (F->getCallingConv ())) {
1145- std::pair<unsigned , unsigned > FlatWorkGroupSize;
1146- if (auto Attr = InfoCache.getFlatWorkGroupSizeAttr (*F))
1147- FlatWorkGroupSize = *Attr;
1148- else
1149- FlatWorkGroupSize = InfoCache.getDefaultFlatWorkGroupSize (*F);
1150- TakeRange (InfoCache.getEffectiveWavesPerEU (*F, MaxWavesPerEURange,
1151- FlatWorkGroupSize));
1152- }
1130+ if (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
1131+ indicatePessimisticFixpoint ();
11531132 }
11541133
11551134 ChangeStatus updateImpl (Attributor &A) override {
1156- auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
11571135 ChangeStatus Change = ChangeStatus::UNCHANGED;
11581136
11591137 auto CheckCallSite = [&](AbstractCallSite CS) {
@@ -1162,24 +1140,21 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
11621140 LLVM_DEBUG (dbgs () << ' [' << getName () << " ] Call " << Caller->getName ()
11631141 << " ->" << Func->getName () << ' \n ' );
11641142
1165- const auto *CallerInfo = A.getAAFor <AAAMDWavesPerEU>(
1143+ const auto *CallerAA = A.getAAFor <AAAMDWavesPerEU>(
11661144 *this , IRPosition::function (*Caller), DepClassTy::REQUIRED);
1167- const auto *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
1168- *this , IRPosition::function (*Func), DepClassTy::REQUIRED);
1169- if (!CallerInfo || !AssumedGroupSize || !CallerInfo->isValidState () ||
1170- !AssumedGroupSize->isValidState ())
1145+ if (!CallerAA || !CallerAA->isValidState ())
11711146 return false ;
11721147
1173- unsigned Min, Max ;
1174- std::tie (Min, Max) = InfoCache. getEffectiveWavesPerEU (
1175- *Caller,
1176- {CallerInfo-> getAssumed (). getLower ().getZExtValue (),
1177- CallerInfo ->getAssumed ().getUpper ().getZExtValue () - 1 },
1178- {AssumedGroupSize-> getAssumed (). getLower (). getZExtValue (),
1179- AssumedGroupSize-> getAssumed (). getUpper (). getZExtValue () - 1 } );
1180- ConstantRange CallerRange ( APInt ( 32 , Min), APInt ( 32 , Max + 1 )) ;
1181- IntegerRangeState CallerRangeState (CallerRange);
1182- Change |= clampStateAndIndicateChange ( this -> getState (), CallerRangeState) ;
1148+ ConstantRange Assumed = getAssumed () ;
1149+ unsigned Min = std::max (Assumed. getLower (). getZExtValue (),
1150+ CallerAA-> getAssumed (). getLower (). getZExtValue ());
1151+ unsigned Max = std::max (Assumed. getUpper ().getZExtValue (),
1152+ CallerAA ->getAssumed ().getUpper ().getZExtValue ());
1153+ ConstantRange Range ( APInt ( 32 , Min), APInt ( 32 , Max));
1154+ IntegerRangeState RangeState (Range );
1155+ getState () = RangeState ;
1156+ Change |= getState () == Assumed ? ChangeStatus::UNCHANGED
1157+ : ChangeStatus::CHANGED ;
11831158
11841159 return true ;
11851160 };
@@ -1323,6 +1298,73 @@ struct AAAMDGPUNoAGPR
13231298
13241299const char AAAMDGPUNoAGPR::ID = 0 ;
13251300
1301+ // / Performs the final check and updates the 'amdgpu-waves-per-eu' attribute
1302+ // / based on the finalized 'amdgpu-flat-work-group-size' attribute.
1303+ // / Both attributes start with narrow ranges that expand during iteration.
1304+ // / However, a narrower flat-workgroup-size leads to a wider waves-per-eu range,
1305+ // / preventing optimal updates later. Therefore, waves-per-eu can't be updated
1306+ // / with intermediate values during the attributor run. We defer the calculation
1307+ // / of waves-per-eu until after the flat-workgroup-size is finalized.
1308+ // / TODO: Remove this and move similar logic back into the attributor run once
1309+ // / we have a better representation for waves-per-eu.
1310+ static bool updateWavesPerEU (Module &M, TargetMachine &TM) {
1311+ bool Changed = false ;
1312+
1313+ LLVMContext &Ctx = M.getContext ();
1314+
1315+ for (Function &F : M) {
1316+ if (F.isDeclaration ())
1317+ continue ;
1318+
1319+ const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
1320+
1321+ std::optional<std::pair<unsigned , std::optional<unsigned >>>
1322+ FlatWgrpSizeAttr =
1323+ AMDGPU::getIntegerPairAttribute (F, " amdgpu-flat-work-group-size" );
1324+
1325+ unsigned MinWavesPerEU = ST.getMinWavesPerEU ();
1326+ unsigned MaxWavesPerEU = ST.getMaxWavesPerEU ();
1327+
1328+ unsigned MinFlatWgrpSize = ST.getMinFlatWorkGroupSize ();
1329+ unsigned MaxFlatWgrpSize = ST.getMaxFlatWorkGroupSize ();
1330+ if (FlatWgrpSizeAttr.has_value ()) {
1331+ MinFlatWgrpSize = FlatWgrpSizeAttr->first ;
1332+ MaxFlatWgrpSize = *(FlatWgrpSizeAttr->second );
1333+ }
1334+
1335+ // Start with the "best" range.
1336+ unsigned Min = MinWavesPerEU;
1337+ unsigned Max = MinWavesPerEU;
1338+
1339+ // Compute the range from flat workgroup size. `getWavesPerEU` will also
1340+ // account for the 'amdgpu-waves-er-eu' attribute.
1341+ auto [MinFromFlatWgrpSize, MaxFromFlatWgrpSize] =
1342+ ST.getWavesPerEU (F, {MinFlatWgrpSize, MaxFlatWgrpSize});
1343+
1344+ // For the lower bound, we have to "tighten" it.
1345+ Min = std::max (Min, MinFromFlatWgrpSize);
1346+ // For the upper bound, we have to "extend" it.
1347+ Max = std::max (Max, MaxFromFlatWgrpSize);
1348+
1349+ // Clamp the range to the max range.
1350+ Min = std::max (Min, MinWavesPerEU);
1351+ Max = std::min (Max, MaxWavesPerEU);
1352+
1353+ // Update the attribute if it is not the max.
1354+ if (Min != MinWavesPerEU || Max != MaxWavesPerEU) {
1355+ SmallString<10 > Buffer;
1356+ raw_svector_ostream OS (Buffer);
1357+ OS << Min << ' ,' << Max;
1358+ Attribute OldAttr = F.getFnAttribute (" amdgpu-waves-per-eu" );
1359+ Attribute NewAttr = Attribute::get (Ctx, " amdgpu-waves-per-eu" , OS.str ());
1360+ F.addFnAttr (NewAttr);
1361+ Changed |= OldAttr == NewAttr;
1362+ }
1363+ }
1364+
1365+ return Changed;
1366+ }
1367+
13261368static bool runImpl (Module &M, AnalysisGetter &AG, TargetMachine &TM,
13271369 AMDGPUAttributorOptions Options,
13281370 ThinOrFullLTOPhase LTOPhase) {
@@ -1396,8 +1438,11 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
13961438 }
13971439 }
13981440
1399- ChangeStatus Change = A.run ();
1400- return Change == ChangeStatus::CHANGED;
1441+ bool Changed = A.run () == ChangeStatus::CHANGED;
1442+
1443+ Changed |= updateWavesPerEU (M, TM);
1444+
1445+ return Changed;
14011446}
14021447
14031448class AMDGPUAttributorLegacy : public ModulePass {
0 commit comments