@@ -179,6 +179,11 @@ class AMDGPUInformationCache : public InformationCache {
179179 return {ST.getMinFlatWorkGroupSize (), ST.getMaxFlatWorkGroupSize ()};
180180 }
181181
182+ SmallVector<unsigned > getMaxNumWorkGroups (const Function &F) {
183+ const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
184+ return ST.getMaxNumWorkGroups (F);
185+ }
186+
182187 // / Get code object version.
183188 unsigned getCodeObjectVersion () const {
184189 return CodeObjectVersion;
@@ -821,6 +826,150 @@ AAAMDFlatWorkGroupSize::createForPosition(const IRPosition &IRP,
821826 " AAAMDFlatWorkGroupSize is only valid for function position" );
822827}
823828
829+ struct TupleDecIntegerRangeState : public AbstractState {
830+ DecIntegerState<uint32_t > X, Y, Z;
831+
832+ bool isValidState () const override {
833+ return X.isValidState () && Y.isValidState () && Z.isValidState ();
834+ }
835+
836+ bool isAtFixpoint () const override {
837+ return X.isAtFixpoint () && Y.isAtFixpoint () && Z.isAtFixpoint ();
838+ }
839+
840+ ChangeStatus indicateOptimisticFixpoint () override {
841+ return X.indicateOptimisticFixpoint () | Y.indicateOptimisticFixpoint () |
842+ Z.indicateOptimisticFixpoint ();
843+ }
844+
845+ ChangeStatus indicatePessimisticFixpoint () override {
846+ return X.indicatePessimisticFixpoint () | Y.indicatePessimisticFixpoint () |
847+ Z.indicatePessimisticFixpoint ();
848+ }
849+
850+ TupleDecIntegerRangeState operator ^=(const TupleDecIntegerRangeState &Other) {
851+ X ^= Other.X ;
852+ Y ^= Other.Y ;
853+ Z ^= Other.Z ;
854+ return *this ;
855+ }
856+
857+ bool operator ==(const TupleDecIntegerRangeState &Other) const {
858+ return X == Other.X && Y == Other.Y && Z == Other.Z ;
859+ }
860+
861+ TupleDecIntegerRangeState &getAssumed () { return *this ; }
862+ const TupleDecIntegerRangeState &getAssumed () const { return *this ; }
863+ };
864+
865+ using AAAMDMaxNumWorkgroupsState =
866+ StateWrapper<TupleDecIntegerRangeState, AbstractAttribute, uint32_t >;
867+
868+ // / Propagate amdgpu-max-num-workgroups attribute.
869+ struct AAAMDMaxNumWorkgroups
870+ : public StateWrapper<TupleDecIntegerRangeState, AbstractAttribute> {
871+ using Base = StateWrapper<TupleDecIntegerRangeState, AbstractAttribute>;
872+
873+ AAAMDMaxNumWorkgroups (const IRPosition &IRP, Attributor &A) : Base(IRP) {}
874+
875+ void initialize (Attributor &A) override {
876+ Function *F = getAssociatedFunction ();
877+ auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
878+
879+ SmallVector<unsigned > MaxNumWorkgroups = InfoCache.getMaxNumWorkGroups (*F);
880+
881+ // FIXME: What is the interpretation of 0?
882+ for (unsigned &Entry : MaxNumWorkgroups) {
883+ if (Entry == 0 )
884+ Entry = std::numeric_limits<uint32_t >::max ();
885+ }
886+
887+ X.takeKnownMinimum (MaxNumWorkgroups[0 ]);
888+ Y.takeKnownMinimum (MaxNumWorkgroups[1 ]);
889+ Z.takeKnownMinimum (MaxNumWorkgroups[2 ]);
890+
891+ if (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
892+ indicatePessimisticFixpoint ();
893+ }
894+
895+ ChangeStatus updateImpl (Attributor &A) override {
896+ ChangeStatus Change = ChangeStatus::UNCHANGED;
897+
898+ auto CheckCallSite = [&](AbstractCallSite CS) {
899+ Function *Caller = CS.getInstruction ()->getFunction ();
900+ LLVM_DEBUG (dbgs () << " [AAAMDMaxNumWorkgroups] Call " << Caller->getName ()
901+ << " ->" << getAssociatedFunction ()->getName () << ' \n ' );
902+
903+ const auto *CallerInfo = A.getAAFor <AAAMDMaxNumWorkgroups>(
904+ *this , IRPosition::function (*Caller), DepClassTy::REQUIRED);
905+ if (!CallerInfo)
906+ return false ;
907+
908+ Change |=
909+ clampStateAndIndicateChange (this ->getState (), CallerInfo->getState ());
910+ return true ;
911+ };
912+
913+ bool AllCallSitesKnown = true ;
914+ if (!A.checkForAllCallSites (CheckCallSite, *this , true , AllCallSitesKnown))
915+ return indicatePessimisticFixpoint ();
916+
917+ return Change;
918+ }
919+
920+ // / Create an abstract attribute view for the position \p IRP.
921+ static AAAMDMaxNumWorkgroups &createForPosition (const IRPosition &IRP,
922+ Attributor &A);
923+
924+ ChangeStatus manifest (Attributor &A) override {
925+ Function *F = getAssociatedFunction ();
926+ // TODO: Skip adding if worst case?
927+ LLVMContext &Ctx = F->getContext ();
928+ SmallString<32 > Buffer;
929+ raw_svector_ostream OS (Buffer);
930+ OS << X.getAssumed () << ' ,' << Y.getAssumed () << ' ,' << Z.getAssumed ();
931+
932+ // TODO: Should annotate loads of the group size for this to do anything
933+ // useful.
934+ return A.manifestAttrs (
935+ getIRPosition (),
936+ {Attribute::get (Ctx, " amdgpu-max-num-workgroups" , OS.str ())},
937+ /* ForceReplace= */ true );
938+ }
939+
940+ const std::string getName () const override { return " AAAMDMaxNumWorkgroups" ; }
941+
942+ const std::string getAsStr (Attributor *) const override {
943+ std::string Buffer = " AAAMDMaxNumWorkgroupsState[" ;
944+ raw_string_ostream OS (Buffer);
945+ OS << X.getAssumed () << ' ,' << Y.getAssumed () << ' ,' << Z.getAssumed ()
946+ << ' ]' ;
947+ return OS.str ();
948+ }
949+
950+ const char *getIdAddr () const override { return &ID; }
951+
952+ // / This function should return true if the type of the \p AA is
953+ // / AAAMDMaxNumWorkgroups
954+ static bool classof (const AbstractAttribute *AA) {
955+ return (AA->getIdAddr () == &ID);
956+ }
957+
958+ void trackStatistics () const override {}
959+
960+ // / Unique ID (due to the unique address)
961+ static const char ID;
962+ };
963+
964+ const char AAAMDMaxNumWorkgroups::ID = 0 ;
965+
966+ AAAMDMaxNumWorkgroups &
967+ AAAMDMaxNumWorkgroups::createForPosition (const IRPosition &IRP, Attributor &A) {
968+ if (IRP.getPositionKind () == IRPosition::IRP_FUNCTION)
969+ return *new (A.Allocator ) AAAMDMaxNumWorkgroups (IRP, A);
970+ llvm_unreachable (" AAAMDMaxNumWorkgroups is only valid for function position" );
971+ }
972+
824973// / Propagate amdgpu-waves-per-eu attribute.
825974struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
826975 AAAMDWavesPerEU (const IRPosition &IRP, Attributor &A)
@@ -1043,8 +1192,8 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
10431192 DenseSet<const char *> Allowed (
10441193 {&AAAMDAttributes::ID, &AAUniformWorkGroupSize::ID,
10451194 &AAPotentialValues::ID, &AAAMDFlatWorkGroupSize::ID,
1046- &AAAMDWavesPerEU ::ID, &AAAMDGPUNoAGPR ::ID, &AACallEdges ::ID,
1047- &AAPointerInfo::ID, &AAPotentialConstantValues::ID,
1195+ &AAAMDMaxNumWorkgroups ::ID, &AAAMDWavesPerEU ::ID, &AAAMDGPUNoAGPR ::ID,
1196+ &AACallEdges::ID, & AAPointerInfo::ID, &AAPotentialConstantValues::ID,
10481197 &AAUnderlyingObjects::ID, &AAAddressSpace::ID, &AAIndirectCallInfo::ID,
10491198 &AAInstanceInfo::ID});
10501199
@@ -1068,6 +1217,7 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
10681217 for (auto *F : Functions) {
10691218 A.getOrCreateAAFor <AAAMDAttributes>(IRPosition::function (*F));
10701219 A.getOrCreateAAFor <AAUniformWorkGroupSize>(IRPosition::function (*F));
1220+ A.getOrCreateAAFor <AAAMDMaxNumWorkgroups>(IRPosition::function (*F));
10711221 A.getOrCreateAAFor <AAAMDGPUNoAGPR>(IRPosition::function (*F));
10721222 CallingConv::ID CC = F->getCallingConv ();
10731223 if (!AMDGPU::isEntryFunctionCC (CC)) {
0 commit comments