@@ -1111,47 +1111,25 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
11111111    Function *F = getAssociatedFunction ();
11121112    auto  &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
11131113
1114-     auto  TakeRange = [&](std::pair<unsigned , unsigned > R) {
1115-       auto  [Min, Max] = R;
1116-       ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
1117-       IntegerRangeState RangeState (Range);
1118-       clampStateAndIndicateChange (this ->getState (), RangeState);
1119-       indicateOptimisticFixpoint ();
1120-     };
1121- 
1122-     std::pair<unsigned , unsigned > MaxWavesPerEURange{
1123-         1U , InfoCache.getMaxWavesPerEU (*F)};
1124- 
11251114    //  If the attribute exists, we will honor it if it is not the default.
11261115    if  (auto  Attr = InfoCache.getWavesPerEUAttr (*F)) {
1116+       std::pair<unsigned , unsigned > MaxWavesPerEURange{
1117+           1U , InfoCache.getMaxWavesPerEU (*F)};
11271118      if  (*Attr != MaxWavesPerEURange) {
1128-         TakeRange (*Attr);
1119+         auto  [Min, Max] = *Attr;
1120+         ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
1121+         IntegerRangeState RangeState (Range);
1122+         this ->getState () = RangeState;
1123+         indicateOptimisticFixpoint ();
11291124        return ;
11301125      }
11311126    }
11321127
1133-     //  Unlike AAAMDFlatWorkGroupSize, it's getting trickier here. Since the
1134-     //  calculation of waves per EU involves flat work group size, we can't
1135-     //  simply use an assumed flat work group size as a start point, because the
1136-     //  update of flat work group size is in an inverse direction of waves per
1137-     //  EU. However, we can still do something if it is an entry function. Since
1138-     //  an entry function is a terminal node, and flat work group size either
1139-     //  from attribute or default will be used anyway, we can take that value and
1140-     //  calculate the waves per EU based on it. This result can't be updated by
1141-     //  no means, but that could still allow us to propagate it.
1142-     if  (AMDGPU::isEntryFunctionCC (F->getCallingConv ())) {
1143-       std::pair<unsigned , unsigned > FlatWorkGroupSize;
1144-       if  (auto  Attr = InfoCache.getFlatWorkGroupSizeAttr (*F))
1145-         FlatWorkGroupSize = *Attr;
1146-       else 
1147-         FlatWorkGroupSize = InfoCache.getDefaultFlatWorkGroupSize (*F);
1148-       TakeRange (InfoCache.getEffectiveWavesPerEU (*F, MaxWavesPerEURange,
1149-                                                  FlatWorkGroupSize));
1150-     }
1128+     if  (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
1129+       indicatePessimisticFixpoint ();
11511130  }
11521131
11531132  ChangeStatus updateImpl (Attributor &A) override  {
1154-     auto  &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
11551133    ChangeStatus Change = ChangeStatus::UNCHANGED;
11561134
11571135    auto  CheckCallSite = [&](AbstractCallSite CS) {
@@ -1160,24 +1138,21 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
11601138      LLVM_DEBUG (dbgs () << ' ['   << getName () << " ] Call "   << Caller->getName ()
11611139                        << " ->"   << Func->getName () << ' \n '  );
11621140
1163-       const  auto  *CallerInfo  = A.getAAFor <AAAMDWavesPerEU>(
1141+       const  auto  *CallerAA  = A.getAAFor <AAAMDWavesPerEU>(
11641142          *this , IRPosition::function (*Caller), DepClassTy::REQUIRED);
1165-       const  auto  *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
1166-           *this , IRPosition::function (*Func), DepClassTy::REQUIRED);
1167-       if  (!CallerInfo || !AssumedGroupSize || !CallerInfo->isValidState () ||
1168-           !AssumedGroupSize->isValidState ())
1143+       if  (!CallerAA || !CallerAA->isValidState ())
11691144        return  false ;
11701145
1171-       unsigned  Min, Max ;
1172-       std::tie  (Min, Max) = InfoCache. getEffectiveWavesPerEU ( 
1173-           *Caller, 
1174-           {CallerInfo-> getAssumed (). getLower ().getZExtValue (),
1175-            CallerInfo ->getAssumed ().getUpper ().getZExtValue () -  1 }, 
1176-           {AssumedGroupSize-> getAssumed (). getLower (). getZExtValue (), 
1177-            AssumedGroupSize-> getAssumed (). getUpper (). getZExtValue () -  1 } );
1178-       ConstantRange  CallerRange ( APInt ( 32 , Min),  APInt ( 32 , Max +  1 )) ;
1179-       IntegerRangeState  CallerRangeState (CallerRange); 
1180-       Change |=  clampStateAndIndicateChange ( this -> getState (), CallerRangeState) ;
1146+       auto  Assumed =  this -> getAssumed () ;
1147+       unsigned  Min =  std::max  (Assumed. getLower (). getZExtValue (), 
1148+                               CallerAA-> getAssumed (). getLower (). getZExtValue ()); 
1149+       unsigned  Max =  std::max (Assumed. getUpper ().getZExtValue (),
1150+                               CallerAA ->getAssumed ().getUpper ().getZExtValue ()); 
1151+       ConstantRange  Range ( APInt ( 32 , Min),  APInt ( 32 , Max)); 
1152+       IntegerRangeState  RangeState (Range );
1153+       this -> getState () = RangeState ;
1154+       Change |=  this -> getState () == Assumed ? ChangeStatus::UNCHANGED 
1155+                                             : ChangeStatus::CHANGED ;
11811156
11821157      return  true ;
11831158    };
@@ -1336,6 +1311,59 @@ static void addPreloadKernArgHint(Function &F, TargetMachine &TM) {
13361311  }
13371312}
13381313
1314+ static  void  checkWavesPerEU (Module &M, TargetMachine &TM) {
1315+   for  (Function &F : M) {
1316+     const  GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
1317+ 
1318+     auto  FlatWgrpSizeAttr =
1319+         AMDGPU::getIntegerPairAttribute (F, " amdgpu-flat-work-group-size"  );
1320+     auto  WavesPerEUAttr = AMDGPU::getIntegerPairAttribute (
1321+         F, " amdgpu-waves-per-eu"  , /* OnlyFirstRequired=*/ true );
1322+ 
1323+     unsigned  MinWavesPerEU = ST.getMinWavesPerEU ();
1324+     unsigned  MaxWavesPerEU = ST.getMaxWavesPerEU ();
1325+ 
1326+     unsigned  MinFlatWgrpSize = 1U ;
1327+     unsigned  MaxFlatWgrpSize = 1024U ;
1328+     if  (FlatWgrpSizeAttr.has_value ()) {
1329+       MinFlatWgrpSize = FlatWgrpSizeAttr->first ;
1330+       MaxFlatWgrpSize = *(FlatWgrpSizeAttr->second );
1331+     }
1332+ 
1333+     //  Start with the max range.
1334+     unsigned  Min = MinWavesPerEU;
1335+     unsigned  Max = MaxWavesPerEU;
1336+ 
1337+     //  If the attribute exists, set them to the value from the attribute.
1338+     if  (WavesPerEUAttr.has_value ()) {
1339+       Min = WavesPerEUAttr->first ;
1340+       if  (WavesPerEUAttr->second .has_value ())
1341+         Max = *(WavesPerEUAttr->second );
1342+     }
1343+ 
1344+     //  Compute the range from flat workgroup size.
1345+     auto  [MinFromFlatWgrpSize, MaxFromFlatWgrpSize] =
1346+         ST.getWavesPerEU (F, std::make_pair (MinFlatWgrpSize, MaxFlatWgrpSize));
1347+ 
1348+     //  For the lower bound, we have to "tighten" it.
1349+     Min = std::max (Min, MinFromFlatWgrpSize);
1350+     //  For the upper bound, we have to "extend" it.
1351+     Max = std::max (Max, MaxFromFlatWgrpSize);
1352+ 
1353+     //  Clamp the range to the max range.
1354+     Min = std::max (Min, MinWavesPerEU);
1355+     Max = std::min (Max, MaxWavesPerEU);
1356+ 
1357+     //  Update the attribute if it is not the max.
1358+     if  (Min != MinWavesPerEU || Max != MaxWavesPerEU) {
1359+       SmallString<10 > Buffer;
1360+       raw_svector_ostream OS (Buffer);
1361+       OS << Min << ' ,'   << Max;
1362+       F.addFnAttr (" amdgpu-waves-per-eu"  , OS.str ());
1363+     }
1364+   }
1365+ }
1366+ 
13391367static  bool  runImpl (Module &M, AnalysisGetter &AG, TargetMachine &TM,
13401368                    AMDGPUAttributorOptions Options,
13411369                    ThinOrFullLTOPhase LTOPhase) {
@@ -1425,8 +1453,14 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
14251453    }
14261454  }
14271455
1428-   ChangeStatus Change = A.run ();
1429-   return  Change == ChangeStatus::CHANGED;
1456+   bool  Changed = A.run () == ChangeStatus::CHANGED;
1457+ 
1458+   if  (Changed && (LTOPhase == ThinOrFullLTOPhase::None ||
1459+                   LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
1460+                   LTOPhase == ThinOrFullLTOPhase::ThinLTOPostLink))
1461+     checkWavesPerEU (M, TM);
1462+ 
1463+   return  Changed;
14301464}
14311465
14321466class  AMDGPUAttributorLegacy  : public  ModulePass  {
0 commit comments