@@ -831,7 +831,9 @@ struct AAAMDSizeRangeAttribute
831831 const std::string getAsStr (Attributor *) const override {
832832 std::string Str;
833833 raw_string_ostream OS (Str);
834- OS << getName () << ' [' ;
834+ OS << getName () << " Known[" ;
835+ OS << getKnown ().getLower () << ' ,' << getKnown ().getUpper () - 1 ;
836+ OS << " ] Assumed[" ;
835837 OS << getAssumed ().getLower () << ' ,' << getAssumed ().getUpper () - 1 ;
836838 OS << ' ]' ;
837839 return OS.str ();
@@ -1044,60 +1046,40 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
10441046 AAAMDWavesPerEU (const IRPosition &IRP, Attributor &A)
10451047 : AAAMDSizeRangeAttribute(IRP, A, " amdgpu-waves-per-eu" ) {}
10461048
1047- bool isValidState () const override {
1048- return !Assumed.isEmptySet () && IntegerRangeState::isValidState ();
1049- }
1050-
10511049 void initialize (Attributor &A) override {
10521050 Function *F = getAssociatedFunction ();
10531051 auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
10541052
1055- if (const auto *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
1056- *this , IRPosition::function (*F), DepClassTy::REQUIRED);
1057- AssumedGroupSize->isValidState ()) {
1053+ // We allow consistent WavesPErEU for all functions here but for non-entry
1054+ // points we will verify consistency in the end.
1055+ unsigned ImpliedMin, ImpliedMax;
1056+ std::tie (ImpliedMin, ImpliedMax) =
1057+ InfoCache.getWavesPerEU (*F, InfoCache.getFlatWorkGroupSizes (*F));
10581058
1059- unsigned Min, Max;
1060- std::tie (Min, Max) = InfoCache.getWavesPerEU (
1061- *F, {AssumedGroupSize->getAssumed ().getLower ().getZExtValue (),
1062- AssumedGroupSize->getAssumed ().getUpper ().getZExtValue () - 1 });
1063-
1064- ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
1065- intersectKnown (Range);
1066- }
1059+ ConstantRange Range (APInt (32 , ImpliedMin), APInt (32 , ImpliedMax + 1 ));
1060+ intersectKnown (Range);
10671061
1068- if (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
1062+ // For entries we cannot derive anything better.
1063+ if (AMDGPU::isEntryFunctionCC (getAssociatedFunction ()->getCallingConv ()))
10691064 indicatePessimisticFixpoint ();
10701065 }
10711066
10721067 ChangeStatus updateImpl (Attributor &A) override {
1073- auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
10741068 ChangeStatus Change = ChangeStatus::UNCHANGED;
10751069
10761070 auto CheckCallSite = [&](AbstractCallSite CS) {
10771071 Function *Caller = CS.getInstruction ()->getFunction ();
1078- Function *Func = getAssociatedFunction ();
1072+ [[maybe_unused]] Function *Func = getAssociatedFunction ();
10791073 LLVM_DEBUG (dbgs () << ' [' << getName () << " ] Call " << Caller->getName ()
10801074 << " ->" << Func->getName () << ' \n ' );
10811075
10821076 const auto *CallerInfo = A.getAAFor <AAAMDWavesPerEU>(
10831077 *this , IRPosition::function (*Caller), DepClassTy::REQUIRED);
1084- const auto *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
1085- *this , IRPosition::function (*Func), DepClassTy::REQUIRED);
1086- if (!CallerInfo || !AssumedGroupSize || !CallerInfo->isValidState () ||
1087- !AssumedGroupSize->isValidState ())
1078+ if (!CallerInfo || !CallerInfo->isValidState ())
10881079 return false ;
10891080
1090- unsigned Min, Max;
1091- std::tie (Min, Max) = InfoCache.getEffectiveWavesPerEU (
1092- *Caller,
1093- {CallerInfo->getAssumed ().getLower ().getZExtValue (),
1094- CallerInfo->getAssumed ().getUpper ().getZExtValue () - 1 },
1095- {AssumedGroupSize->getAssumed ().getLower ().getZExtValue (),
1096- AssumedGroupSize->getAssumed ().getUpper ().getZExtValue () - 1 });
1097- ConstantRange CallerRange (APInt (32 , Min), APInt (32 , Max + 1 ));
1098- IntegerRangeState CallerRangeState (CallerRange);
1099- Change |= clampStateAndIndicateChange (this ->getState (), CallerRangeState);
1100-
1081+ Change |=
1082+ clampStateAndIndicateChange (this ->getState (), CallerInfo->getState ());
11011083 return true ;
11021084 };
11031085
@@ -1113,8 +1095,28 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
11131095 Attributor &A);
11141096
11151097 ChangeStatus manifest (Attributor &A) override {
1098+ unsigned ImpliedMin = getAssumed ().getLower ().getZExtValue ();
1099+ unsigned ImpliedMax = getAssumed ().getUpper ().getZExtValue () - 1 ;
1100+
11161101 Function *F = getAssociatedFunction ();
11171102 auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
1103+
1104+ // Make non-kernel functions locally consistent.
1105+ if (!AMDGPU::isEntryFunctionCC (getAssociatedFunction ()->getCallingConv ())) {
1106+ const auto *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
1107+ *this , getIRPosition (), DepClassTy::OPTIONAL);
1108+ std::pair<unsigned , unsigned > FlatWorkGroupSize;
1109+ if (!AssumedGroupSize || !AssumedGroupSize->isValidState ())
1110+ FlatWorkGroupSize = InfoCache.getFlatWorkGroupSizes (*F);
1111+ else
1112+ FlatWorkGroupSize = {
1113+ AssumedGroupSize->getAssumed ().getLower ().getZExtValue (),
1114+ AssumedGroupSize->getAssumed ().getUpper ().getZExtValue () - 1 };
1115+
1116+ std::tie (ImpliedMin, ImpliedMax) = InfoCache.getEffectiveWavesPerEU (
1117+ *F, {ImpliedMin, ImpliedMax}, FlatWorkGroupSize);
1118+ }
1119+
11181120 unsigned Max = InfoCache.getMaxWavesPerEU (*F);
11191121 return emitAttributeIfNotDefault (A, 1 , Max);
11201122 }
@@ -1295,10 +1297,10 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
12951297 A.getOrCreateAAFor <AAUniformWorkGroupSize>(IRPosition::function (*F));
12961298 A.getOrCreateAAFor <AAAMDMaxNumWorkgroups>(IRPosition::function (*F));
12971299 A.getOrCreateAAFor <AAAMDGPUNoAGPR>(IRPosition::function (*F));
1300+ A.getOrCreateAAFor <AAAMDWavesPerEU>(IRPosition::function (*F));
12981301 CallingConv::ID CC = F->getCallingConv ();
12991302 if (!AMDGPU::isEntryFunctionCC (CC)) {
13001303 A.getOrCreateAAFor <AAAMDFlatWorkGroupSize>(IRPosition::function (*F));
1301- A.getOrCreateAAFor <AAAMDWavesPerEU>(IRPosition::function (*F));
13021304 } else if (CC == CallingConv::AMDGPU_KERNEL) {
13031305 addPreloadKernArgHint (*F, TM);
13041306 }
0 commit comments