Skip to content

Conversation

@anchuraj
Copy link
Contributor

@anchuraj anchuraj commented Oct 30, 2025

Scan reductions are supported in OpenMP with the the help of scan directive. Reduction clause of the for workshare loop/simd directive takes an inscan modifier if scan reduction is specified. With an inscan modifier, the body of the directive should specify a scan directive. This PR implements the lowering logic for scan reductions in workshare loops of OpenMP. OpenMPIRBuilder support can be found in #136035. Support for nested loops/ exclusive clause is not done in this PR
Nested Loops, Linear clause, Collapse and Tiling are not enabled now for loops with scan reduction.
Following is a working example:

program inclusive_scan
 implicit none
 integer, parameter :: n = 100
 integer a(n), b(n)
 integer x, k, y, z

 ! initialization
 x = 0
 do k = 1, n
  a(k) = k
 end do

 ! a(k) is included in the computation of producing results in b(k)
 !$omp parallel do reduction(inscan, +: x)
 do k = 1, n
   x = x + a(k)
   !$omp scan inclusive(x)
   b(k) = x
 end do

 print *,'x =', x
end program

@anchuraj anchuraj changed the title [MLIR][LLVMIR] Adding scan lowering to llvm on the mlir side [MLIR][LLVMIR] Adding scan lowering to llvm from mlie Oct 30, 2025
@anchuraj anchuraj force-pushed the scan-llvm-interface branch from cb38aae to ad88725 Compare October 31, 2025 18:41
@anchuraj anchuraj requested review from kparzysz, skatrak and tblah and removed request for skatrak October 31, 2025 18:43
@anchuraj anchuraj marked this pull request as ready for review October 31, 2025 18:43
@llvmbot llvmbot added mlir:llvm mlir flang Flang issues not falling into any other category mlir:openmp flang:fir-hlfir flang:openmp labels Oct 31, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 31, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-openmp

Author: Anchu Rajendran S (anchuraj)

Changes

Scan reductions are supported in OpenMP with the the help of scan directive. Reduction clause of the for workshare loop/simd directive takes an inscan modifier if scan reduction is specified. With an inscan modifier, the body of the directive should specify a scan directive. This PR implements the lowering logic for scan reductions in workshare loops of OpenMP. OpenMPIRBuilder support can be found in #136035. Support for nested loops/ exclusive clause is not done in this PR


Patch is 37.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165788.diff

4 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+34-5)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+303-77)
  • (added) mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir (+130)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (+36-2)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f86ee01355104..5d82466889b1e 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2326,12 +2326,41 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
 static mlir::omp::ScanOp
 genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
-          semantics::SemanticsContext &semaCtx, mlir::Location loc,
-          const ConstructQueue &queue, ConstructQueue::const_iterator item) {
+          semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+          mlir::Location loc, const ConstructQueue &queue,
+          ConstructQueue::const_iterator item) {
   mlir::omp::ScanOperands clauseOps;
   genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
-  return mlir::omp::ScanOp::create(converter.getFirOpBuilder(),
-                                   converter.getCurrentLocation(), clauseOps);
+  mlir::omp::ScanOp scanOp = mlir::omp::ScanOp::create(
+      converter.getFirOpBuilder(), converter.getCurrentLocation(), clauseOps);
+  // If there are nested loops all indices should be loaded after
+  // the scan construct as otherwise, it would result in using the index
+  // variable across scan directive.
+  // (`Intra-iteration dependences from a statement in the structured
+  // block sequence that precede a scan directive to a statement in the
+  // structured block sequence that follows a scan directive must not exist,
+  // except for dependences for the list items specified in an inclusive or
+  // exclusive clause.`).
+  // TODO: If there are nested loops, it is not handled.
+  mlir::omp::LoopNestOp loopNestOp =
+      scanOp->getParentOfType<mlir::omp::LoopNestOp>();
+  assert(loopNestOp.getNumLoops() == 1 &&
+         "Scan directive inside nested do loops is not handled yet.");
+  mlir::Region &region = loopNestOp->getRegion(0);
+  mlir::Value indexVal = fir::getBase(region.getArgument(0));
+  lower::pft::Evaluation *doConstructEval = eval.parentConstruct;
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  lower::pft::Evaluation *doLoop = &doConstructEval->getFirstNestedEvaluation();
+  auto *doStmt = doLoop->getIf<parser::NonLabelDoStmt>();
+  assert(doStmt && "Expected do loop to be in the nested evaluation");
+  const auto &loopControl =
+      std::get<std::optional<parser::LoopControl>>(doStmt->t);
+  const parser::LoopControl::Bounds *bounds =
+      std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
+  mlir::Operation *storeOp =
+      setLoopVar(converter, loc, indexVal, bounds->name.thing.symbol);
+  firOpBuilder.setInsertionPointAfter(storeOp);
+  return scanOp;
 }
 
 static mlir::omp::SectionsOp
