Skip to content

Commit 5ab6aed

Browse files
Jose M Monsalve Diazshiltian
authored andcommitted
[OpenMP] Folding threadLimit and numThreads when single value in kernels
The device runtime contains several calls to `__kmpc_get_hardware_num_threads_in_block` and `__kmpc_get_hardware_num_blocks`. If the thread_limit and the num_teams are constant, these calls can be folded to the constant value. In this patch we use the already introduced `AAFoldRuntimeCall` and the `NumTeams` and `NumThreads` kernel attributes (to be introduced in a different patch) to fold these functions. The code checks all the kernels, and if their attributes match, the functions are folded. In the future we will explore specializing for multiple values of NumThreads and NumTeams. Depends on D106390 Reviewed By: jdoerfert, JonChesterfield Differential Revision: https://reviews.llvm.org/D106033
1 parent 4819b75 commit 5ab6aed

File tree

4 files changed

+198
-39
lines changed

4 files changed

+198
-39
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPKinds.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ __OMP_RTL(__kmpc_omp_reg_task_with_affinity, false, Int32, IdentPtr, Int32,
206206
/* kmp_task_t */ VoidPtr, Int32,
207207
/* kmp_task_affinity_info_t */ VoidPtr)
208208

209+
__OMP_RTL(__kmpc_get_hardware_num_blocks, false, Int32, )
210+
__OMP_RTL(__kmpc_get_hardware_num_threads_in_block, false, Int32, )
211+
209212
__OMP_RTL(omp_get_thread_num, false, Int32, )
210213
__OMP_RTL(omp_get_num_threads, false, Int32, )
211214
__OMP_RTL(omp_get_max_threads, false, Int32, )
@@ -601,6 +604,9 @@ __OMP_RTL_ATTRS(__kmpc_omp_reg_task_with_affinity, DefaultAttrs, AttributeSet(),
601604
ParamAttrs(ReadOnlyPtrAttrs, AttributeSet(), ReadOnlyPtrAttrs,
602605
AttributeSet(), ReadOnlyPtrAttrs))
603606

607+
__OMP_RTL_ATTRS(__kmpc_get_hardware_num_blocks, GetterAttrs, AttributeSet(), ParamAttrs())
608+
__OMP_RTL_ATTRS(__kmpc_get_hardware_num_threads_in_block, GetterAttrs, AttributeSet(), ParamAttrs())
609+
604610
__OMP_RTL_ATTRS(omp_get_thread_num, GetterAttrs, AttributeSet(), ParamAttrs())
605611
__OMP_RTL_ATTRS(omp_get_num_threads, GetterAttrs, AttributeSet(), ParamAttrs())
606612
__OMP_RTL_ATTRS(omp_get_max_threads, GetterAttrs, AttributeSet(), ParamAttrs())

llvm/lib/Transforms/IPO/OpenMPOpt.cpp

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,6 +1833,8 @@ struct OpenMPOpt {
18331833
return Changed == ChangeStatus::CHANGED;
18341834
}
18351835

1836+
void registerFoldRuntimeCall(RuntimeFunction RF);
1837+
18361838
/// Populate the Attributor with abstract attribute opportunities in the
18371839
/// function.
18381840
void registerAAs(bool IsModulePass);
@@ -3506,6 +3508,8 @@ struct AAKernelInfoCallSite : AAKernelInfo {
35063508
case OMPRTL___kmpc_is_spmd_exec_mode:
35073509
case OMPRTL___kmpc_for_static_fini:
35083510
case OMPRTL___kmpc_global_thread_num:
3511+
case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3512+
case OMPRTL___kmpc_get_hardware_num_blocks:
35093513
case OMPRTL___kmpc_single:
35103514
case OMPRTL___kmpc_end_single:
35113515
case OMPRTL___kmpc_master:
@@ -3710,7 +3714,6 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
37103714

37113715
ChangeStatus updateImpl(Attributor &A) override {
37123716
ChangeStatus Changed = ChangeStatus::UNCHANGED;
3713-
37143717
switch (RFKind) {
37153718
case OMPRTL___kmpc_is_spmd_exec_mode:
37163719
Changed |= foldIsSPMDExecMode(A);
@@ -3721,6 +3724,12 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
37213724
case OMPRTL___kmpc_parallel_level:
37223725
Changed |= foldParallelLevel(A);
37233726
break;
3727+
case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3728+
Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
3729+
break;
3730+
case OMPRTL___kmpc_get_hardware_num_blocks:
3731+
Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
3732+
break;
37243733
default:
37253734
llvm_unreachable("Unhandled OpenMP runtime function!");
37263735
}
@@ -3892,7 +3901,39 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
38923901
"Expected only non-SPMD kernels!");
38933902
SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
38943903
}
3904+
return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
3905+
: ChangeStatus::CHANGED;
3906+
}
3907+
3908+
ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
3909+
// Specialize only if all the calls agree with the attribute constant value
3910+
int32_t CurrentAttrValue = -1;
3911+
Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
3912+
3913+
auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
3914+
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
38953915

