@@ -82,10 +82,7 @@ class OpenMPLoopInfoStackFrame
8282 : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
8383public:
8484 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (OpenMPLoopInfoStackFrame)
85-
86- explicit OpenMPLoopInfoStackFrame (llvm::CanonicalLoopInfo *loopInfo)
87- : loopInfo(loopInfo) {}
88- llvm::CanonicalLoopInfo *loopInfo;
85+ llvm::CanonicalLoopInfo *loopInfo = nullptr ;
8986};
9087
9188// / Custom error class to signal translation errors that don't need reporting,
@@ -348,13 +345,13 @@ static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
348345// / normal operations in the builder.
349346static llvm::OpenMPIRBuilder::InsertPointTy
350347findAllocaInsertPoint (llvm::IRBuilderBase &builder,
351- const LLVM::ModuleTranslation &moduleTranslation) {
348+ LLVM::ModuleTranslation &moduleTranslation) {
352349 // If there is an alloca insertion point on stack, i.e. we are in a nested
353350 // operation and a specific point was provided by some surrounding operation,
354351 // use it.
355352 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
356353 WalkResult walkResult = moduleTranslation.stackWalk <OpenMPAllocaStackFrame>(
357- [&](const OpenMPAllocaStackFrame &frame) {
354+ [&](OpenMPAllocaStackFrame &frame) {
358355 allocaInsertPoint = frame.allocaInsertPoint ;
359356 return WalkResult::interrupt ();
360357 });
@@ -386,13 +383,13 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
386383}
387384
388385// / Find the loop information structure for the loop nest being translated. It
389- // / will not return a value unless called from the translation function for
386+ // / will return a `null` value unless called from the translation function for
390387// / a loop wrapper operation after successfully translating its body.
391- static std::optional< llvm::CanonicalLoopInfo *>
388+ static llvm::CanonicalLoopInfo *
392389findCurrentLoopInfo (LLVM::ModuleTranslation &moduleTranslation) {
393- std::optional< llvm::CanonicalLoopInfo *> loopInfo;
390+ llvm::CanonicalLoopInfo *loopInfo = nullptr ;
394391 moduleTranslation.stackWalk <OpenMPLoopInfoStackFrame>(
395- [&](const OpenMPLoopInfoStackFrame &frame) {
392+ [&](OpenMPLoopInfoStackFrame &frame) {
396393 loopInfo = frame.loopInfo ;
397394 return WalkResult::interrupt ();
398395 });
@@ -1987,7 +1984,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
19871984 return failure ();
19881985
19891986 builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
1990- llvm::CanonicalLoopInfo *loopInfo = * findCurrentLoopInfo (moduleTranslation);
1987+ llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
19911988
19921989 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
19931990 ompBuilder->applyWorkshareLoop (
@@ -2270,16 +2267,16 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
22702267 llvm::Value *alignment = nullptr ;
22712268 llvm::Value *llvmVal = moduleTranslation.lookupValue (operands[i]);
22722269 llvm::Type *ty = llvmVal->getType ();
2273- if ( auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
2274- alignment = builder. getInt64 (intAttr. getInt () );
2275- assert (ty-> isPointerTy () && " Invalid type for aligned variable " );
2276- assert (alignment && " Invalid alignment value " );
2277- auto curInsert = builder. saveIP ( );
2278- builder.SetInsertPoint (sourceBlock );
2279- llvmVal = builder.CreateLoad (ty, llvmVal );
2280- builder.restoreIP (curInsert );
2281- alignedVars[llvmVal] = alignment ;
2282- }
2270+
2271+ auto intAttr = cast<IntegerAttr>((*alignmentValues)[i] );
2272+ alignment = builder. getInt64 (intAttr. getInt () );
2273+ assert (ty-> isPointerTy () && " Invalid type for aligned variable " );
2274+ assert (alignment && " Invalid alignment value " );
2275+ auto curInsert = builder.saveIP ( );
2276+ builder.SetInsertPoint (sourceBlock );
2277+ llvmVal = builder.CreateLoad (ty, llvmVal );
2278+ builder. restoreIP (curInsert) ;
2279+ alignedVars[llvmVal] = alignment;
22832280 }
22842281
22852282 llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions (
@@ -2289,7 +2286,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
22892286 return failure ();
22902287
22912288 builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
2292- llvm::CanonicalLoopInfo *loopInfo = * findCurrentLoopInfo (moduleTranslation);
2289+ llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
22932290 ompBuilder->applySimd (loopInfo, alignedVars,
22942291 simdOp.getIfExpr ()
22952292 ? moduleTranslation.lookupValue (simdOp.getIfExpr ())
@@ -2377,11 +2374,13 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
23772374 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
23782375 loopInfos.front ()->getAfterIP ();
23792376
2380- // Add a stack frame holding information about the resulting loop after
2381- // applying transformations, to be further transformed by parent loop
2382- // wrappers.
2383- moduleTranslation.stackPush <OpenMPLoopInfoStackFrame>(
2384- ompBuilder->collapseLoops (ompLoc.DL , loopInfos, {}));
2377+ // Update the stack frame created for this loop to point to the resulting loop
2378+ // after applying transformations.
2379+ moduleTranslation.stackWalk <OpenMPLoopInfoStackFrame>(
2380+ [&](OpenMPLoopInfoStackFrame &frame) {
2381+ frame.loopInfo = ompBuilder->collapseLoops (ompLoc.DL , loopInfos, {});
2382+ return WalkResult::interrupt ();
2383+ });
23852384
23862385 // Continue building IR after the loop. Note that the LoopInfo returned by
23872386 // `collapseLoops` points inside the outermost loop and is intended for
@@ -4576,6 +4575,19 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
45764575 LLVM::ModuleTranslation &moduleTranslation) {
45774576 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
45784577
4578+ // For each loop, introduce one stack frame to hold loop information. Ensure
4579+ // this is only done for the outermost loop wrapper to prevent introducing
4580+ // multiple stack frames for a single loop. Initially set to null, the loop
4581+ // information structure is initialized during translation of the nested
4582+ // omp.loop_nest operation, making it available to translation of all loop
4583+ // wrappers after their body has been successfully translated.
4584+ bool isOutermostLoopWrapper =
4585+ isa_and_present<omp::LoopWrapperInterface>(op) &&
4586+ !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp ());
4587+
4588+ if (isOutermostLoopWrapper)
4589+ moduleTranslation.stackPush <OpenMPLoopInfoStackFrame>();
4590+
45794591 auto result =
45804592 llvm::TypeSwitch<Operation *, LogicalResult>(op)
45814593 .Case ([&](omp::BarrierOp op) -> LogicalResult {
@@ -4700,19 +4712,7 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
47004712 << " not yet implemented: " << inst->getName ();
47014713 });
47024714
4703- // When translating an omp.loop_nest, one stack frame was pushed to hold that
4704- // loop's information. The code below ensures that this stack frame is removed
4705- // when encountering the outermost loop wrapper associated to that loop. This
4706- // approach allows all loop wrappers have access to that loop's information
4707- // (to e.g. apply transformations to it) after their associated omp.loop_nest
4708- // operation has been translated.
4709- bool isOutermostLoopWrapper =
4710- isa_and_present<omp::LoopWrapperInterface>(op) &&
4711- !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp ());
4712-
4713- // We need to check that a loop info is present as well, in case translation
4714- // of the loop failed before it was created.
4715- if (isOutermostLoopWrapper && findCurrentLoopInfo (moduleTranslation))
4715+ if (isOutermostLoopWrapper)
47164716 moduleTranslation.stackPop ();
47174717
47184718 return result;
0 commit comments