@@ -3416,7 +3445,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                                   loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_scan:
-    newOp = genScanOp(converter, symTable, semaCtx, loc, queue, item);
+    newOp = genScanOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_section:
     llvm_unreachable("genOMPDispatch: OMPD_section");
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1e2099d6cc1b2..9db269c4f8756 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -37,6 +37,7 @@
 #include "llvm/TargetParser/Triple.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 
+#include <cassert>
 #include <cstdint>
 #include <iterator>
 #include <numeric>
@@ -77,6 +78,22 @@ class OpenMPAllocaStackFrame
   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
 };
 
+/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
+/// insertion points for allocas of parent of the current parallel region. The
+/// insertion point is used to allocate variables to be share by the threads
+/// executing the parallel region. Lowering of scan reduction requires declaring
+/// shared pointers to the temporary buffer to perform scan reduction.
+class OpenMPParallelAllocaStackFrame
+    : public StateStackFrameBase<OpenMPParallelAllocaStackFrame> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPParallelAllocaStackFrame)
+
+  explicit OpenMPParallelAllocaStackFrame(
+      llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
+      : allocaInsertPoint(allocaIP) {}
+  llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
+};
+
 /// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
 /// collapsed canonical loop information corresponding to an \c omp.loop_nest
 /// operation.
@@ -84,7 +101,13 @@ class OpenMPLoopInfoStackFrame
     : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
-  llvm::CanonicalLoopInfo *loopInfo = nullptr;
+  /// For constructs like scan, one LoopInfo frame can contain multiple
+  /// Canonical Loops as a single openmpLoopNestOp will be split into input
+  /// loop and scan loop.
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
+  llvm::ScanInfo *scanInfo;
+  llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+      new llvm::DenseMap<llvm::Value *, llvm::Type *>();
 };
 
 /// Custom error class to signal translation errors that don't need reporting,
@@ -323,6 +346,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getDistScheduleChunkSize())
       result = todo("dist_schedule with chunk_size");
   };
+  auto checkExclusive = [&todo](auto op, LogicalResult &result) {
+    if (!op.getExclusiveVars().empty())
+      result = todo("exclusive");
+  };
   auto checkHint = [](auto op, LogicalResult &) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
@@ -371,9 +398,14 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       if (!op.getReductionVars().empty() || op.getReductionByref() ||
           op.getReductionSyms())
         result = todo("reduction");
-    if (op.getReductionMod() &&
-        op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
-      result = todo("reduction with modifier");
+    if (op.getReductionMod()) {
+      if (isa<omp::WsloopOp>(op)) {
+        if (op.getReductionMod().value() == omp::ReductionModifier::task)
+          result = todo("reduction with task modifier");
+      } else {
+        result = todo("reduction with modifier");
+      }
+    }
   };
   auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
     if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
@@ -397,6 +429,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkOrder(op, result);
       })
       .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
+      .Case([&](omp::ScanOp op) { checkExclusive(op, result); })
       .Case([&](omp::SectionsOp op) {
         checkAllocate(op, result);
         checkPrivate(op, result);
@@ -531,15 +564,59 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
 /// Find the loop information structure for the loop nest being translated. It
 /// will return a `null` value unless called from the translation function for
 /// a loop wrapper operation after successfully translating its body.
-static llvm::CanonicalLoopInfo *
-findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
-  llvm::CanonicalLoopInfo *loopInfo = nullptr;
+static SmallVector<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfos(LLVM::ModuleTranslation &moduleTranslation) {
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
+  moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+      [&](OpenMPLoopInfoStackFrame &frame) {
+        loopInfos = frame.loopInfos;
+        return WalkResult::interrupt();
+      });
+  return loopInfos;
+}
+
+// LoopFrame stores the scaninfo which is used for scan reduction.
+// Upon encountering an `inscan` reduction modifier, `scanInfoInitialize`
+// initializes the ScanInfo and is used when scan directive is encountered
+// in the body of the loop nest.
+static llvm::ScanInfo *
+findScanInfo(LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::ScanInfo *scanInfo;
+  moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+      [&](OpenMPLoopInfoStackFrame &frame) {
+        scanInfo = frame.scanInfo;
+        return WalkResult::interrupt();
+      });
+  return scanInfo;
+}
+
+// The types of reduction vars are used for lowering scan directive which
+// appears in the body of the loop. The types are stored in loop frame when
+// reduction clause is encountered and is used when scan directive is
+// encountered.
+static llvm::DenseMap<llvm::Value *, llvm::Type *> *
+findReductionVarTypes(LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType = nullptr;
   moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
       [&](OpenMPLoopInfoStackFrame &frame) {
-        loopInfo = frame.loopInfo;
+        reductionVarToType = frame.reductionVarToType;
         return WalkResult::interrupt();
       });
