Skip to content

Commit f8144e6

Browse files
ftynseLeporacanthicus
authored andcommitted
[mlir] support collapsed loops in OpenMP-to-LLVM translation
Reviewed By: Meinersbur Differential Revision: https://reviews.llvm.org/D105706
1 parent 3a164ae commit f8144e6

File tree

2 files changed

+115
-42
lines changed

2 files changed

+115
-42
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -213,25 +213,12 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
213213
if (loop.lowerBound().empty())
214214
return failure();
215215

216-
if (loop.getNumLoops() != 1)
217-
return opInst.emitOpError("collapsed loops not yet supported");
218-
219216
// Static is the default.
220217
omp::ClauseScheduleKind schedule = omp::ClauseScheduleKind::Static;
221218
if (loop.schedule_val().hasValue())
222219
schedule =
223220
*omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue());
224221

225-
// Find the loop configuration.
226-
llvm::Value *lowerBound = moduleTranslation.lookupValue(loop.lowerBound()[0]);
227-
llvm::Value *upperBound = moduleTranslation.lookupValue(loop.upperBound()[0]);
228-
llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]);
229-
llvm::Type *ivType = step->getType();
230-
llvm::Value *chunk =
231-
loop.schedule_chunk_var()
232-
? moduleTranslation.lookupValue(loop.schedule_chunk_var())
233-
: llvm::ConstantInt::get(ivType, 1);
234-
235222
// Set up the source location value for OpenMP runtime.
236223
llvm::DISubprogram *subprogram =
237224
builder.GetInsertBlock()->getParent()->getSubprogram();
@@ -240,22 +227,29 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
240227
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
241228
llvm::DebugLoc(diLoc));
242229

243-
// Generator of the canonical loop body. Produces an SESE region of basic
244-
// blocks.
230+
// Generator of the canonical loop body.
245231
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
246232
// relying on captured variables.
233+
SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
234+
SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
247235
LogicalResult bodyGenStatus = success();
248236
auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
249-
llvm::IRBuilder<>::InsertPointGuard guard(builder);
250-
251237
// Make sure further conversions know about the induction variable.
252-
moduleTranslation.mapValue(loop.getRegion().front().getArgument(0), iv);
238+
moduleTranslation.mapValue(
239+
loop.getRegion().front().getArgument(loopInfos.size()), iv);
240+
241+
// Capture the body insertion point for use in nested loops. BodyIP of the
242+
// CanonicalLoopInfo always points to the beginning of the entry block of
243+
// the body.
244+
bodyInsertPoints.push_back(ip);
245+
246+
if (loopInfos.size() != loop.getNumLoops() - 1)
247+
return;
253248

249+
// Convert the body of the loop.
254250
llvm::BasicBlock *entryBlock = ip.getBlock();
255251
llvm::BasicBlock *exitBlock =
256252
entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
257-
258-
// Convert the body of the loop.
259253
convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
260254
*exitBlock, builder, moduleTranslation, bodyGenStatus);
261255
};
@@ -264,17 +258,46 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
264258
// TODO: this currently assumes WsLoop is semantically similar to SCF loop,
265259
// i.e. it has a positive step, uses signed integer semantics. Reconsider
266260
// this code when WsLoop clearly supports more cases.
261+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
262+
for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
263+
llvm::Value *lowerBound =
264+
moduleTranslation.lookupValue(loop.lowerBound()[i]);
265+
llvm::Value *upperBound =
266+
moduleTranslation.lookupValue(loop.upperBound()[i]);
267+
llvm::Value *step = moduleTranslation.lookupValue(loop.step()[i]);
268+
269+
// Make sure loop trip count are emitted in the preheader of the outermost
270+
// loop at the latest so that they are all available for the new collapsed
271+
// loop will be created below.
272+
llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
273+
llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
274+
if (i != 0) {
275+
loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
276+
llvm::DebugLoc(diLoc));
277+
computeIP = loopInfos.front()->getPreheaderIP();
278+
}
279+
loopInfos.push_back(ompBuilder->createCanonicalLoop(
280+
loc, bodyGen, lowerBound, upperBound, step,
281+
/*IsSigned=*/true, loop.inclusive(), computeIP));
282+
283+
if (failed(bodyGenStatus))
284+
return failure();
285+
}
286+
287+
// Collapse loops. Store the insertion point because LoopInfos may get
288+
// invalidated.
289+
llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
267290
llvm::CanonicalLoopInfo *loopInfo =
268-
moduleTranslation.getOpenMPBuilder()->createCanonicalLoop(
269-
ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
270-
/*InclusiveStop=*/loop.inclusive());
271-
if (failed(bodyGenStatus))
272-
return failure();
291+
ompBuilder->collapseLoops(diLoc, loopInfos, {});
273292

293+
// Find the loop configuration.
294+
llvm::Type *ivType = loopInfo->getIndVar()->getType();
295+
llvm::Value *chunk =
296+
loop.schedule_chunk_var()
297+
? moduleTranslation.lookupValue(loop.schedule_chunk_var())
298+
: llvm::ConstantInt::get(ivType, 1);
274299
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
275300
findAllocaInsertPoint(builder, moduleTranslation);
276-
llvm::OpenMPIRBuilder::InsertPointTy afterIP;
277-
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
278301

279302
bool isSimd = false;
280303
if (auto simd = loop.simd_modifier()) {
@@ -283,9 +306,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
283306
}
284307

