@@ -179,6 +179,11 @@ class AMDGPUInformationCache : public InformationCache {
179179 return {ST.getMinFlatWorkGroupSize (), ST.getMaxFlatWorkGroupSize ()};
180180 }
181181
182+ SmallVector<unsigned > getMaxNumWorkGroups (const Function &F) {
183+ const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
184+ return ST.getMaxNumWorkGroups (F);
185+ }
186+
182187 // / Get code object version.
183188 unsigned getCodeObjectVersion () const { return CodeObjectVersion; }
184189
@@ -821,6 +826,145 @@ AAAMDFlatWorkGroupSize::createForPosition(const IRPosition &IRP,
821826 " AAAMDFlatWorkGroupSize is only valid for function position" );
822827}
823828
829+ struct TupleDecIntegerRangeState : public AbstractState {
830+ DecIntegerState<uint32_t > X, Y, Z;
831+
832+ bool isValidState () const override {
833+ return X.isValidState () && Y.isValidState () && Z.isValidState ();
834+ }
835+
836+ bool isAtFixpoint () const override {
837+ return X.isAtFixpoint () && Y.isAtFixpoint () && Z.isAtFixpoint ();
838+ }
839+
840+ ChangeStatus indicateOptimisticFixpoint () override {
841+ return X.indicateOptimisticFixpoint () | Y.indicateOptimisticFixpoint () |
842+ Z.indicateOptimisticFixpoint ();
843+ }
844+
845+ ChangeStatus indicatePessimisticFixpoint () override {
846+ return X.indicatePessimisticFixpoint () | Y.indicatePessimisticFixpoint () |
847+ Z.indicatePessimisticFixpoint ();
848+ }
849+
850+ TupleDecIntegerRangeState operator ^=(const TupleDecIntegerRangeState &Other) {
851+ X ^= Other.X ;
852+ Y ^= Other.Y ;
853+ Z ^= Other.Z ;
854+ return *this ;
855+ }
856+
857+ bool operator ==(const TupleDecIntegerRangeState &Other) const {
858+ return X == Other.X && Y == Other.Y && Z == Other.Z ;
859+ }
860+
861+ TupleDecIntegerRangeState &getAssumed () { return *this ; }
862+ const TupleDecIntegerRangeState &getAssumed () const { return *this ; }
863+ };
864+
865+ using AAAMDMaxNumWorkgroupsState =
866+ StateWrapper<TupleDecIntegerRangeState, AbstractAttribute, uint32_t >;
867+
868+ // / Propagate amdgpu-max-num-workgroups attribute.
869+ struct AAAMDMaxNumWorkgroups
870+ : public StateWrapper<TupleDecIntegerRangeState, AbstractAttribute> {
871+ using Base = StateWrapper<TupleDecIntegerRangeState, AbstractAttribute>;
872+
873+ AAAMDMaxNumWorkgroups (const IRPosition &IRP, Attributor &A) : Base(IRP) {}
874+
875+ void initialize (Attributor &A) override {
876+ Function *F = getAssociatedFunction ();
877+ auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
878+
879+ SmallVector<unsigned > MaxNumWorkgroups = InfoCache.getMaxNumWorkGroups (*F);
880+
881+ X.takeKnownMinimum (MaxNumWorkgroups[0 ]);
882+ Y.takeKnownMinimum (MaxNumWorkgroups[1 ]);
883+ Z.takeKnownMinimum (MaxNumWorkgroups[2 ]);
884+
885+ if (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
886+ indicatePessimisticFixpoint ();
887+ }
888+
889+ ChangeStatus updateImpl (Attributor &A) override {
890+ ChangeStatus Change = ChangeStatus::UNCHANGED;
891+
892+ auto CheckCallSite = [&](AbstractCallSite CS) {
893+ Function *Caller = CS.getInstruction ()->getFunction ();
894+ LLVM_DEBUG (dbgs () << " [AAAMDMaxNumWorkgroups] Call " << Caller->getName ()
895+ << " ->" << getAssociatedFunction ()->getName () << ' \n ' );
896+
897+ const auto *CallerInfo = A.getAAFor <AAAMDMaxNumWorkgroups>(
898+ *this , IRPosition::function (*Caller), DepClassTy::REQUIRED);
899+ if (!CallerInfo || !CallerInfo->isValidState ())
900+ return false ;
901+
902+ Change |=
903+ clampStateAndIndicateChange (this ->getState (), CallerInfo->getState ());
904+ return true ;
905+ };
906+
907+ bool AllCallSitesKnown = true ;
908+ if (!A.checkForAllCallSites (CheckCallSite, *this ,
909+ /* RequireAllCallSites=*/ true ,
910+ AllCallSitesKnown))
911+ return indicatePessimisticFixpoint ();
912+
913+ return Change;
914+ }
915+
916+ // / Create an abstract attribute view for the position \p IRP.
917+ static AAAMDMaxNumWorkgroups &createForPosition (const IRPosition &IRP,
918+ Attributor &A);
919+
920+ ChangeStatus manifest (Attributor &A) override {
921+ Function *F = getAssociatedFunction ();
922+ LLVMContext &Ctx = F->getContext ();
923+ SmallString<32 > Buffer;
924+ raw_svector_ostream OS (Buffer);
925+ OS << X.getAssumed () << ' ,' << Y.getAssumed () << ' ,' << Z.getAssumed ();
926+
927+ // TODO: Should annotate loads of the group size for this to do anything
928+ // useful.
929+ return A.manifestAttrs (
930+ getIRPosition (),
931+ {Attribute::get (Ctx, " amdgpu-max-num-workgroups" , OS.str ())},
932+ /* ForceReplace= */ true );
933+ }
934+
935+ const std::string getName () const override { return " AAAMDMaxNumWorkgroups" ; }
936+
937+ const std::string getAsStr (Attributor *) const override {
938+ std::string Buffer = " AAAMDMaxNumWorkgroupsState[" ;
939+ raw_string_ostream OS (Buffer);
940+ OS << X.getAssumed () << ' ,' << Y.getAssumed () << ' ,' << Z.getAssumed ()
941+ << ' ]' ;
942+ return OS.str ();
943+ }
944+
945+ const char *getIdAddr () const override { return &ID; }
946+
947+ // / This function should return true if the type of the \p AA is
948+ // / AAAMDMaxNumWorkgroups
949+ static bool classof (const AbstractAttribute *AA) {
950+ return (AA->getIdAddr () == &ID);
951+ }
952+
953+ void trackStatistics () const override {}
954+
955+ // / Unique ID (due to the unique address)
956+ static const char ID;
957+ };
958+
959+ const char AAAMDMaxNumWorkgroups::ID = 0 ;
960+
961+ AAAMDMaxNumWorkgroups &
962+ AAAMDMaxNumWorkgroups::createForPosition (const IRPosition &IRP, Attributor &A) {
963+ if (IRP.getPositionKind () == IRPosition::IRP_FUNCTION)
964+ return *new (A.Allocator ) AAAMDMaxNumWorkgroups (IRP, A);
965+ llvm_unreachable (" AAAMDMaxNumWorkgroups is only valid for function position" );
966+ }
967+
824968// / Propagate amdgpu-waves-per-eu attribute.
825969struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
826970 AAAMDWavesPerEU (const IRPosition &IRP, Attributor &A)
@@ -1046,8 +1190,8 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
10461190 DenseSet<const char *> Allowed (
10471191 {&AAAMDAttributes::ID, &AAUniformWorkGroupSize::ID,
10481192 &AAPotentialValues::ID, &AAAMDFlatWorkGroupSize::ID,
1049- &AAAMDWavesPerEU ::ID, &AAAMDGPUNoAGPR ::ID, &AACallEdges ::ID,
1050- &AAPointerInfo::ID, &AAPotentialConstantValues::ID,
1193+ &AAAMDMaxNumWorkgroups ::ID, &AAAMDWavesPerEU ::ID, &AAAMDGPUNoAGPR ::ID,
1194+ &AACallEdges::ID, & AAPointerInfo::ID, &AAPotentialConstantValues::ID,
10511195 &AAUnderlyingObjects::ID, &AAAddressSpace::ID, &AAIndirectCallInfo::ID,
10521196 &AAInstanceInfo::ID});
10531197
@@ -1071,6 +1215,7 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
10711215 for (auto *F : Functions) {
10721216 A.getOrCreateAAFor <AAAMDAttributes>(IRPosition::function (*F));
10731217 A.getOrCreateAAFor <AAUniformWorkGroupSize>(IRPosition::function (*F));
1218+ A.getOrCreateAAFor <AAAMDMaxNumWorkgroups>(IRPosition::function (*F));
10741219 A.getOrCreateAAFor <AAAMDGPUNoAGPR>(IRPosition::function (*F));
10751220 CallingConv::ID CC = F->getCallingConv ();
10761221 if (!AMDGPU::isEntryFunctionCC (CC)) {
0 commit comments