-  return loopInfo;
+  return reductionVarToType;
+}
+
+// Scan reduction requires a shared buffer to be allocated to perform reduction.
+// ParallelAllocaStackFrame holds the allocaIP where shared allocation can be
+// done.
+static llvm::OpenMPIRBuilder::InsertPointTy
+findParallelAllocaIP(LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder::InsertPointTy parallelAllocaIP;
+  moduleTranslation.stackWalk<OpenMPParallelAllocaStackFrame>(
+      [&](OpenMPParallelAllocaStackFrame &frame) {
+        parallelAllocaIP = frame.allocaInsertPoint;
+        return WalkResult::interrupt();
+      });
+  return parallelAllocaIP;
 }
 
 /// Converts the given region that appears within an OpenMP dialect operation to
@@ -1254,11 +1331,17 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
   for (auto [data, addr] : deferredStores)
     builder.CreateStore(data, addr);
 
+  llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+      findReductionVarTypes(moduleTranslation);
   // Before the loop, store the initial values of reductions into reduction
   // variables. Although this could be done after allocas, we don't want to mess
   // up with the alloca insertion point.
   for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
     SmallVector<llvm::Value *, 1> phis;
+    llvm::Type *reductionType =
+        moduleTranslation.convertType(reductionDecls[i].getType());
+    if (reductionVarToType != nullptr)
+      (*reductionVarToType)[privateReductionVariables[i]] = reductionType;
 
     // map block argument to initializer region
     mapInitializationArgs(op, moduleTranslation, reductionDecls,
@@ -1330,15 +1413,20 @@ static void collectReductionInfo(
 
   // Collect the reduction information.
   reductionInfos.reserve(numReductions);
+  llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+      findReductionVarTypes(moduleTranslation);
   for (unsigned i = 0; i < numReductions; ++i) {
     llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
     if (owningAtomicReductionGens[i])
       atomicGen = owningAtomicReductionGens[i];
     llvm::Value *variable =
         moduleTranslation.lookupValue(loop.getReductionVars()[i]);
+    llvm::Type *reductionType =
+        moduleTranslation.convertType(reductionDecls[i].getType());
+    if (reductionVarToType != nullptr)
+      (*reductionVarToType)[privateReductionVariables[i]] = reductionType;
     reductionInfos.push_back(
-        {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
-         privateReductionVariables[i],
+        {reductionType, variable, privateReductionVariables[i],
          /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
          owningReductionGens[i],
          /*ReductionGenClang=*/nullptr, atomicGen});
@@ -2543,6 +2631,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
   bool isSimd = wsloopOp.getScheduleSimd();
   bool loopNeedsBarrier = !wsloopOp.getNowait();
+  bool isInScanRegion =
+      wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+                                     mlir::omp::ReductionModifier::inscan);
 
   // The only legal way for the direct parent to be omp.distribute is that this
   // represents 'distribute parallel do'. Otherwise, this is a regular
@@ -2574,20 +2665,81 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(regionBlock, opInst)))
     return failure();
 
-  llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+      findCurrentLoopInfos(moduleTranslation);
+
+  const auto &&wsloopCodeGen = [&](llvm::CanonicalLoopInfo *loopInfo,
+                                   bool noLoopMode, bool inputScanLoop) {
+    bool emitLinearVarInit = !isInScanRegion || inputScanLoop;
+    // Emit Initialization and Update IR for linear variables
+    if (emitLinearVarInit && !wsloopOp.getLinearVars().empty()) {
+      llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+          linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+                                              loopInfo->getPreheader());
+      if (failed(handleError(afterBarrierIP, *loopOp)))
+        return failure();
+      builder.restoreIP(*afterBarrierIP);
+      linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+                                            loopInfo->getIndVar());
+    }
+    bool emitLinearVarFinalize = !isInScanRegion || !inputScanLoop;
+    if (emitLinearVarFinalize)
+      linearClauseProcessor.outlineLinearFinalizationBB(builder,
+                                                        loopInfo->getExit());
+    builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+        ompBuilder->applyWorkshareLoop(
+            ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+            convertToScheduleKind(schedule), chunk, isSimd,
+            scheduleMod == omp::ScheduleModifier::monotonic,
+            scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+            workshareLoopType, noLoopMode);
+
+    if (failed(handleError(wsloopIP, opInst)))
+      return failure();
 
