@@ -213,25 +213,12 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
213
213
if (loop.lowerBound ().empty ())
214
214
return failure ();
215
215
216
- if (loop.getNumLoops () != 1 )
217
- return opInst.emitOpError (" collapsed loops not yet supported" );
218
-
219
216
// Static is the default.
220
217
omp::ClauseScheduleKind schedule = omp::ClauseScheduleKind::Static;
221
218
if (loop.schedule_val ().hasValue ())
222
219
schedule =
223
220
*omp::symbolizeClauseScheduleKind (loop.schedule_val ().getValue ());
224
221
225
- // Find the loop configuration.
226
- llvm::Value *lowerBound = moduleTranslation.lookupValue (loop.lowerBound ()[0 ]);
227
- llvm::Value *upperBound = moduleTranslation.lookupValue (loop.upperBound ()[0 ]);
228
- llvm::Value *step = moduleTranslation.lookupValue (loop.step ()[0 ]);
229
- llvm::Type *ivType = step->getType ();
230
- llvm::Value *chunk =
231
- loop.schedule_chunk_var ()
232
- ? moduleTranslation.lookupValue (loop.schedule_chunk_var ())
233
- : llvm::ConstantInt::get (ivType, 1 );
234
-
235
222
// Set up the source location value for OpenMP runtime.
236
223
llvm::DISubprogram *subprogram =
237
224
builder.GetInsertBlock ()->getParent ()->getSubprogram ();
@@ -240,22 +227,29 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
240
227
llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder.saveIP (),
241
228
llvm::DebugLoc (diLoc));
242
229
243
- // Generator of the canonical loop body. Produces an SESE region of basic
244
- // blocks.
230
+ // Generator of the canonical loop body.
245
231
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
246
232
// relying on captured variables.
233
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
234
+ SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
247
235
LogicalResult bodyGenStatus = success ();
248
236
auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
249
- llvm::IRBuilder<>::InsertPointGuard guard (builder);
250
-
251
237
// Make sure further conversions know about the induction variable.
252
- moduleTranslation.mapValue (loop.getRegion ().front ().getArgument (0 ), iv);
238
+ moduleTranslation.mapValue (
239
+ loop.getRegion ().front ().getArgument (loopInfos.size ()), iv);
240
+
241
+ // Capture the body insertion point for use in nested loops. BodyIP of the
242
+ // CanonicalLoopInfo always points to the beginning of the entry block of
243
+ // the body.
244
+ bodyInsertPoints.push_back (ip);
245
+
246
+ if (loopInfos.size () != loop.getNumLoops () - 1 )
247
+ return ;
253
248
249
+ // Convert the body of the loop.
254
250
llvm::BasicBlock *entryBlock = ip.getBlock ();
255
251
llvm::BasicBlock *exitBlock =
256
252
entryBlock->splitBasicBlock (ip.getPoint (), " omp.wsloop.exit" );
257
-
258
- // Convert the body of the loop.
259
253
convertOmpOpRegions (loop.region (), " omp.wsloop.region" , *entryBlock,
260
254
*exitBlock, builder, moduleTranslation, bodyGenStatus);
261
255
};
@@ -264,17 +258,46 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
264
258
// TODO: this currently assumes WsLoop is semantically similar to SCF loop,
265
259
// i.e. it has a positive step, uses signed integer semantics. Reconsider
266
260
// this code when WsLoop clearly supports more cases.
261
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
262
+ for (unsigned i = 0 , e = loop.getNumLoops (); i < e; ++i) {
263
+ llvm::Value *lowerBound =
264
+ moduleTranslation.lookupValue (loop.lowerBound ()[i]);
265
+ llvm::Value *upperBound =
266
+ moduleTranslation.lookupValue (loop.upperBound ()[i]);
267
+ llvm::Value *step = moduleTranslation.lookupValue (loop.step ()[i]);
268
+
269
+ // Make sure loop trip count are emitted in the preheader of the outermost
270
+ // loop at the latest so that they are all available for the new collapsed
271
+ // loop will be created below.
272
+ llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
273
+ llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP ;
274
+ if (i != 0 ) {
275
+ loc = llvm::OpenMPIRBuilder::LocationDescription (bodyInsertPoints.back (),
276
+ llvm::DebugLoc (diLoc));
277
+ computeIP = loopInfos.front ()->getPreheaderIP ();
278
+ }
279
+ loopInfos.push_back (ompBuilder->createCanonicalLoop (
280
+ loc, bodyGen, lowerBound, upperBound, step,
281
+ /* IsSigned=*/ true , loop.inclusive (), computeIP));
282
+
283
+ if (failed (bodyGenStatus))
284
+ return failure ();
285
+ }
286
+
287
+ // Collapse loops. Store the insertion point because LoopInfos may get
288
+ // invalidated.
289
+ llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front ()->getAfterIP ();
267
290
llvm::CanonicalLoopInfo *loopInfo =
268
- moduleTranslation.getOpenMPBuilder ()->createCanonicalLoop (
269
- ompLoc, bodyGen, lowerBound, upperBound, step, /* IsSigned=*/ true ,
270
- /* InclusiveStop=*/ loop.inclusive ());
271
- if (failed (bodyGenStatus))
272
- return failure ();
291
+ ompBuilder->collapseLoops (diLoc, loopInfos, {});
273
292
293
+ // Find the loop configuration.
294
+ llvm::Type *ivType = loopInfo->getIndVar ()->getType ();
295
+ llvm::Value *chunk =
296
+ loop.schedule_chunk_var ()
297
+ ? moduleTranslation.lookupValue (loop.schedule_chunk_var ())
298
+ : llvm::ConstantInt::get (ivType, 1 );
274
299
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
275
300
findAllocaInsertPoint (builder, moduleTranslation);
276
- llvm::OpenMPIRBuilder::InsertPointTy afterIP;
277
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
278
301
279
302
bool isSimd = false ;
280
303
if (auto simd = loop.simd_modifier ()) {
@@ -283,9 +306,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
283
306
}
284
307
285
308
if (schedule == omp::ClauseScheduleKind::Static) {
286
- loopInfo = ompBuilder->createStaticWorkshareLoop (ompLoc, loopInfo, allocaIP,
287
- !loop.nowait (), chunk);
288
- afterIP = loopInfo->getAfterIP ();
309
+ ompBuilder->createStaticWorkshareLoop (ompLoc, loopInfo, allocaIP,
310
+ !loop.nowait (), chunk);
289
311
} else {
290
312
llvm::omp::OMPScheduleType schedType;
291
313
switch (schedule) {
@@ -328,11 +350,14 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
328
350
break ;
329
351
}
330
352
}
331
- afterIP = ompBuilder->createDynamicWorkshareLoop (
353
+ ompBuilder->createDynamicWorkshareLoop (
332
354
ompLoc, loopInfo, allocaIP, schedType, !loop.nowait (), chunk);
333
355
}
334
356
335
- // Continue building IR after the loop.
357
+ // Continue building IR after the loop. Note that the LoopInfo returned by
358
+ // `collapseLoops` points inside the outermost loop and is intended for
359
+ // potential further loop transformations. Use the insertion point stored
360
+ // before collapsing loops instead.
336
361
builder.restoreIP (afterIP);
337
362
return success ();
338
363
}
0 commit comments