Skip to content

Commit 2ab02b6

Browse files
authored
[flang][OpenMP] Support multi-block reduction combiner regions on the GPU (#156837)
Fixes a bug related to insertion points when inlining multi-block combiner reduction regions. The IP at the end of the inlined region was not used resulting in emitting BBs with multiple terminators. PR stack: - #155754 - #155987 - #155992 - #155993 - #157638 - #156610 - #156837 ◀️
1 parent d34f738 commit 2ab02b6

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3507,6 +3507,8 @@ Expected<Function *> OpenMPIRBuilder::createReductionFunction(
35073507
return AfterIP.takeError();
35083508
if (!Builder.GetInsertBlock())
35093509
return ReductionFunc;
3510+
3511+
Builder.restoreIP(*AfterIP);
35103512
Builder.CreateStore(Reduced, LHSPtr);
35113513
}
35123514
}
@@ -3751,6 +3753,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
37513753
RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
37523754
if (!AfterIP)
37533755
return AfterIP.takeError();
3756+
Builder.restoreIP(*AfterIP);
37543757
Builder.CreateStore(Reduced, LHS, false);
37553758
}
37563759
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// Verifies that the IR builder can handle reductions with multi-block combiner
4+
// regions on the GPU.
5+
6+
module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
7+
llvm.func @bar() {}
8+
llvm.func @baz() {}
9+
10+
omp.declare_reduction @add_reduction_byref_box_5xf32 : !llvm.ptr alloc {
11+
%0 = llvm.mlir.constant(1 : i64) : i64
12+
%1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr<5>
13+
%2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
14+
omp.yield(%2 : !llvm.ptr)
15+
} init {
16+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
17+
omp.yield(%arg1 : !llvm.ptr)
18+
} combiner {
19+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
20+
llvm.call @bar() : () -> ()
21+
llvm.br ^bb3
22+
23+
^bb3: // pred: ^bb1
24+
llvm.call @baz() : () -> ()
25+
omp.yield(%arg0 : !llvm.ptr)
26+
}
27+
llvm.func @foo_() {
28+
%c1 = llvm.mlir.constant(1 : i64) : i64
29+
%10 = llvm.alloca %c1 x !llvm.array<5 x f32> {bindc_name = "x"} : (i64) -> !llvm.ptr<5>
30+
%11 = llvm.addrspacecast %10 : !llvm.ptr<5> to !llvm.ptr
31+
%74 = omp.map.info var_ptr(%11 : !llvm.ptr, !llvm.array<5 x f32>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "x"}
32+
omp.target map_entries(%74 -> %arg0 : !llvm.ptr) {
33+
%c1_2 = llvm.mlir.constant(1 : i32) : i32
34+
%c10 = llvm.mlir.constant(10 : i32) : i32
35+
omp.teams reduction(byref @add_reduction_byref_box_5xf32 %arg0 -> %arg2 : !llvm.ptr) {
36+
omp.parallel {
37+
omp.distribute {
38+
omp.wsloop {
39+
omp.loop_nest (%arg5) : i32 = (%c1_2) to (%c10) inclusive step (%c1_2) {
40+
omp.yield
41+
}
42+
} {omp.composite}
43+
} {omp.composite}
44+
omp.terminator
45+
} {omp.composite}
46+
omp.terminator
47+
}
48+
omp.terminator
49+
}
50+
llvm.return
51+
}
52+
}
53+
54+
// CHECK: call void @__kmpc_parallel_51({{.*}}, i32 1, i32 -1, i32 -1,
55+
// CHECK-SAME: ptr @[[PAR_OUTLINED:.*]], ptr null, ptr %2, i64 1)
56+
57+
// CHECK: define internal void @[[PAR_OUTLINED]]{{.*}} {
58+
// CHECK: .omp.reduction.then:
59+
// CHECK: br label %omp.reduction.nonatomic.body
60+
61+
// CHECK: omp.reduction.nonatomic.body:
62+
// CHECK: call void @bar()
63+
// CHECK: br label %[[BODY_2ND_BB:.*]]
64+
65+
// CHECK: [[BODY_2ND_BB]]:
66+
// CHECK: call void @baz()
67+
// CHECK: br label %[[CONT_BB:.*]]
68+
69+
// CHECK: [[CONT_BB]]:
70+
// CHECK-NEXT: %[[RED_RHS:.*]] = phi ptr [ %final.rhs, %{{.*}} ]
71+
// CHECK-NEXT: store ptr %[[RED_RHS]], ptr %{{.*}}, align 8
72+
// CHECK-NEXT: br label %.omp.reduction.done
73+
// CHECK: }
74+
75+
// CHECK: define internal void @"{{.*}}$reduction$reduction_func"(ptr noundef %0, ptr noundef %1) #0 {
76+
// CHECK: br label %omp.reduction.nonatomic.body
77+
78+
// CHECK: [[BODY_2ND_BB:.*]]:
79+
// CHECK: call void @baz()
80+
// CHECK: br label %omp.region.cont
81+
82+
83+
// CHECK: omp.reduction.nonatomic.body:
84+
// CHECK: call void @bar()
85+
// CHECK: br label %[[BODY_2ND_BB]]
86+
87+
// CHECK: }

0 commit comments

Comments
 (0)