3916+
if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
3917+
return indicatePessimisticFixpoint();
3918+
3919+
// Iterate over the kernels that reach this function
3920+
for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
3921+
int32_t NextAttrVal = -1;
3922+
if (K->hasFnAttribute(Attr))
3923+
NextAttrVal =
3924+
std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
3925+
3926+
if (NextAttrVal == -1 ||
3927+
(CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
3928+
return indicatePessimisticFixpoint();
3929+
CurrentAttrValue = NextAttrVal;
3930+
}
3931+
3932+
if (CurrentAttrValue != -1) {
3933+
auto &Ctx = getAnchorValue().getContext();
3934+
SimplifiedValue =
3935+
ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
3936+
}
38963937
return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
38973938
: ChangeStatus::CHANGED;
38983939
}
@@ -3908,6 +3949,21 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
39083949

39093950
} // namespace
39103951

3952+
/// Register folding callsite
3953+
void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
3954+
auto &RFI = OMPInfoCache.RFIs[RF];
3955+
RFI.foreachUse(SCC, [&](Use &U, Function &F) {
3956+
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3957+
if (!CI)
3958+
return false;
3959+
A.getOrCreateAAFor<AAFoldRuntimeCall>(
3960+
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
3961+
DepClassTy::NONE, /* ForceUpdate */ false,
3962+
/* UpdateAfterInit */ false);
3963+
return false;
3964+
});
3965+
}
3966+
39113967
void OpenMPOpt::registerAAs(bool IsModulePass) {
39123968
if (SCC.empty())
39133969

@@ -3923,43 +3979,12 @@ void OpenMPOpt::registerAAs(bool IsModulePass) {
39233979
DepClassTy::NONE, /* ForceUpdate */ false,
39243980
/* UpdateAfterInit */ false);
39253981

3926-
auto &IsMainRFI =
3927-
OMPInfoCache.RFIs[OMPRTL___kmpc_is_generic_main_thread_id];
3928-
IsMainRFI.foreachUse(SCC, [&](Use &U, Function &F) {
3929-
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsMainRFI);
3930-
if (!CI)
3931-
return false;
3932-
A.getOrCreateAAFor<AAFoldRuntimeCall>(
3933-
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
3934-
DepClassTy::NONE, /* ForceUpdate */ false,
3935-
/* UpdateAfterInit */ false);
3936-
return false;
3937-
});
39383982

3939-
auto &IsSPMDRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode];
3940-
IsSPMDRFI.foreachUse(SCC, [&](Use &U, Function &) {
3941-
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsSPMDRFI);
3942-
if (!CI)
3943-
return false;
3944-
A.getOrCreateAAFor<AAFoldRuntimeCall>(
3945-
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
3946-
DepClassTy::NONE, /* ForceUpdate */ false,
3947-
/* UpdateAfterInit */ false);
3948-
return false;
3949-
});
3950-
3951-
auto &ParallelLevelRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_level];
3952-
ParallelLevelRFI.foreachUse(SCC, [&](Use &U, Function &) {
3953-
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &ParallelLevelRFI);
3954-
if (!CI)
3955-
return false;
3956-
A.getOrCreateAAFor<AAFoldRuntimeCall>(
3957-
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
3958-
DepClassTy::NONE, /* ForceUpdate */ false,
3959-
/* UpdateAfterInit */ false);
3960-
3961-
return false;
3962-
});
3983+
registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
3984+
registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
3985+
registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
3986+
registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
3987+
registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
39633988
}
39643989