-  // Emit Initialization and Update IR for linear variables
-  if (!wsloopOp.getLinearVars().empty()) {
-    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
-        linearClauseProcessor.initLinearVar(builder, moduleTranslation,
-                                            loopInfo->getPreheader());
-    if (failed(handleError(afterBarrierIP, *loopOp)))
+    // Emit finalization and in-place rewrites for linear vars.
+    if (emitLinearVarFinalize && !wsloopOp.getLinearVars().empty()) {
+      llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+      if (loopInfo->getLastIter())
+        return failure();
+      // assert(loopInfo->getLastIter() &&
+      //        "`lastiter` in CanonicalLoopInfo is nullptr");
+      llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+          linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
+                                                  loopInfo->getLastIter());
+      if (failed(handleError(afterBarrierIP, *loopOp)))
+        return failure();
+      for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
+        linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
+                                             index);
+      builder.restoreIP(oldIP);
+    }
+    if (!inputScanLoop || !isInScanRegion)
+      popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
+
+    return llvm::success();
+  };
+
+  if (isInScanRegion) {
+    auto inputLoopFinishIp = loopInfos.front()->getAfterIP();
+    builder.restoreIP(inputLoopFinishIp);
+    SmallVector<OwningReductionGen> owningReductionGens;
+    SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
+    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+    collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
+                         owningReductionGens, owningAtomicReductionGens,
+                         privateReductionVariables, reductionInfos);
+    llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont");
+    llvm::ScanInfo *scanInfo = findScanInfo(moduleTranslation);
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
+        ompBuilder->emitScanReduction(builder.saveIP(), reductionInfos,
+                                      scanInfo);
+    if (failed(handleError(redIP, opInst)))
       return failure();
-    builder.restoreIP(*afterBarrierIP);
-    linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
-                                          loopInfo->getIndVar());
-    linearClauseProcessor.outlineLinearFinalizationBB(builder,
-                                                      loopInfo->getExit());
+
+    builder.restoreIP(*redIP);
+    builder.CreateBr(cont);
   }
 
   builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
@@ -2612,42 +2764,34 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     }
   }
 
-  llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
-      ompBuilder->applyWorkshareLoop(
-          ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
-          convertToScheduleKind(schedule), chunk, isSimd,
-          scheduleMod == omp::ScheduleModifier::monotonic,
-          scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-          workshareLoopType, noLoopMode);
-
-  if (failed(handleError(wsloopIP, opInst)))
-    return failure();
-
-  // Emit finalization and in-place rewrites for linear vars.
-  if (!wsloopOp.getLinearVars().empty()) {
-    llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
-    assert(loopInfo->getLastIter() &&
-           "`lastiter` in CanonicalLoopInfo is nullptr");
-    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
-        linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
-                                                loopInfo->getLastIter());
-    if (failed(handleError(afterBarrierIP, *loopOp)))
+  // For Scan loops input loop need not pop cancellation CB and hence, it is set
+  // false for the first loop
+  bool inputScanLoop = isInScanRegion;
+  for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+    if (failed(wsloopCodeGen(loopInfo, noLoopMode, inputScanLoop)))
       return failure();
-    for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
-      linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
-                                           index);
-    builder.restoreIP(oldIP);
+    inputScanLoop = false;
   }
 
-  // Set the correct branch target for task cancellation
-  popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
-
-  // Process the reductions if required.
-  if (failed(createReductionsAndCleanup(
-          wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
-          privateReductionVariables, isByRef, wsloopOp.getNowait(),
-          /*isTeamsReduction=*/false)))
-    return failure();
+  // todo: change builder.saveIP to wsLoopIP
+  if (isInScanRegion) {
+    SmallVector<Region *> reductionRegions;
+    llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
+                    [](omp::DeclareReductionOp reductionDecl) {
+                      return &reductionDecl.getCleanupRegion();
+                    });
+    if (failed(inlineOmpRegionCleanup(
+            reductionRegions, privateReductionVariables, moduleTranslation,
+            builder, "omp.reduction.cleanup")))
+      return failure();
+  } else {
+    // Process the reductions if required.
+    if (failed(createReductionsAndCleanup(
+            wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
+            privateReductionVariables, i...
[truncated]

@anchuraj anchuraj changed the title [MLIR][LLVMIR] Adding scan lowering to llvm from mlie [MLIR][LLVMIR] Adding scan lowering to llvm from mlir Oct 31, 2025
@joker-eph joker-eph changed the title [MLIR][LLVMIR] Adding scan lowering to llvm from mlir [MLIR][OpenMP] Add scan reduction lowering to llvm Oct 31, 2025
@anchuraj anchuraj force-pushed the scan-llvm-interface branch 5 times, most recently from 27a9e19 to ccb1229 Compare November 4, 2025 17:48
@anchuraj anchuraj marked this pull request as draft November 4, 2025 17:56
@anchuraj anchuraj force-pushed the scan-llvm-interface branch from ccb1229 to fb85ab4 Compare November 4, 2025 18:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants