@@ -198,6 +198,17 @@ class AMDGPUInformationCache : public InformationCache {
198198 return ST.getWavesPerEU (F, FlatWorkGroupSize);
199199 }
200200
201+ std::optional<std::pair<unsigned , unsigned >>
202+ getWavesPerEUAttr (const Function &F) {
203+ auto Val = AMDGPU::getIntegerPairAttribute (F, " amdgpu-waves-per-eu" ,
204+ /* OnlyFirstRequired=*/ true );
205+ if (Val && Val->second == 0 ) {
206+ const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
207+ Val->second = ST.getMaxWavesPerEU ();
208+ }
209+ return Val;
210+ }
211+
201212 std::pair<unsigned , unsigned >
202213 getEffectiveWavesPerEU (const Function &F,
203214 std::pair<unsigned , unsigned > WavesPerEU,
@@ -768,22 +779,6 @@ struct AAAMDSizeRangeAttribute
768779 /* ForceReplace=*/ true );
769780 }
770781
771- ChangeStatus emitAttributeIfNotDefault (Attributor &A, unsigned Min,
772- unsigned Max) {
773- // Don't add the attribute if it's the implied default.
774- if (getAssumed ().getLower () == Min && getAssumed ().getUpper () - 1 == Max)
775- return ChangeStatus::UNCHANGED;
776-
777- Function *F = getAssociatedFunction ();
778- LLVMContext &Ctx = F->getContext ();
779- SmallString<10 > Buffer;
780- raw_svector_ostream OS (Buffer);
781- OS << getAssumed ().getLower () << ' ,' << getAssumed ().getUpper () - 1 ;
782- return A.manifestAttrs (getIRPosition (),
783- {Attribute::get (Ctx, AttrName, OS.str ())},
784- /* ForceReplace=*/ true );
785- }
786-
787782 const std::string getAsStr (Attributor *) const override {
788783 std::string Str;
789784 raw_string_ostream OS (Str);
@@ -873,29 +868,47 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
873868 AAAMDWavesPerEU (const IRPosition &IRP, Attributor &A)
874869 : AAAMDSizeRangeAttribute(IRP, A, " amdgpu-waves-per-eu" ) {}
875870
876- bool isValidState () const override {
877- return !Assumed.isEmptySet () && IntegerRangeState::isValidState ();
878- }
879-
880871 void initialize (Attributor &A) override {
881872 Function *F = getAssociatedFunction ();
882873 auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
883874
884- if (const auto *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
885- *this , IRPosition::function (*F), DepClassTy::REQUIRED);
886- AssumedGroupSize->isValidState ()) {
875+ auto TakeRange = [&](std::pair<unsigned , unsigned > R) {
876+ auto [Min, Max] = R;
877+ ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
878+ IntegerRangeState RangeState (Range);
879+ clampStateAndIndicateChange (this ->getState (), RangeState);
880+ indicateOptimisticFixpoint ();
881+ };
887882
888- unsigned Min, Max;
889- std::tie (Min, Max) = InfoCache.getWavesPerEU (
890- *F, {AssumedGroupSize->getAssumed ().getLower ().getZExtValue (),
891- AssumedGroupSize->getAssumed ().getUpper ().getZExtValue () - 1 });
883+ std::pair<unsigned , unsigned > MaxWavesPerEURange{
884+ 1U , InfoCache.getMaxWavesPerEU (*F)};
892885
893- ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
894- intersectKnown (Range);
886+ // If the attribute exists, we will honor it if it is not the default.
887+ if (auto Attr = InfoCache.getWavesPerEUAttr (*F)) {
888+ if (*Attr != MaxWavesPerEURange) {
889+ TakeRange (*Attr);
890+ return ;
891+ }
895892 }
896893
897- if (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
898- indicatePessimisticFixpoint ();
894+ // Unlike AAAMDFlatWorkGroupSize, it's getting trickier here. Since the
895+ // calculation of waves per EU involves flat work group size, we can't
896+ // simply use an assumed flat work group size as a start point, because the
897+ // update of flat work group size is in an inverse direction of waves per
898+ // EU. However, we can still do something if it is an entry function. Since
899+ // an entry function is a terminal node, and flat work group size either
900+ // from attribute or default will be used anyway, we can take that value and
901+ // calculate the waves per EU based on it. This result can't be updated by
902+ // no means, but that could still allow us to propagate it.
903+ if (AMDGPU::isEntryFunctionCC (F->getCallingConv ())) {
904+ std::pair<unsigned , unsigned > FlatWorkGroupSize;
905+ if (auto Attr = InfoCache.getFlatWorkGroupSizeAttr (*F))
906+ FlatWorkGroupSize = *Attr;
907+ else
908+ FlatWorkGroupSize = InfoCache.getDefaultFlatWorkGroupSize (*F);
909+ TakeRange (InfoCache.getEffectiveWavesPerEU (*F, MaxWavesPerEURange,
910+ FlatWorkGroupSize));
911+ }
899912 }
900913
901914 ChangeStatus updateImpl (Attributor &A) override {
@@ -944,8 +957,8 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
944957 ChangeStatus manifest (Attributor &A) override {
945958 Function *F = getAssociatedFunction ();
946959 auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
947- unsigned Max = InfoCache. getMaxWavesPerEU (*F);
948- return emitAttributeIfNotDefault ( A, 1 , Max );
960+ return emitAttributeIfNotDefaultAfterClamp (
961+ A, { 1U , InfoCache. getMaxWavesPerEU (*F)} );
949962 }
950963
951964 // / See AbstractAttribute::getName()
0 commit comments