@@ -1109,74 +1109,46 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
11091109    Function *F = getAssociatedFunction ();
11101110    auto  &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
11111111
1112-     auto  TakeRange = [&](std::pair<unsigned , unsigned > R) {
1113-       auto  [Min, Max] = R;
1114-       ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
1115-       IntegerRangeState RangeState (Range);
1116-       clampStateAndIndicateChange (this ->getState (), RangeState);
1117-       indicateOptimisticFixpoint ();
1118-     };
1119- 
1120-     std::pair<unsigned , unsigned > MaxWavesPerEURange{
1121-         1U , InfoCache.getMaxWavesPerEU (*F)};
1122- 
11231112    //  If the attribute exists, we will honor it if it is not the default.
11241113    if  (auto  Attr = InfoCache.getWavesPerEUAttr (*F)) {
1114+       std::pair<unsigned , unsigned > MaxWavesPerEURange{
1115+           1U , InfoCache.getMaxWavesPerEU (*F)};
11251116      if  (*Attr != MaxWavesPerEURange) {
1126-         TakeRange (*Attr);
1117+         auto  [Min, Max] = *Attr;
1118+         ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
1119+         IntegerRangeState RangeState (Range);
1120+         clampStateAndIndicateChange (this ->getState (), RangeState);
1121+         indicateOptimisticFixpoint ();
11271122        return ;
11281123      }
11291124    }
11301125
1131-     //  Unlike AAAMDFlatWorkGroupSize, it's getting trickier here. Since the
1132-     //  calculation of waves per EU involves flat work group size, we can't
1133-     //  simply use an assumed flat work group size as a start point, because the
1134-     //  update of flat work group size is in an inverse direction of waves per
1135-     //  EU. However, we can still do something if it is an entry function. Since
1136-     //  an entry function is a terminal node, and flat work group size either
1137-     //  from attribute or default will be used anyway, we can take that value and
1138-     //  calculate the waves per EU based on it. This result can't be updated by
1139-     //  no means, but that could still allow us to propagate it.
1140-     if  (AMDGPU::isEntryFunctionCC (F->getCallingConv ())) {
1141-       std::pair<unsigned , unsigned > FlatWorkGroupSize;
1142-       if  (auto  Attr = InfoCache.getFlatWorkGroupSizeAttr (*F))
1143-         FlatWorkGroupSize = *Attr;
1144-       else 
1145-         FlatWorkGroupSize = InfoCache.getDefaultFlatWorkGroupSize (*F);
1146-       TakeRange (InfoCache.getEffectiveWavesPerEU (*F, MaxWavesPerEURange,
1147-                                                  FlatWorkGroupSize));
1148-     }
1126+     if  (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
1127+       indicatePessimisticFixpoint ();
11491128  }
11501129
11511130  ChangeStatus updateImpl (Attributor &A) override  {
1152-     auto  &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
11531131    ChangeStatus Change = ChangeStatus::UNCHANGED;
11541132
11551133    auto  CheckCallSite = [&](AbstractCallSite CS) {
11561134      Function *Caller = CS.getInstruction ()->getFunction ();
1157-       Function *Func = getAssociatedFunction ();
1158-       LLVM_DEBUG (dbgs () << ' ['   << getName () << " ] Call "   << Caller->getName ()
1159-                         << " ->"   << Func->getName () << ' \n '  );
1160- 
11611135      const  auto  *CallerInfo = A.getAAFor <AAAMDWavesPerEU>(
11621136          *this , IRPosition::function (*Caller), DepClassTy::REQUIRED);
1163-       const  auto  *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
1164-           *this , IRPosition::function (*Func), DepClassTy::REQUIRED);
1165-       if  (!CallerInfo || !AssumedGroupSize || !CallerInfo->isValidState () ||
1166-           !AssumedGroupSize->isValidState ())
1137+       if  (!CallerInfo || !CallerInfo->isValidState ())
11671138        return  false ;
1168- 
1169-       unsigned  Min, Max;
1170-       std::tie (Min, Max) = InfoCache.getEffectiveWavesPerEU (
1171-           *Caller,
1172-           {CallerInfo->getAssumed ().getLower ().getZExtValue (),
1173-            CallerInfo->getAssumed ().getUpper ().getZExtValue () - 1 },
1174-           {AssumedGroupSize->getAssumed ().getLower ().getZExtValue (),
1175-            AssumedGroupSize->getAssumed ().getUpper ().getZExtValue () - 1 });
1176-       ConstantRange CallerRange (APInt (32 , Min), APInt (32 , Max + 1 ));
1139+       unsigned  Min = CallerInfo->getAssumed ().getLower ().getZExtValue ();
1140+       unsigned  Max = CallerInfo->getAssumed ().getUpper ().getZExtValue ();
1141+       auto  CurrentAssumed = this ->getState ().getAssumed ();
1142+       Min = std::max (
1143+           Min, static_cast <unsigned >(CurrentAssumed.getLower ().getZExtValue ()));
1144+       Max = std::max (
1145+           Max, static_cast <unsigned >(CurrentAssumed.getUpper ().getZExtValue ()));
1146+       ConstantRange CallerRange (APInt (32 , Min), APInt (32 , Max));
11771147      IntegerRangeState CallerRangeState (CallerRange);
1178-       Change |= clampStateAndIndicateChange (this ->getState (), CallerRangeState);
1179- 
1148+       this ->getState () = CallerRangeState;
1149+       Change |= CurrentAssumed == this ->getState ().getAssumed ()
1150+                     ? ChangeStatus::UNCHANGED
1151+                     : ChangeStatus::CHANGED;
11801152      return  true ;
11811153    };
11821154
@@ -1329,6 +1301,59 @@ static void addPreloadKernArgHint(Function &F, TargetMachine &TM) {
13291301  }
13301302}
13311303
1304+ static  void  checkWavesPerEU (Module &M, TargetMachine &TM) {
1305+   for  (Function &F : M) {
1306+     const  GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
1307+ 
1308+     auto  FlatWgrpSizeAttr =
1309+         AMDGPU::getIntegerPairAttribute (F, " amdgpu-flat-work-group-size"  );
1310+     auto  WavesPerEUAttr = AMDGPU::getIntegerPairAttribute (
1311+         F, " amdgpu-waves-per-eu"  , /* OnlyFirstRequired=*/ true );
1312+ 
1313+     unsigned  MinWavesPerEU = ST.getMinWavesPerEU ();
1314+     unsigned  MaxWavesPerEU = ST.getMaxWavesPerEU ();
1315+ 
1316+     unsigned  MinFlatWgrpSize = 1U ;
1317+     unsigned  MaxFlatWgrpSize = 1024U ;
1318+     if  (FlatWgrpSizeAttr.has_value ()) {
1319+       MinFlatWgrpSize = FlatWgrpSizeAttr->first ;
1320+       MaxFlatWgrpSize = *(FlatWgrpSizeAttr->second );
1321+     }
1322+ 
1323+     //  Start with the max range.
1324+     unsigned  Min = MinWavesPerEU;
1325+     unsigned  Max = MaxWavesPerEU;
1326+ 
1327+     //  If the attribute exists, set them to the value from the attribute.
1328+     if  (WavesPerEUAttr.has_value ()) {
1329+       Min = WavesPerEUAttr->first ;
1330+       if  (WavesPerEUAttr->second .has_value ())
1331+         Max = *(WavesPerEUAttr->second );
1332+     }
1333+ 
1334+     //  Compute the range from flat workgroup size.
1335+     auto  [MinFromFlatWgrpSize, MaxFromFlatWgrpSize] =
1336+         ST.getWavesPerEU (F, std::make_pair (MinFlatWgrpSize, MaxFlatWgrpSize));
1337+ 
1338+     //  For the lower bound, we have to "tighten" it.
1339+     Min = std::max (Min, MinFromFlatWgrpSize);
1340+     //  For the upper bound, we have to "extend" it.
1341+     Max = std::max (Max, MaxFromFlatWgrpSize);
1342+ 
1343+     //  Clamp the range to the max range.
1344+     Min = std::max (Min, MinWavesPerEU);
1345+     Max = std::min (Max, MaxWavesPerEU);
1346+ 
1347+     //  Update the attribute if it is not the max.
1348+     if  (Min != MinWavesPerEU || Max != MaxWavesPerEU) {
1349+       SmallString<10 > Buffer;
1350+       raw_svector_ostream OS (Buffer);
1351+       OS << Min << ' ,'   << Max;
1352+       F.addFnAttr (" amdgpu-waves-per-eu"  , OS.str ());
1353+     }
1354+   }
1355+ }
1356+ 
13321357static  bool  runImpl (Module &M, AnalysisGetter &AG, TargetMachine &TM,
13331358                    AMDGPUAttributorOptions Options,
13341359                    ThinOrFullLTOPhase LTOPhase) {
@@ -1421,8 +1446,14 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
14211446    }
14221447  }
14231448
1424-   ChangeStatus Change = A.run ();
1425-   return  Change == ChangeStatus::CHANGED;
1449+   bool  Changed = A.run () == ChangeStatus::CHANGED;
1450+ 
1451+   if  (Changed && (LTOPhase == ThinOrFullLTOPhase::None ||
1452+                   LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
1453+                   LTOPhase == ThinOrFullLTOPhase::ThinLTOPostLink))
1454+     checkWavesPerEU (M, TM);
1455+ 
1456+   return  Changed;
14261457}
14271458
14281459class  AMDGPUAttributorLegacy  : public  ModulePass  {
0 commit comments