@@ -215,6 +215,19 @@ class AMDGPUInformationCache : public InformationCache {
215215 return ST.getWavesPerEU (F, FlatWorkGroupSize);
216216 }
217217
218+ std::optional<std::pair<unsigned , unsigned >>
219+ getWavesPerEUAttr (const Function &F) {
220+ Attribute Attr = F.getFnAttribute (" amdgpu-waves-per-eu" );
221+ if (!Attr.isStringAttribute ())
222+ return std::nullopt ;
223+ auto Val = parseRangeAttribute (Attr.getValueAsString ());
224+ if (Val && Val->second == 0 ) {
225+ const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
226+ Val->second = ST.getMaxWavesPerEU ();
227+ }
228+ return Val;
229+ }
230+
218231 std::pair<unsigned , unsigned >
219232 getEffectiveWavesPerEU (const Function &F,
220233 std::pair<unsigned , unsigned > WavesPerEU,
@@ -785,22 +798,6 @@ struct AAAMDSizeRangeAttribute
785798 /* ForceReplace=*/ true );
786799 }
787800
788- ChangeStatus emitAttributeIfNotDefault (Attributor &A, unsigned Min,
789- unsigned Max) {
790- // Don't add the attribute if it's the implied default.
791- if (getAssumed ().getLower () == Min && getAssumed ().getUpper () - 1 == Max)
792- return ChangeStatus::UNCHANGED;
793-
794- Function *F = getAssociatedFunction ();
795- LLVMContext &Ctx = F->getContext ();
796- SmallString<10 > Buffer;
797- raw_svector_ostream OS (Buffer);
798- OS << getAssumed ().getLower () << ' ,' << getAssumed ().getUpper () - 1 ;
799- return A.manifestAttrs (getIRPosition (),
800- {Attribute::get (Ctx, AttrName, OS.str ())},
801- /* ForceReplace=*/ true );
802- }
803-
804801 const std::string getAsStr (Attributor *) const override {
805802 std::string Str;
806803 raw_string_ostream OS (Str);
@@ -885,29 +882,44 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
885882 AAAMDWavesPerEU (const IRPosition &IRP, Attributor &A)
886883 : AAAMDSizeRangeAttribute(IRP, A, " amdgpu-waves-per-eu" ) {}
887884
888- bool isValidState () const override {
889- return !Assumed.isEmptySet () && IntegerRangeState::isValidState ();
890- }
891-
892885 void initialize (Attributor &A) override {
893886 Function *F = getAssociatedFunction ();
894887 auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
895888
896- if (const auto *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
897- *this , IRPosition::function (*F), DepClassTy::REQUIRED);
898- AssumedGroupSize->isValidState ()) {
899-
900- unsigned Min, Max;
901- std::tie (Min, Max) = InfoCache.getWavesPerEU (
902- *F, {AssumedGroupSize->getAssumed ().getLower ().getZExtValue (),
903- AssumedGroupSize->getAssumed ().getUpper ().getZExtValue () - 1 });
904-
889+ auto TakeRange = [&](std::pair<unsigned , unsigned > R) {
890+ auto [Min, Max] = R;
905891 ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
906- intersectKnown (Range);
892+ IntegerRangeState RangeState (Range);
893+ clampStateAndIndicateChange (this ->getState (), RangeState);
894+ indicateOptimisticFixpoint ();
895+ };
896+
897+ // If the attribute exists, simple honor it.
898+ if (auto Attr = InfoCache.getWavesPerEUAttr (*F)) {
899+ TakeRange (*Attr);
900+ return ;
907901 }
908902
909- if (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
910- indicatePessimisticFixpoint ();
903+ // It's getting trickier here, different from AAAMDFlatWorkGroupSize. Since
904+ // the calculation of waves per EU involves flat work group size, we can't
905+ // simply use an assumed flat work group size as a start point, because the
906+ // update of flat work group size is in an inverse direction of waves per
907+ // EU. However, we can still do something if it is an entry function. Since
908+ // an entry function is a terminal node, and flat work group size either
909+ // from attribute or default will be used anyway, we can take that value and
910+ // calculate the waves per EU based on it. This result can't be updated by
911+ // no means, but that could still allow us to propagate it.
912+ if (AMDGPU::isEntryFunctionCC (F->getCallingConv ())) {
913+ std::pair<unsigned , unsigned > MaxWavesPerEURange{
914+ 1U , InfoCache.getMaxWavesPerEU (*F)};
915+ std::pair<unsigned , unsigned > FlatWorkGroupSize;
916+ if (auto Attr = InfoCache.getFlatWorkGroupSizeAttr (*F))
917+ FlatWorkGroupSize = *Attr;
918+ else
919+ FlatWorkGroupSize = InfoCache.getDefaultFlatWorkGroupSize (*F);
920+ TakeRange (InfoCache.getEffectiveWavesPerEU (*F, MaxWavesPerEURange,
921+ FlatWorkGroupSize));
922+ }
911923 }
912924
913925 ChangeStatus updateImpl (Attributor &A) override {
@@ -956,8 +968,8 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
956968 ChangeStatus manifest (Attributor &A) override {
957969 Function *F = getAssociatedFunction ();
958970 auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
959- unsigned Max = InfoCache. getMaxWavesPerEU (*F);
960- return emitAttributeIfNotDefault ( A, 1 , Max );
971+ return emitAttributeIfNotDefaultAfterClamp (
972+ A, { 1U , InfoCache. getMaxWavesPerEU (*F)} );
961973 }
962974
963975 // / See AbstractAttribute::getName()
0 commit comments