39653990
// Create CallSite AA for all Getters.
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --check-globals
2+
; RUN: opt -S -passes=openmp-opt < %s | FileCheck %s
3+
target triple = "nvptx64"
4+
5+
%struct.ident_t = type { i32, i32, i32, i32, i8* }
6+
7+
@kernel0_exec_mode = weak constant i8 1
8+
9+
@G = external global i32
10+
;.
11+
; CHECK: @[[G:[a-zA-Z0-9_$"\\.-]+]] = external global i32
12+
;.
13+
define weak void @kernel0() #0 {
14+
; CHECK-LABEL: define {{[^@]+}}@kernel0()
15+
; CHECK: #[[ATTR0:[0-9]+]] {
16+
; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
17+
; CHECK-NEXT: call void @helper0()
18+
; CHECK-NEXT: call void @helper1()
19+
; CHECK-NEXT: call void @helper2()
20+
; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
21+
; CHECK-NEXT: ret void
22+
;
23+
%i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
24+
call void @helper0()
25+
call void @helper1()
26+
call void @helper2()
27+
call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
28+
ret void
29+
}
30+
31+
@kernel1_exec_mode = weak constant i8 1
32+
33+
define weak void @kernel1() #0 {
34+
; CHECK-LABEL: define {{[^@]+}}@kernel1()
35+
; CHECK: #[[ATTR0]] {
36+
; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
37+
; CHECK-NEXT: call void @helper1()
38+
; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
39+
; CHECK-NEXT: ret void
40+
;
41+
%i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
42+
call void @helper1()
43+
call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
44+
ret void
45+
}
46+
47+
@kernel2_exec_mode = weak constant i8 1
48+
49+
define weak void @kernel2() #0 {
50+
; CHECK-LABEL: define {{[^@]+}}@kernel2()
51+
; CHECK: #[[ATTR0]] {
52+
; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false)
53+
; CHECK-NEXT: call void @helper0()
54+
; CHECK-NEXT: call void @helper1()
55+
; CHECK-NEXT: call void @helper2()
56+
; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
57+
; CHECK-NEXT: ret void
58+
;
59+
%i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false)
60+
call void @helper0()
61+
call void @helper1()
62+
call void @helper2()
63+
call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
64+
ret void
65+
}
66+
67+
define internal void @helper0() {
68+
; CHECK-LABEL: define {{[^@]+}}@helper0() {{#[0-9]+}} {
69+
; CHECK-NEXT: store i32 666, i32* @G, align 4
70+
; CHECK-NEXT: ret void
71+
;
72+
%threadLimit = call i32 @__kmpc_get_hardware_num_threads_in_block()
73+
store i32 %threadLimit, i32* @G
74+
ret void
75+
}
76+
77+
define internal void @helper1() {
78+
; CHECK-LABEL: define {{[^@]+}}@helper1() {{#[0-9]+}} {
79+
; CHECK-NEXT: br label [[F:%.*]]
80+
; CHECK: t:
81+
; CHECK-NEXT: unreachable
82+
; CHECK: f:
83+
; CHECK-NEXT: ret void
84+
;
85+
%threadLimit = call i32 @__kmpc_get_hardware_num_threads_in_block()
86+
%c = icmp eq i32 %threadLimit, 666
87+
br i1 %c, label %f, label %t
88+
t:
89+
call void @helper0()
90+
ret void
91+
f:
92+
ret void
93+
}
94+
95+
define internal void @helper2() {
96+
; CHECK-LABEL: define {{[^@]+}}@helper2() {{#[0-9]+}} {
97+
; CHECK-NEXT: store i32 666, i32* @G
98+
; CHECK-NEXT: ret void
99+
;
100+
%threadLimit = call i32 @__kmpc_get_hardware_num_threads_in_block()
101+
store i32 %threadLimit, i32* @G
102+
ret void
103+
}
104+
105+
declare i32 @__kmpc_get_hardware_num_threads_in_block()
106+
declare i32 @__kmpc_target_init(%struct.ident_t*, i1 zeroext, i1 zeroext, i1 zeroext) #1
107+
declare void @__kmpc_target_deinit(%struct.ident_t* nocapture readnone, i1 zeroext, i1 zeroext) #1
108+
109+
110+
!llvm.module.flags = !{!0, !1}
111+
!nvvm.annotations = !{!2, !3, !4}
112+
113+
attributes #0 = { "omp_target_thread_limit"="666" "omp_target_num_teams"="777"}
114+
115+
!0 = !{i32 7, !"openmp", i32 50}
116+
!1 = !{i32 7, !"openmp-device", i32 50}
117+
!2 = !{void ()* @kernel0, !"kernel", i32 1}
118+
!3 = !{void ()* @kernel1, !"kernel", i32 1}
119+
!4 = !{void ()* @kernel2, !"kernel", i32 1}
120+
;.
121+
; CHECK: attributes #[[ATTR0]] = { "omp_target_num_teams"="777" "omp_target_thread_limit"="666" }
122+
;
123+
; CHECK: [[META0:![0-9]+]] = !{i32 7, !"openmp", i32 50}
124+
; CHECK: [[META1:![0-9]+]] = !{i32 7, !"openmp-device", i32 50}
125+
; CHECK: [[META2:![0-9]+]] = !{void ()* @kernel0, !"kernel", i32 1}
126+
; CHECK: [[META3:![0-9]+]] = !{void ()* @kernel1, !"kernel", i32 1}
127+
; CHECK: [[META4:![0-9]+]] = !{void ()* @kernel2, !"kernel", i32 1}
128+
;.

openmp/libomptarget/deviceRTLs/target_interface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
// Calls to the NVPTX layer (assuming 1D layout)
1919
EXTERN int __kmpc_get_hardware_thread_id_in_block();
2020
EXTERN int GetBlockIdInKernel();
21-
EXTERN int __kmpc_get_hardware_num_blocks();
22-
EXTERN int __kmpc_get_hardware_num_threads_in_block();
21+
EXTERN NOINLINE int __kmpc_get_hardware_num_blocks();
22+
EXTERN NOINLINE int __kmpc_get_hardware_num_threads_in_block();
2323
EXTERN unsigned GetWarpId();
2424
EXTERN unsigned GetWarpSize();
2525
EXTERN unsigned GetLaneId();

0 commit comments

Comments
 (0)