285308
if (schedule == omp::ClauseScheduleKind::Static) {
286-
loopInfo = ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP,
287-
!loop.nowait(), chunk);
288-
afterIP = loopInfo->getAfterIP();
309+
ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP,
310+
!loop.nowait(), chunk);
289311
} else {
290312
llvm::omp::OMPScheduleType schedType;
291313
switch (schedule) {
@@ -328,11 +350,14 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
328350
break;
329351
}
330352
}
331-
afterIP = ompBuilder->createDynamicWorkshareLoop(
353+
ompBuilder->createDynamicWorkshareLoop(
332354
ompLoc, loopInfo, allocaIP, schedType, !loop.nowait(), chunk);
333355
}
334356

335-
// Continue building IR after the loop.
357+
// Continue building IR after the loop. Note that the LoopInfo returned by
358+
// `collapseLoops` points inside the outermost loop and is intended for
359+
// potential further loop transformations. Use the insertion point stored
360+
// before collapsing loops instead.
336361
builder.restoreIP(afterIP);
337362
return success();
338363
}

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -479,14 +479,62 @@ llvm.func @test_omp_wsloop_dynamic_nonmonotonic(%lb : i64, %ub : i64, %step : i6
479479
llvm.return
480480
}
481481

482-
llvm.func @test_omp_wsloop_dynamic_monotonic(%lb : i64, %ub : i64, %step : i64) -> () {
483-
omp.wsloop (%iv) : i64 = (%lb) to (%ub) step (%step) schedule(dynamic, monotonic) {
484-
// CHECK: call void @__kmpc_dispatch_init_8u(%struct.ident_t* @{{.*}}, i32 %{{.*}}, i32 536870947
485-
// CHECK: %[[continue:.*]] = call i32 @__kmpc_dispatch_next_8u
486-
// CHECK: %[[cond:.*]] = icmp ne i32 %[[continue]], 0
487-
// CHECK br i1 %[[cond]], label %omp_loop.header{{.*}}, label %omp_loop.exit{{.*}}
488-
llvm.call @body(%iv) : (i64) -> ()
489-
omp.yield
490-
}
491-
llvm.return
482+
// -----
483+
484+
// Check that the loop bounds are emitted in the correct location in case of
485+
// collapse. This only checks the overall shape of the IR, detailed checking
486+
// is done by the OpenMPIRBuilder.
487+
488+
// CHECK-LABEL: @collapse_wsloop
489+
// CHECK: i32* noalias %[[TIDADDR:[0-9A-Za-z.]*]]
490+
// CHECK: load i32, i32* %[[TIDADDR]]
491+
// CHECK: store
492+
// CHECK: load
493+
// CHECK: %[[LB0:.*]] = load i32
494+
// CHECK: %[[UB0:.*]] = load i32
495+
// CHECK: %[[STEP0:.*]] = load i32
496+
// CHECK: %[[LB1:.*]] = load i32
497+
// CHECK: %[[UB1:.*]] = load i32
498+
// CHECK: %[[STEP1:.*]] = load i32
499+
// CHECK: %[[LB2:.*]] = load i32
500+
// CHECK: %[[UB2:.*]] = load i32
501+
// CHECK: %[[STEP2:.*]] = load i32
502+
llvm.func @collapse_wsloop(
503+
%0: i32, %1: i32, %2: i32,
504+
%3: i32, %4: i32, %5: i32,
505+
%6: i32, %7: i32, %8: i32,
506+
%20: !llvm.ptr<i32>) {
507+
omp.parallel {
508+
// CHECK: icmp slt i32 %[[LB0]], 0
509+
// CHECK-COUNT-4: select
510+
// CHECK: %[[TRIPCOUNT0:.*]] = select
511+
// CHECK: br label %[[PREHEADER:.*]]
512+
//
513+
// CHECK: [[PREHEADER]]:
514+
// CHECK: icmp slt i32 %[[LB1]], 0
515+
// CHECK-COUNT-4: select
516+
// CHECK: %[[TRIPCOUNT1:.*]] = select
517+
// CHECK: icmp slt i32 %[[LB2]], 0
518+
// CHECK-COUNT-4: select
519+
// CHECK: %[[TRIPCOUNT2:.*]] = select
520+
// CHECK: %[[PROD:.*]] = mul nuw i32 %[[TRIPCOUNT0]], %[[TRIPCOUNT1]]
521+
// CHECK: %[[TOTAL:.*]] = mul nuw i32 %[[PROD]], %[[TRIPCOUNT2]]
522+
// CHECK: br label %[[COLLAPSED_PREHEADER:.*]]
523+
//
524+
// CHECK: [[COLLAPSED_PREHEADER]]:
525+
// CHECK: store i32 0, i32*
526+
// CHECK: %[[TOTAL_SUB_1:.*]] = sub i32 %[[TOTAL]], 1
527+
// CHECK: store i32 %[[TOTAL_SUB_1]], i32*
528+
// CHECK: call void @__kmpc_for_static_init_4u
529+
omp.wsloop (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) collapse(3) {
530+
%31 = llvm.load %20 : !llvm.ptr<i32>
531+
%32 = llvm.add %31, %arg0 : i32
532+
%33 = llvm.add %32, %arg1 : i32
533+
%34 = llvm.add %33, %arg2 : i32
534+
llvm.store %34, %20 : !llvm.ptr<i32>
535+
omp.yield
536+
}
537+
omp.terminator
538+
}
539+
llvm.return
492540
}

0 commit comments

Comments
 (0)