@@ -198,6 +198,16 @@ class AMDGPUInformationCache : public InformationCache {
198198 return ST.getWavesPerEU (F, FlatWorkGroupSize);
199199 }
200200
201+ std::optional<std::pair<unsigned , unsigned >>
202+ getWavesPerEUAttr (const Function &F) {
203+ auto Val = AMDGPU::getIntegerPairAttribute (F, " amdgpu-waves-per-eu" );
204+ if (Val && Val->second == 0 ) {
205+ const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
206+ Val->second = ST.getMaxWavesPerEU ();
207+ }
208+ return Val;
209+ }
210+
201211 std::pair<unsigned , unsigned >
202212 getEffectiveWavesPerEU (const Function &F,
203213 std::pair<unsigned , unsigned > WavesPerEU,
@@ -768,22 +778,6 @@ struct AAAMDSizeRangeAttribute
768778 /* ForceReplace=*/ true );
769779 }
770780
771- ChangeStatus emitAttributeIfNotDefault (Attributor &A, unsigned Min,
772- unsigned Max) {
773- // Don't add the attribute if it's the implied default.
774- if (getAssumed ().getLower () == Min && getAssumed ().getUpper () - 1 == Max)
775- return ChangeStatus::UNCHANGED;
776-
777- Function *F = getAssociatedFunction ();
778- LLVMContext &Ctx = F->getContext ();
779- SmallString<10 > Buffer;
780- raw_svector_ostream OS (Buffer);
781- OS << getAssumed ().getLower () << ' ,' << getAssumed ().getUpper () - 1 ;
782- return A.manifestAttrs (getIRPosition (),
783- {Attribute::get (Ctx, AttrName, OS.str ())},
784- /* ForceReplace=*/ true );
785- }
786-
787781 const std::string getAsStr (Attributor *) const override {
788782 std::string Str;
789783 raw_string_ostream OS (Str);
@@ -868,29 +862,44 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
868862 AAAMDWavesPerEU (const IRPosition &IRP, Attributor &A)
869863 : AAAMDSizeRangeAttribute(IRP, A, " amdgpu-waves-per-eu" ) {}
870864
871- bool isValidState () const override {
872- return !Assumed.isEmptySet () && IntegerRangeState::isValidState ();
873- }
874-
875865 void initialize (Attributor &A) override {
876866 Function *F = getAssociatedFunction ();
877867 auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
878868
879- if (const auto *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
880- *this , IRPosition::function (*F), DepClassTy::REQUIRED);
881- AssumedGroupSize->isValidState ()) {
882-
883- unsigned Min, Max;
884- std::tie (Min, Max) = InfoCache.getWavesPerEU (
885- *F, {AssumedGroupSize->getAssumed ().getLower ().getZExtValue (),
886- AssumedGroupSize->getAssumed ().getUpper ().getZExtValue () - 1 });
887-
869+ auto TakeRange = [&](std::pair<unsigned , unsigned > R) {
870+ auto [Min, Max] = R;
888871 ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
889- intersectKnown (Range);
872+ IntegerRangeState RangeState (Range);
873+ clampStateAndIndicateChange (this ->getState (), RangeState);
874+ indicateOptimisticFixpoint ();
875+ };
876+
877+ // If the attribute exists, simple honor it.
878+ if (auto Attr = InfoCache.getWavesPerEUAttr (*F)) {
879+ TakeRange (*Attr);
880+ return ;
890881 }
891882
892- if (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
893- indicatePessimisticFixpoint ();
883+ // It's getting trickier here, different from AAAMDFlatWorkGroupSize. Since
884+ // the calculation of waves per EU involves flat work group size, we can't
885+ // simply use an assumed flat work group size as a start point, because the
886+ // update of flat work group size is in an inverse direction of waves per
887+ // EU. However, we can still do something if it is an entry function. Since
888+ // an entry function is a terminal node, and flat work group size either
889+ // from attribute or default will be used anyway, we can take that value and
890+ // calculate the waves per EU based on it. This result can't be updated by
891+ // no means, but that could still allow us to propagate it.
892+ if (AMDGPU::isEntryFunctionCC (F->getCallingConv ())) {
893+ std::pair<unsigned , unsigned > MaxWavesPerEURange{
894+ 1U , InfoCache.getMaxWavesPerEU (*F)};
895+ std::pair<unsigned , unsigned > FlatWorkGroupSize;
896+ if (auto Attr = InfoCache.getFlatWorkGroupSizeAttr (*F))
897+ FlatWorkGroupSize = *Attr;
898+ else
899+ FlatWorkGroupSize = InfoCache.getDefaultFlatWorkGroupSize (*F);
900+ TakeRange (InfoCache.getEffectiveWavesPerEU (*F, MaxWavesPerEURange,
901+ FlatWorkGroupSize));
902+ }
894903 }
895904
896905 ChangeStatus updateImpl (Attributor &A) override {
@@ -939,8 +948,8 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
939948 ChangeStatus manifest (Attributor &A) override {
940949 Function *F = getAssociatedFunction ();
941950 auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
942- unsigned Max = InfoCache. getMaxWavesPerEU (*F);
943- return emitAttributeIfNotDefault ( A, 1 , Max );
951+ return emitAttributeIfNotDefaultAfterClamp (
952+ A, { 1U , InfoCache. getMaxWavesPerEU (*F)} );
944953 }
945954
946955 // / See AbstractAttribute::getName()
0 commit comments