Skip to content

Commit 7be0443

Browse files
committed
[OpenMP][MLIR] Create LLVM IR lifetime markers for OpenMP loop-related allocations
This patch introduces `llvm.lifetime.start` and `llvm.lifetime.end` markers around the LLVM basic blocks containing the translated body of `omp.wsloop` and `omp.simdloop` operations, for all `alloca` instructions that are defined outside of that block but only ever used inside of it. This is achieved by analyzing the MLIR regions associated to the aforementioned OpenMP dialect loop operations during translation to LLVM IR. The purpose of this addition is to enable sinking these allocations into the region if it gets outlined into a separate function, avoiding the need to pass the pointer as an argument. It is a less intrusive alternative to #67010 that addresses the same problem on the interaction between redundant allocations for OpenMP loop indices and loop body outlining for target offload using new DeviceRTL functions.
1 parent d5e2cbd commit 7be0443

File tree

2 files changed

+216
-4
lines changed

2 files changed

+216
-4
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,92 @@ static llvm::BasicBlock *convertOmpOpRegions(
244244
return continuationBlock;
245245
}
246246

247+
/// Finds the set of \c llvm.alloca instructions associated to \c LLVM::AllocaOp
248+
/// MLIR operations for primitive types that are defined outside of the given
249+
/// \p region but only used inside of it.
250+
static void
251+
gatherSinkableAllocas(const LLVM::ModuleTranslation &moduleTranslation,
252+
Region &region,
253+
SetVector<llvm::AllocaInst *> &allocasToSink) {
254+
Operation *op = region.getParentOp();
255+
256+
auto processLoadStore = [&](auto loadStoreOp) {
257+
Value addr = loadStoreOp.getAddr();
258+
Operation *addrOp = addr.getDefiningOp();
259+
260+
// The destination address is already defined in this region or it is not an
261+
// llvm.alloca operation, so skip it.
262+
if (!isa_and_present<LLVM::AllocaOp>(addrOp) || op->isAncestor(addrOp))
263+
return;
264+
265+
// Get LLVM value to which the address is mapped. It has to be mapped to the
266+
// allocation instruction of a scalar type to be marked as sinkable by this
267+
// function.
268+
llvm::Value *llvmAddr = moduleTranslation.lookupValue(addr);
269+
if (!isa_and_present<llvm::AllocaInst>(llvmAddr))
270+
return;
271+
272+
auto *llvmAlloca = cast<llvm::AllocaInst>(llvmAddr);
273+
if (llvmAlloca->getAllocatedType()->getPrimitiveSizeInBits() == 0)
274+
return;
275+
276+
// Check that the address is only used inside of the region.
277+
bool addressUsedOnlyInternally = true;
278+
for (auto &addrUse : addr.getUses()) {
279+
if (!op->isAncestor(addrUse.getOwner())) {
280+
addressUsedOnlyInternally = false;
281+
break;
282+
}
283+
}
284+
285+
if (!addressUsedOnlyInternally)
286+
return;
287+
288+
allocasToSink.insert(llvmAlloca);
289+
};
290+
291+
region.walk([&processLoadStore](Operation *op) {
292+
if (auto loadOp = dyn_cast<LLVM::LoadOp>(op))
293+
processLoadStore(loadOp);
294+
else if (auto storeOp = dyn_cast<LLVM::StoreOp>(op))
295+
processLoadStore(storeOp);
296+
});
297+
}
298+
299+
/// Converts the given region that appears within an OpenMP dialect operation to
300+
/// LLVM IR, according to the process described in \c convertOmpOpRegions(), and
301+
/// marks the lifetime of allocas read/written exclusively inside of the region
302+
/// but defined outside of it.
303+
///
304+
/// This information enables later compilation stages to sink these allocations
305+
/// inside of the region, such as when outlining it into a separate function.
306+
static llvm::BasicBlock *convertOmpOpRegionsWithAllocaLifetimes(
307+
Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
308+
LLVM::ModuleTranslation &moduleTranslation, LogicalResult &bodyGenStatus) {
309+
SetVector<llvm::AllocaInst *> allocasToSink;
310+
gatherSinkableAllocas(moduleTranslation, region, allocasToSink);
311+
312+
for (auto *alloca : allocasToSink) {
313+
unsigned size = alloca->getAllocatedType()->getPrimitiveSizeInBits() / 8;
314+
builder.CreateLifetimeStart(alloca, builder.getInt64(size));
315+
}
316+
317+
llvm::BasicBlock *continuationBlock = convertOmpOpRegions(
318+
region, blockName, builder, moduleTranslation, bodyGenStatus);
319+
320+
if (!allocasToSink.empty()) {
321+
llvm::IRBuilderBase::InsertPointGuard guard(builder);
322+
builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
323+
324+
for (auto *alloca : allocasToSink) {
325+
unsigned size = alloca->getAllocatedType()->getPrimitiveSizeInBits() / 8;
326+
builder.CreateLifetimeEnd(alloca, builder.getInt64(size));
327+
}
328+
}
329+
330+
return continuationBlock;
331+
}
332+
247333
/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
248334
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
249335
switch (kind) {
@@ -910,8 +996,9 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
910996

911997
// Convert the body of the loop.
912998
builder.restoreIP(ip);
913-
convertOmpOpRegions(loop.getRegion(), "omp.wsloop.region", builder,
914-
moduleTranslation, bodyGenStatus);
999+
convertOmpOpRegionsWithAllocaLifetimes(loop.getRegion(),
1000+
"omp.wsloop.region", builder,
1001+
moduleTranslation, bodyGenStatus);
9151002
};
9161003

9171004
// Delegate actual loop construction to the OpenMP IRBuilder.
@@ -1151,8 +1238,9 @@ convertOmpSimdLoop(Operation &opInst, llvm::IRBuilderBase &builder,
11511238

11521239
// Convert the body of the loop.
11531240
builder.restoreIP(ip);
1154-
convertOmpOpRegions(loop.getRegion(), "omp.simdloop.region", builder,
1155-
moduleTranslation, bodyGenStatus);
1241+
convertOmpOpRegionsWithAllocaLifetimes(loop.getRegion(),
1242+
"omp.simdloop.region", builder,
1243+
moduleTranslation, bodyGenStatus);
11561244
};
11571245

11581246
// Delegate actual loop construction to the OpenMP IRBuilder.
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// This test checks the introduction of lifetime information for allocas defined
2+
// outside of omp.wsloop and omp.simdloop regions but only used inside of them.
3+
4+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
5+
6+
llvm.func @foo(%arg0 : i32) {
7+
llvm.return
8+
}
9+
10+
llvm.func @bar(%arg0 : i64) {
11+
llvm.return
12+
}
13+
14+
// CHECK-LABEL: define void @wsloop_i32
15+
llvm.func @wsloop_i32(%size : i64, %lb : i32, %ub : i32, %step : i32) {
16+
// CHECK-DAG: %[[LASTITER:.*]] = alloca i32
17+
// CHECK-DAG: %[[LB:.*]] = alloca i32
18+
// CHECK-DAG: %[[UB:.*]] = alloca i32
19+
// CHECK-DAG: %[[STRIDE:.*]] = alloca i32
20+
// CHECK-DAG: %[[I:.*]] = alloca i32
21+
%1 = llvm.alloca %size x i32 : (i64) -> !llvm.ptr
22+
23+
// CHECK-NOT: %[[I]]
24+
// CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %[[I]])
25+
// CHECK-NEXT: br label %[[WSLOOP_BB:.*]]
26+
// CHECK-NOT: %[[I]]
27+
// CHECK: [[WSLOOP_BB]]:
28+
// CHECK-NOT: {{^.*}}:
29+
// CHECK: br label %[[CONT_BB:.*]]
30+
// CHECK: [[CONT_BB]]:
31+
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr %[[I]])
32+
// CHECK-NOT: %[[I]]
33+
omp.wsloop for (%iv) : i32 = (%lb) to (%ub) step (%step) {
34+
llvm.store %iv, %1 : i32, !llvm.ptr
35+
%2 = llvm.load %1 : !llvm.ptr -> i32
36+
llvm.call @foo(%2) : (i32) -> ()
37+
omp.yield
38+
}
39+
40+
// CHECK: ret void
41+
llvm.return
42+
}
43+
44+
// CHECK-LABEL: define void @wsloop_i64
45+
llvm.func @wsloop_i64(%size : i64, %lb : i64, %ub : i64, %step : i64) {
46+
// CHECK-DAG: %[[LASTITER:.*]] = alloca i32
47+
// CHECK-DAG: %[[LB:.*]] = alloca i64
48+
// CHECK-DAG: %[[UB:.*]] = alloca i64
49+
// CHECK-DAG: %[[STRIDE:.*]] = alloca i64
50+
// CHECK-DAG: %[[I:.*]] = alloca i64
51+
%1 = llvm.alloca %size x i64 : (i64) -> !llvm.ptr
52+
53+
// CHECK-NOT: %[[I]]
54+
// CHECK: call void @llvm.lifetime.start.p0(i64 8, ptr %[[I]])
55+
// CHECK-NEXT: br label %[[WSLOOP_BB:.*]]
56+
// CHECK-NOT: %[[I]]
57+
// CHECK: [[WSLOOP_BB]]:
58+
// CHECK-NOT: {{^.*}}:
59+
// CHECK: br label %[[CONT_BB:.*]]
60+
// CHECK: [[CONT_BB]]:
61+
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr %[[I]])
62+
// CHECK-NOT: %[[I]]
63+
omp.wsloop for (%iv) : i64 = (%lb) to (%ub) step (%step) {
64+
llvm.store %iv, %1 : i64, !llvm.ptr
65+
%2 = llvm.load %1 : !llvm.ptr -> i64
66+
llvm.call @bar(%2) : (i64) -> ()
67+
omp.yield
68+
}
69+
70+
// CHECK: ret void
71+
llvm.return
72+
}
73+
74+
// CHECK-LABEL: define void @simdloop_i32
75+
llvm.func @simdloop_i32(%size : i64, %lb : i32, %ub : i32, %step : i32) {
76+
// CHECK: %[[I:.*]] = alloca i32
77+
%1 = llvm.alloca %size x i32 : (i64) -> !llvm.ptr
78+
79+
// CHECK-NOT: %[[I]]
80+
// CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %[[I]])
81+
// CHECK-NEXT: br label %[[SIMDLOOP_BB:.*]]
82+
// CHECK-NOT: %[[I]]
83+
// CHECK: [[SIMDLOOP_BB]]:
84+
// CHECK-NOT: {{^.*}}:
85+
// CHECK: br label %[[CONT_BB:.*]]
86+
// CHECK: [[CONT_BB]]:
87+
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr %[[I]])
88+
// CHECK-NOT: %[[I]]
89+
omp.simdloop for (%iv) : i32 = (%lb) to (%ub) step (%step) {
90+
llvm.store %iv, %1 : i32, !llvm.ptr
91+
%2 = llvm.load %1 : !llvm.ptr -> i32
92+
llvm.call @foo(%2) : (i32) -> ()
93+
omp.yield
94+
}
95+
96+
// CHECK: ret void
97+
llvm.return
98+
}
99+
100+
// CHECK-LABEL: define void @simdloop_i64
101+
llvm.func @simdloop_i64(%size : i64, %lb : i64, %ub : i64, %step : i64) {
102+
// CHECK: %[[I:.*]] = alloca i64
103+
%1 = llvm.alloca %size x i64 : (i64) -> !llvm.ptr
104+
105+
// CHECK-NOT: %[[I]]
106+
// CHECK: call void @llvm.lifetime.start.p0(i64 8, ptr %[[I]])
107+
// CHECK-NEXT: br label %[[SIMDLOOP_BB:.*]]
108+
// CHECK-NOT: %[[I]]
109+
// CHECK: [[SIMDLOOP_BB]]:
110+
// CHECK-NOT: {{^.*}}:
111+
// CHECK: br label %[[CONT_BB:.*]]
112+
// CHECK: [[CONT_BB]]:
113+
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr %[[I]])
114+
// CHECK-NOT: %[[I]]
115+
omp.simdloop for (%iv) : i64 = (%lb) to (%ub) step (%step) {
116+
llvm.store %iv, %1 : i64, !llvm.ptr
117+
%2 = llvm.load %1 : !llvm.ptr -> i64
118+
llvm.call @bar(%2) : (i64) -> ()
119+
omp.yield
120+
}
121+
122+
// CHECK: ret void
123+
llvm.return
124+
}

0 commit comments

Comments
 (0)