@@ -1117,47 +1117,25 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
11171117    Function *F = getAssociatedFunction ();
11181118    auto  &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
11191119
1120-     auto  TakeRange = [&](std::pair<unsigned , unsigned > R) {
1121-       auto  [Min, Max] = R;
1122-       ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
1123-       IntegerRangeState RangeState (Range);
1124-       clampStateAndIndicateChange (this ->getState (), RangeState);
1125-       indicateOptimisticFixpoint ();
1126-     };
1127- 
1128-     std::pair<unsigned , unsigned > MaxWavesPerEURange{
1129-         1U , InfoCache.getMaxWavesPerEU (*F)};
1130- 
11311120    //  If the attribute exists, we will honor it if it is not the default.
11321121    if  (auto  Attr = InfoCache.getWavesPerEUAttr (*F)) {
1122+       std::pair<unsigned , unsigned > MaxWavesPerEURange{
1123+           1U , InfoCache.getMaxWavesPerEU (*F)};
11331124      if  (*Attr != MaxWavesPerEURange) {
1134-         TakeRange (*Attr);
1125+         auto  [Min, Max] = *Attr;
1126+         ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
1127+         IntegerRangeState RangeState (Range);
1128+         this ->getState () = RangeState;
1129+         indicateOptimisticFixpoint ();
11351130        return ;
11361131      }
11371132    }
11381133
1139-     //  Unlike AAAMDFlatWorkGroupSize, it's getting trickier here. Since the
1140-     //  calculation of waves per EU involves flat work group size, we can't
1141-     //  simply use an assumed flat work group size as a start point, because the
1142-     //  update of flat work group size is in an inverse direction of waves per
1143-     //  EU. However, we can still do something if it is an entry function. Since
1144-     //  an entry function is a terminal node, and flat work group size either
1145-     //  from attribute or default will be used anyway, we can take that value and
1146-     //  calculate the waves per EU based on it. This result can't be updated by
1147-     //  no means, but that could still allow us to propagate it.
1148-     if  (AMDGPU::isEntryFunctionCC (F->getCallingConv ())) {
1149-       std::pair<unsigned , unsigned > FlatWorkGroupSize;
1150-       if  (auto  Attr = InfoCache.getFlatWorkGroupSizeAttr (*F))
1151-         FlatWorkGroupSize = *Attr;
1152-       else 
1153-         FlatWorkGroupSize = InfoCache.getDefaultFlatWorkGroupSize (*F);
1154-       TakeRange (InfoCache.getEffectiveWavesPerEU (*F, MaxWavesPerEURange,
1155-                                                  FlatWorkGroupSize));
1156-     }
1134+     if  (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
1135+       indicatePessimisticFixpoint ();
11571136  }
11581137
11591138  ChangeStatus updateImpl (Attributor &A) override  {
1160-     auto  &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
11611139    ChangeStatus Change = ChangeStatus::UNCHANGED;
11621140
11631141    auto  CheckCallSite = [&](AbstractCallSite CS) {
@@ -1166,24 +1144,21 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
11661144      LLVM_DEBUG (dbgs () << ' ['   << getName () << " ] Call "   << Caller->getName ()
11671145                        << " ->"   << Func->getName () << ' \n '  );
11681146
1169-       const  auto  *CallerInfo  = A.getAAFor <AAAMDWavesPerEU>(
1147+       const  auto  *CallerAA  = A.getAAFor <AAAMDWavesPerEU>(
11701148          *this , IRPosition::function (*Caller), DepClassTy::REQUIRED);
1171-       const  auto  *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
1172-           *this , IRPosition::function (*Func), DepClassTy::REQUIRED);
1173-       if  (!CallerInfo || !AssumedGroupSize || !CallerInfo->isValidState () ||
1174-           !AssumedGroupSize->isValidState ())
1149+       if  (!CallerAA || !CallerAA->isValidState ())
11751150        return  false ;
11761151
1177-       unsigned  Min, Max ;
1178-       std::tie  (Min, Max) = InfoCache. getEffectiveWavesPerEU ( 
1179-           *Caller, 
1180-           {CallerInfo-> getAssumed (). getLower ().getZExtValue (),
1181-            CallerInfo ->getAssumed ().getUpper ().getZExtValue () -  1 }, 
1182-           {AssumedGroupSize-> getAssumed (). getLower (). getZExtValue (), 
1183-            AssumedGroupSize-> getAssumed (). getUpper (). getZExtValue () -  1 } );
1184-       ConstantRange  CallerRange ( APInt ( 32 , Min),  APInt ( 32 , Max +  1 )) ;
1185-       IntegerRangeState  CallerRangeState (CallerRange); 
1186-       Change |=  clampStateAndIndicateChange ( this -> getState (), CallerRangeState) ;
1152+       ConstantRange Assumed =  this -> getAssumed () ;
1153+       unsigned  Min =  std::max  (Assumed. getLower (). getZExtValue (), 
1154+                               CallerAA-> getAssumed (). getLower (). getZExtValue ()); 
1155+       unsigned  Max =  std::max (Assumed. getUpper ().getZExtValue (),
1156+                               CallerAA ->getAssumed ().getUpper ().getZExtValue ()); 
1157+       ConstantRange  Range ( APInt ( 32 , Min),  APInt ( 32 , Max)); 
1158+       IntegerRangeState  RangeState (Range );
1159+       this -> getState () = RangeState ;
1160+       Change |=  this -> getState () == Assumed ? ChangeStatus::UNCHANGED 
1161+                                             : ChangeStatus::CHANGED ;
11871162
11881163      return  true ;
11891164    };
@@ -1342,6 +1317,60 @@ static void addPreloadKernArgHint(Function &F, TargetMachine &TM) {
13421317  }
13431318}
13441319
1320+ // / The final check and update of the attribute 'amdgpu-waves-per-eu' based on
1321+ // / the determined 'amdgpu-flat-work-group-size' attribute. We can't do this
1322+ // / during attributor run because the two attributes grow in opposite direction,
1323+ // / we should not use any intermediate value to calculate waves per eu until we
1324+ // / have a determined flat workgroup size.
1325+ static  void  updateWavesPerEU (Module &M, TargetMachine &TM) {
1326+   for  (Function &F : M) {
1327+     if  (F.isDeclaration ())
1328+       continue ;
1329+ 
1330+     const  GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
1331+ 
1332+     std::optional<std::pair<unsigned , std::optional<unsigned >>>
1333+         FlatWgrpSizeAttr =
1334+             AMDGPU::getIntegerPairAttribute (F, " amdgpu-flat-work-group-size"  );
1335+ 
1336+     unsigned  MinWavesPerEU = ST.getMinWavesPerEU ();
1337+     unsigned  MaxWavesPerEU = ST.getMaxWavesPerEU ();
1338+ 
1339+     unsigned  MinFlatWgrpSize = ST.getMinFlatWorkGroupSize ();
1340+     unsigned  MaxFlatWgrpSize = ST.getMaxFlatWorkGroupSize ();
1341+     if  (FlatWgrpSizeAttr.has_value ()) {
1342+       MinFlatWgrpSize = FlatWgrpSizeAttr->first ;
1343+       MaxFlatWgrpSize = *(FlatWgrpSizeAttr->second );
1344+     }
1345+ 
1346+     //  Start with the max range.
1347+     unsigned  Min = MinWavesPerEU;
1348+     unsigned  Max = MinWavesPerEU;
1349+ 
1350+     //  Compute the range from flat workgroup size. `getWavesPerEU` will also
1351+     //  account for the 'amdgpu-waves-er-eu' attribute.
1352+     auto  [MinFromFlatWgrpSize, MaxFromFlatWgrpSize] =
1353+         ST.getWavesPerEU (F, {MinFlatWgrpSize, MaxFlatWgrpSize});
1354+ 
1355+     //  For the lower bound, we have to "tighten" it.
1356+     Min = std::max (Min, MinFromFlatWgrpSize);
1357+     //  For the upper bound, we have to "extend" it.
1358+     Max = std::max (Max, MaxFromFlatWgrpSize);
1359+ 
1360+     //  Clamp the range to the max range.
1361+     Min = std::max (Min, MinWavesPerEU);
1362+     Max = std::min (Max, MaxWavesPerEU);
1363+ 
1364+     //  Update the attribute if it is not the max.
1365+     if  (Min != MinWavesPerEU || Max != MaxWavesPerEU) {
1366+       SmallString<10 > Buffer;
1367+       raw_svector_ostream OS (Buffer);
1368+       OS << Min << ' ,'   << Max;
1369+       F.addFnAttr (" amdgpu-waves-per-eu"  , OS.str ());
1370+     }
1371+   }
1372+ }
1373+ 
13451374static  bool  runImpl (Module &M, AnalysisGetter &AG, TargetMachine &TM,
13461375                    AMDGPUAttributorOptions Options,
13471376                    ThinOrFullLTOPhase LTOPhase) {
@@ -1417,8 +1446,16 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
14171446    }
14181447  }
14191448
1420-   ChangeStatus Change = A.run ();
1421-   return  Change == ChangeStatus::CHANGED;
1449+   bool  Changed = A.run () == ChangeStatus::CHANGED;
1450+ 
1451+   //  We only update the waves-per-eu attribute at the final stage to avoid
1452+   //  setting it with intermediate values.
1453+   if  (Changed && (LTOPhase == ThinOrFullLTOPhase::None ||
1454+                   LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
1455+                   LTOPhase == ThinOrFullLTOPhase::ThinLTOPostLink))
1456+     updateWavesPerEU (M, TM);
1457+ 
1458+   return  Changed;
14221459}
14231460
14241461class  AMDGPUAttributorLegacy  : public  ModulePass  {
0 commit comments