Skip to content

Commit 49715c2

Browse files
committed
[flang][OpenMP] Support multi-block reduction combiner regions on the GPU
Fixes a bug related to insertion points when inlining multi-block combiner reduction regions. The IP at the end of the inlined region was not used resulting in emitting BBs with multiple terminators.
1 parent 6987182 commit 49715c2

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3506,6 +3506,8 @@ Expected<Function *> OpenMPIRBuilder::createReductionFunction(
35063506
return AfterIP.takeError();
35073507
if (!Builder.GetInsertBlock())
35083508
return ReductionFunc;
3509+
3510+
Builder.SetInsertPoint(AfterIP->getBlock(), AfterIP->getPoint());
35093511
Builder.CreateStore(Reduced, LHSPtr);
35103512
}
35113513
}
@@ -3750,6 +3752,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
37503752
RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
37513753
if (!AfterIP)
37523754
return AfterIP.takeError();
3755+
Builder.SetInsertPoint(AfterIP->getBlock(), AfterIP->getPoint());
37533756
Builder.CreateStore(Reduced, LHS, false);
37543757
}
37553758
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// Verifies that the IR builder can handle reductions with multi-block combiner
4+
// regions on the GPU.
5+
6+
module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
7+
llvm.func @bar() {}
8+
llvm.func @baz() {}
9+
10+
omp.declare_reduction @add_reduction_byref_box_5xf32 : !llvm.ptr alloc {
11+
%0 = llvm.mlir.constant(1 : i64) : i64
12+
%1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr<5>
13+
%2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
14+
omp.yield(%2 : !llvm.ptr)
15+
} init {
16+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
17+
omp.yield(%arg1 : !llvm.ptr)
18+
} combiner {
19+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
20+
llvm.call @bar() : () -> ()
21+
llvm.br ^bb3
22+
23+
^bb3: // pred: ^bb1
24+
llvm.call @baz() : () -> ()
25+
omp.yield(%arg0 : !llvm.ptr)
26+
}
27+
llvm.func @foo_() {
28+
%c1 = llvm.mlir.constant(1 : i64) : i64
29+
%10 = llvm.alloca %c1 x !llvm.array<5 x f32> {bindc_name = "x"} : (i64) -> !llvm.ptr<5>
30+
%11 = llvm.addrspacecast %10 : !llvm.ptr<5> to !llvm.ptr
31+
%74 = omp.map.info var_ptr(%11 : !llvm.ptr, !llvm.array<5 x f32>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "x"}
32+
omp.target map_entries(%74 -> %arg0 : !llvm.ptr) {
33+
%c1_2 = llvm.mlir.constant(1 : i32) : i32
34+
%c10 = llvm.mlir.constant(10 : i32) : i32
35+
omp.teams reduction(byref @add_reduction_byref_box_5xf32 %arg0 -> %arg2 : !llvm.ptr) {
36+
omp.parallel {
37+
omp.distribute {
38+
omp.wsloop {
39+
omp.loop_nest (%arg5) : i32 = (%c1_2) to (%c10) inclusive step (%c1_2) {
40+
omp.yield
41+
}
42+
} {omp.composite}
43+
} {omp.composite}
44+
omp.terminator
45+
} {omp.composite}
46+
omp.terminator
47+
}
48+
omp.terminator
49+
}
50+
llvm.return
51+
}
52+
}
53+
54+
// CHECK: call void @__kmpc_parallel_51({{.*}}, i32 1, i32 -1, i32 -1,
55+
// CHECK-SAME: ptr @[[PAR_OUTLINED:.*]], ptr null, ptr %2, i64 1)
56+
57+
// CHECK: define internal void @[[PAR_OUTLINED]]{{.*}} {
58+
// CHECK: .omp.reduction.then:
59+
// CHECK: br label %omp.reduction.nonatomic.body
60+
61+
// CHECK: omp.reduction.nonatomic.body:
62+
// CHECK: call void @bar()
63+
// CHECK: br label %[[BODY_2ND_BB:.*]]
64+
65+
// CHECK: [[BODY_2ND_BB]]:
66+
// CHECK: call void @baz()
67+
// CHECK: br label %[[CONT_BB:.*]]
68+
69+
// CHECK: [[CONT_BB]]:
70+
// CHECK: br label %.omp.reduction.done
71+
// CHECK: }
72+
73+
// CHECK: define internal void @"{{.*}}$reduction$reduction_func"(ptr noundef %0, ptr noundef %1) #0 {
74+
// CHECK: br label %omp.reduction.nonatomic.body
75+
76+
// CHECK: [[BODY_2ND_BB:.*]]:
77+
// CHECK: call void @baz()
78+
// CHECK: br label %omp.region.cont
79+
80+
81+
// CHECK: omp.reduction.nonatomic.body:
82+
// CHECK: call void @bar()
83+
// CHECK: br label %[[BODY_2ND_BB]]
84+
85+
// CHECK: }

0 commit comments

Comments
 (0)