@@ -1833,6 +1833,8 @@ struct OpenMPOpt {
1833
1833
return Changed == ChangeStatus::CHANGED;
1834
1834
}
1835
1835
1836
+ void registerFoldRuntimeCall (RuntimeFunction RF);
1837
+
1836
1838
// / Populate the Attributor with abstract attribute opportunities in the
1837
1839
// / function.
1838
1840
void registerAAs (bool IsModulePass);
@@ -3506,6 +3508,8 @@ struct AAKernelInfoCallSite : AAKernelInfo {
3506
3508
case OMPRTL___kmpc_is_spmd_exec_mode:
3507
3509
case OMPRTL___kmpc_for_static_fini:
3508
3510
case OMPRTL___kmpc_global_thread_num:
3511
+ case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3512
+ case OMPRTL___kmpc_get_hardware_num_blocks:
3509
3513
case OMPRTL___kmpc_single:
3510
3514
case OMPRTL___kmpc_end_single:
3511
3515
case OMPRTL___kmpc_master:
@@ -3710,7 +3714,6 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
3710
3714
3711
3715
ChangeStatus updateImpl (Attributor &A) override {
3712
3716
ChangeStatus Changed = ChangeStatus::UNCHANGED;
3713
-
3714
3717
switch (RFKind) {
3715
3718
case OMPRTL___kmpc_is_spmd_exec_mode:
3716
3719
Changed |= foldIsSPMDExecMode (A);
@@ -3721,6 +3724,12 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
3721
3724
case OMPRTL___kmpc_parallel_level:
3722
3725
Changed |= foldParallelLevel (A);
3723
3726
break ;
3727
+ case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3728
+ Changed = Changed | foldKernelFnAttribute (A, " omp_target_thread_limit" );
3729
+ break ;
3730
+ case OMPRTL___kmpc_get_hardware_num_blocks:
3731
+ Changed = Changed | foldKernelFnAttribute (A, " omp_target_num_teams" );
3732
+ break ;
3724
3733
default :
3725
3734
llvm_unreachable (" Unhandled OpenMP runtime function!" );
3726
3735
}
@@ -3892,7 +3901,39 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
3892
3901
" Expected only non-SPMD kernels!" );
3893
3902
SimplifiedValue = ConstantInt::get (Type::getInt8Ty (Ctx), 0 );
3894
3903
}
3904
+ return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
3905
+ : ChangeStatus::CHANGED;
3906
+ }
3907
+
3908
+ ChangeStatus foldKernelFnAttribute (Attributor &A, llvm::StringRef Attr) {
3909
+ // Specialize only if all the calls agree with the attribute constant value
3910
+ int32_t CurrentAttrValue = -1 ;
3911
+ Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
3912
+
3913
+ auto &CallerKernelInfoAA = A.getAAFor <AAKernelInfo>(
3914
+ *this , IRPosition::function (*getAnchorScope ()), DepClassTy::REQUIRED);
3895
3915
3916
+ if (!CallerKernelInfoAA.ReachingKernelEntries .isValidState ())
3917
+ return indicatePessimisticFixpoint ();
3918
+
3919
+ // Iterate over the kernels that reach this function
3920
+ for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries ) {
3921
+ int32_t NextAttrVal = -1 ;
3922
+ if (K->hasFnAttribute (Attr))
3923
+ NextAttrVal =
3924
+ std::stoi (K->getFnAttribute (Attr).getValueAsString ().str ());
3925
+
3926
+ if (NextAttrVal == -1 ||
3927
+ (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
3928
+ return indicatePessimisticFixpoint ();
3929
+ CurrentAttrValue = NextAttrVal;
3930
+ }
3931
+
3932
+ if (CurrentAttrValue != -1 ) {
3933
+ auto &Ctx = getAnchorValue ().getContext ();
3934
+ SimplifiedValue =
3935
+ ConstantInt::get (Type::getInt32Ty (Ctx), CurrentAttrValue);
3936
+ }
3896
3937
return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
3897
3938
: ChangeStatus::CHANGED;
3898
3939
}
@@ -3908,6 +3949,21 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
3908
3949
3909
3950
} // namespace
3910
3951
3952
+ // / Register folding callsite
3953
+ void OpenMPOpt::registerFoldRuntimeCall (RuntimeFunction RF) {
3954
+ auto &RFI = OMPInfoCache.RFIs [RF];
3955
+ RFI.foreachUse (SCC, [&](Use &U, Function &F) {
3956
+ CallInst *CI = OpenMPOpt::getCallIfRegularCall (U, &RFI);
3957
+ if (!CI)
3958
+ return false ;
3959
+ A.getOrCreateAAFor <AAFoldRuntimeCall>(
3960
+ IRPosition::callsite_returned (*CI), /* QueryingAA */ nullptr ,
3961
+ DepClassTy::NONE, /* ForceUpdate */ false ,
3962
+ /* UpdateAfterInit */ false );
3963
+ return false ;
3964
+ });
3965
+ }
3966
+
3911
3967
void OpenMPOpt::registerAAs (bool IsModulePass) {
3912
3968
if (SCC.empty ())
3913
3969
@@ -3923,43 +3979,12 @@ void OpenMPOpt::registerAAs(bool IsModulePass) {
3923
3979
DepClassTy::NONE, /* ForceUpdate */ false ,
3924
3980
/* UpdateAfterInit */ false );
3925
3981
3926
- auto &IsMainRFI =
3927
- OMPInfoCache.RFIs [OMPRTL___kmpc_is_generic_main_thread_id];
3928
- IsMainRFI.foreachUse (SCC, [&](Use &U, Function &F) {
3929
- CallInst *CI = OpenMPOpt::getCallIfRegularCall (U, &IsMainRFI);
3930
- if (!CI)
3931
- return false ;
3932
- A.getOrCreateAAFor <AAFoldRuntimeCall>(
3933
- IRPosition::callsite_returned (*CI), /* QueryingAA */ nullptr ,
3934
- DepClassTy::NONE, /* ForceUpdate */ false ,
3935
- /* UpdateAfterInit */ false );
3936
- return false ;
3937
- });
3938
3982
3939
- auto &IsSPMDRFI = OMPInfoCache.RFIs [OMPRTL___kmpc_is_spmd_exec_mode];
3940
- IsSPMDRFI.foreachUse (SCC, [&](Use &U, Function &) {
3941
- CallInst *CI = OpenMPOpt::getCallIfRegularCall (U, &IsSPMDRFI);
3942
- if (!CI)
3943
- return false ;
3944
- A.getOrCreateAAFor <AAFoldRuntimeCall>(
3945
- IRPosition::callsite_returned (*CI), /* QueryingAA */ nullptr ,
3946
- DepClassTy::NONE, /* ForceUpdate */ false ,
3947
- /* UpdateAfterInit */ false );
3948
- return false ;
3949
- });
3950
-
3951
- auto &ParallelLevelRFI = OMPInfoCache.RFIs [OMPRTL___kmpc_parallel_level];
3952
- ParallelLevelRFI.foreachUse (SCC, [&](Use &U, Function &) {
3953
- CallInst *CI = OpenMPOpt::getCallIfRegularCall (U, &ParallelLevelRFI);
3954
- if (!CI)
3955
- return false ;
3956
- A.getOrCreateAAFor <AAFoldRuntimeCall>(
3957
- IRPosition::callsite_returned (*CI), /* QueryingAA */ nullptr ,
3958
- DepClassTy::NONE, /* ForceUpdate */ false ,
3959
- /* UpdateAfterInit */ false );
3960
-
3961
- return false ;
3962
- });
3983
+ registerFoldRuntimeCall (OMPRTL___kmpc_is_generic_main_thread_id);
3984
+ registerFoldRuntimeCall (OMPRTL___kmpc_is_spmd_exec_mode);
3985
+ registerFoldRuntimeCall (OMPRTL___kmpc_parallel_level);
3986
+ registerFoldRuntimeCall (OMPRTL___kmpc_get_hardware_num_threads_in_block);
3987
+ registerFoldRuntimeCall (OMPRTL___kmpc_get_hardware_num_blocks);
3963
3988
}
3964
3989
3965
3990
// Create CallSite AA for all Getters.
0 commit comments