- 
                Notifications
    
You must be signed in to change notification settings  - Fork 15.1k
 
[MLIR][OpenMP] Add scan reduction lowering to llvm #165788
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
cb38aae    to
    ad88725      
    Compare
  
    | 
          
 @llvm/pr-subscribers-mlir @llvm/pr-subscribers-flang-openmp Author: Anchu Rajendran S (anchuraj) ChangesScan reductions are supported in OpenMP with the the help of scan directive. Reduction clause of the for workshare loop/simd directive takes an  Patch is 37.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165788.diff 4 Files Affected: 
 diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f86ee01355104..5d82466889b1e 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2326,12 +2326,41 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
 static mlir::omp::ScanOp
 genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
-          semantics::SemanticsContext &semaCtx, mlir::Location loc,
-          const ConstructQueue &queue, ConstructQueue::const_iterator item) {
+          semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+          mlir::Location loc, const ConstructQueue &queue,
+          ConstructQueue::const_iterator item) {
   mlir::omp::ScanOperands clauseOps;
   genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
-  return mlir::omp::ScanOp::create(converter.getFirOpBuilder(),
-                                   converter.getCurrentLocation(), clauseOps);
+  mlir::omp::ScanOp scanOp = mlir::omp::ScanOp::create(
+      converter.getFirOpBuilder(), converter.getCurrentLocation(), clauseOps);
+  // If there are nested loops all indices should be loaded after
+  // the scan construct as otherwise, it would result in using the index
+  // variable across scan directive.
+  // (`Intra-iteration dependences from a statement in the structured
+  // block sequence that precede a scan directive to a statement in the
+  // structured block sequence that follows a scan directive must not exist,
+  // except for dependences for the list items specified in an inclusive or
+  // exclusive clause.`).
+  // TODO: If there are nested loops, it is not handled.
+  mlir::omp::LoopNestOp loopNestOp =
+      scanOp->getParentOfType<mlir::omp::LoopNestOp>();
+  assert(loopNestOp.getNumLoops() == 1 &&
+         "Scan directive inside nested do loops is not handled yet.");
+  mlir::Region ®ion = loopNestOp->getRegion(0);
+  mlir::Value indexVal = fir::getBase(region.getArgument(0));
+  lower::pft::Evaluation *doConstructEval = eval.parentConstruct;
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  lower::pft::Evaluation *doLoop = &doConstructEval->getFirstNestedEvaluation();
+  auto *doStmt = doLoop->getIf<parser::NonLabelDoStmt>();
+  assert(doStmt && "Expected do loop to be in the nested evaluation");
+  const auto &loopControl =
+      std::get<std::optional<parser::LoopControl>>(doStmt->t);
+  const parser::LoopControl::Bounds *bounds =
+      std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
+  mlir::Operation *storeOp =
+      setLoopVar(converter, loc, indexVal, bounds->name.thing.symbol);
+  firOpBuilder.setInsertionPointAfter(storeOp);
+  return scanOp;
 }
 
 static mlir::omp::SectionsOp
@@ -3416,7 +3445,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                                   loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_scan:
-    newOp = genScanOp(converter, symTable, semaCtx, loc, queue, item);
+    newOp = genScanOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_section:
     llvm_unreachable("genOMPDispatch: OMPD_section");
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1e2099d6cc1b2..9db269c4f8756 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -37,6 +37,7 @@
 #include "llvm/TargetParser/Triple.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 
+#include <cassert>
 #include <cstdint>
 #include <iterator>
 #include <numeric>
@@ -77,6 +78,22 @@ class OpenMPAllocaStackFrame
   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
 };
 
+/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
+/// insertion points for allocas of parent of the current parallel region. The
+/// insertion point is used to allocate variables to be share by the threads
+/// executing the parallel region. Lowering of scan reduction requires declaring
+/// shared pointers to the temporary buffer to perform scan reduction.
+class OpenMPParallelAllocaStackFrame
+    : public StateStackFrameBase<OpenMPParallelAllocaStackFrame> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPParallelAllocaStackFrame)
+
+  explicit OpenMPParallelAllocaStackFrame(
+      llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
+      : allocaInsertPoint(allocaIP) {}
+  llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
+};
+
 /// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
 /// collapsed canonical loop information corresponding to an \c omp.loop_nest
 /// operation.
@@ -84,7 +101,13 @@ class OpenMPLoopInfoStackFrame
     : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
-  llvm::CanonicalLoopInfo *loopInfo = nullptr;
+  /// For constructs like scan, one LoopInfo frame can contain multiple
+  /// Canonical Loops as a single openmpLoopNestOp will be split into input
+  /// loop and scan loop.
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
+  llvm::ScanInfo *scanInfo;
+  llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+      new llvm::DenseMap<llvm::Value *, llvm::Type *>();
 };
 
 /// Custom error class to signal translation errors that don't need reporting,
@@ -323,6 +346,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getDistScheduleChunkSize())
       result = todo("dist_schedule with chunk_size");
   };
+  auto checkExclusive = [&todo](auto op, LogicalResult &result) {
+    if (!op.getExclusiveVars().empty())
+      result = todo("exclusive");
+  };
   auto checkHint = [](auto op, LogicalResult &) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
@@ -371,9 +398,14 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       if (!op.getReductionVars().empty() || op.getReductionByref() ||
           op.getReductionSyms())
         result = todo("reduction");
-    if (op.getReductionMod() &&
-        op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
-      result = todo("reduction with modifier");
+    if (op.getReductionMod()) {
+      if (isa<omp::WsloopOp>(op)) {
+        if (op.getReductionMod().value() == omp::ReductionModifier::task)
+          result = todo("reduction with task modifier");
+      } else {
+        result = todo("reduction with modifier");
+      }
+    }
   };
   auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
     if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
@@ -397,6 +429,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkOrder(op, result);
       })
       .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
+      .Case([&](omp::ScanOp op) { checkExclusive(op, result); })
       .Case([&](omp::SectionsOp op) {
         checkAllocate(op, result);
         checkPrivate(op, result);
@@ -531,15 +564,59 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
 /// Find the loop information structure for the loop nest being translated. It
 /// will return a `null` value unless called from the translation function for
 /// a loop wrapper operation after successfully translating its body.
-static llvm::CanonicalLoopInfo *
-findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
-  llvm::CanonicalLoopInfo *loopInfo = nullptr;
+static SmallVector<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfos(LLVM::ModuleTranslation &moduleTranslation) {
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
+  moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+      [&](OpenMPLoopInfoStackFrame &frame) {
+        loopInfos = frame.loopInfos;
+        return WalkResult::interrupt();
+      });
+  return loopInfos;
+}
+
+// LoopFrame stores the scaninfo which is used for scan reduction.
+// Upon encountering an `inscan` reduction modifier, `scanInfoInitialize`
+// initializes the ScanInfo and is used when scan directive is encountered
+// in the body of the loop nest.
+static llvm::ScanInfo *
+findScanInfo(LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::ScanInfo *scanInfo;
+  moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+      [&](OpenMPLoopInfoStackFrame &frame) {
+        scanInfo = frame.scanInfo;
+        return WalkResult::interrupt();
+      });
+  return scanInfo;
+}
+
+// The types of reduction vars are used for lowering scan directive which
+// appears in the body of the loop. The types are stored in loop frame when
+// reduction clause is encountered and is used when scan directive is
+// encountered.
+static llvm::DenseMap<llvm::Value *, llvm::Type *> *
+findReductionVarTypes(LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType = nullptr;
   moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
       [&](OpenMPLoopInfoStackFrame &frame) {
-        loopInfo = frame.loopInfo;
+        reductionVarToType = frame.reductionVarToType;
         return WalkResult::interrupt();
       });
-  return loopInfo;
+  return reductionVarToType;
+}
+
+// Scan reduction requires a shared buffer to be allocated to perform reduction.
+// ParallelAllocaStackFrame holds the allocaIP where shared allocation can be
+// done.
+static llvm::OpenMPIRBuilder::InsertPointTy
+findParallelAllocaIP(LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder::InsertPointTy parallelAllocaIP;
+  moduleTranslation.stackWalk<OpenMPParallelAllocaStackFrame>(
+      [&](OpenMPParallelAllocaStackFrame &frame) {
+        parallelAllocaIP = frame.allocaInsertPoint;
+        return WalkResult::interrupt();
+      });
+  return parallelAllocaIP;
 }
 
 /// Converts the given region that appears within an OpenMP dialect operation to
@@ -1254,11 +1331,17 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
   for (auto [data, addr] : deferredStores)
     builder.CreateStore(data, addr);
 
+  llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+      findReductionVarTypes(moduleTranslation);
   // Before the loop, store the initial values of reductions into reduction
   // variables. Although this could be done after allocas, we don't want to mess
   // up with the alloca insertion point.
   for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
     SmallVector<llvm::Value *, 1> phis;
+    llvm::Type *reductionType =
+        moduleTranslation.convertType(reductionDecls[i].getType());
+    if (reductionVarToType != nullptr)
+      (*reductionVarToType)[privateReductionVariables[i]] = reductionType;
 
     // map block argument to initializer region
     mapInitializationArgs(op, moduleTranslation, reductionDecls,
@@ -1330,15 +1413,20 @@ static void collectReductionInfo(
 
   // Collect the reduction information.
   reductionInfos.reserve(numReductions);
+  llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+      findReductionVarTypes(moduleTranslation);
   for (unsigned i = 0; i < numReductions; ++i) {
     llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
     if (owningAtomicReductionGens[i])
       atomicGen = owningAtomicReductionGens[i];
     llvm::Value *variable =
         moduleTranslation.lookupValue(loop.getReductionVars()[i]);
+    llvm::Type *reductionType =
+        moduleTranslation.convertType(reductionDecls[i].getType());
+    if (reductionVarToType != nullptr)
+      (*reductionVarToType)[privateReductionVariables[i]] = reductionType;
     reductionInfos.push_back(
-        {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
-         privateReductionVariables[i],
+        {reductionType, variable, privateReductionVariables[i],
          /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
          owningReductionGens[i],
          /*ReductionGenClang=*/nullptr, atomicGen});
@@ -2543,6 +2631,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
   bool isSimd = wsloopOp.getScheduleSimd();
   bool loopNeedsBarrier = !wsloopOp.getNowait();
+  bool isInScanRegion =
+      wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+                                     mlir::omp::ReductionModifier::inscan);
 
   // The only legal way for the direct parent to be omp.distribute is that this
   // represents 'distribute parallel do'. Otherwise, this is a regular
@@ -2574,20 +2665,81 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(regionBlock, opInst)))
     return failure();
 
-  llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+      findCurrentLoopInfos(moduleTranslation);
+
+  const auto &&wsloopCodeGen = [&](llvm::CanonicalLoopInfo *loopInfo,
+                                   bool noLoopMode, bool inputScanLoop) {
+    bool emitLinearVarInit = !isInScanRegion || inputScanLoop;
+    // Emit Initialization and Update IR for linear variables
+    if (emitLinearVarInit && !wsloopOp.getLinearVars().empty()) {
+      llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+          linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+                                              loopInfo->getPreheader());
+      if (failed(handleError(afterBarrierIP, *loopOp)))
+        return failure();
+      builder.restoreIP(*afterBarrierIP);
+      linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+                                            loopInfo->getIndVar());
+    }
+    bool emitLinearVarFinalize = !isInScanRegion || !inputScanLoop;
+    if (emitLinearVarFinalize)
+      linearClauseProcessor.outlineLinearFinalizationBB(builder,
+                                                        loopInfo->getExit());
+    builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+        ompBuilder->applyWorkshareLoop(
+            ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+            convertToScheduleKind(schedule), chunk, isSimd,
+            scheduleMod == omp::ScheduleModifier::monotonic,
+            scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+            workshareLoopType, noLoopMode);
+
+    if (failed(handleError(wsloopIP, opInst)))
+      return failure();
 
-  // Emit Initialization and Update IR for linear variables
-  if (!wsloopOp.getLinearVars().empty()) {
-    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
-        linearClauseProcessor.initLinearVar(builder, moduleTranslation,
-                                            loopInfo->getPreheader());
-    if (failed(handleError(afterBarrierIP, *loopOp)))
+    // Emit finalization and in-place rewrites for linear vars.
+    if (emitLinearVarFinalize && !wsloopOp.getLinearVars().empty()) {
+      llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+      if (loopInfo->getLastIter())
+        return failure();
+      // assert(loopInfo->getLastIter() &&
+      //        "`lastiter` in CanonicalLoopInfo is nullptr");
+      llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+          linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
+                                                  loopInfo->getLastIter());
+      if (failed(handleError(afterBarrierIP, *loopOp)))
+        return failure();
+      for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
+        linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
+                                             index);
+      builder.restoreIP(oldIP);
+    }
+    if (!inputScanLoop || !isInScanRegion)
+      popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
+
+    return llvm::success();
+  };
+
+  if (isInScanRegion) {
+    auto inputLoopFinishIp = loopInfos.front()->getAfterIP();
+    builder.restoreIP(inputLoopFinishIp);
+    SmallVector<OwningReductionGen> owningReductionGens;
+    SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
+    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+    collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
+                         owningReductionGens, owningAtomicReductionGens,
+                         privateReductionVariables, reductionInfos);
+    llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont");
+    llvm::ScanInfo *scanInfo = findScanInfo(moduleTranslation);
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
+        ompBuilder->emitScanReduction(builder.saveIP(), reductionInfos,
+                                      scanInfo);
+    if (failed(handleError(redIP, opInst)))
       return failure();
-    builder.restoreIP(*afterBarrierIP);
-    linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
-                                          loopInfo->getIndVar());
-    linearClauseProcessor.outlineLinearFinalizationBB(builder,
-                                                      loopInfo->getExit());
+
+    builder.restoreIP(*redIP);
+    builder.CreateBr(cont);
   }
 
   builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
@@ -2612,42 +2764,34 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     }
   }
 
-  llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
-      ompBuilder->applyWorkshareLoop(
-          ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
-          convertToScheduleKind(schedule), chunk, isSimd,
-          scheduleMod == omp::ScheduleModifier::monotonic,
-          scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-          workshareLoopType, noLoopMode);
-
-  if (failed(handleError(wsloopIP, opInst)))
-    return failure();
-
-  // Emit finalization and in-place rewrites for linear vars.
-  if (!wsloopOp.getLinearVars().empty()) {
-    llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
-    assert(loopInfo->getLastIter() &&
-           "`lastiter` in CanonicalLoopInfo is nullptr");
-    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
-        linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
-                                                loopInfo->getLastIter());
-    if (failed(handleError(afterBarrierIP, *loopOp)))
+  // For Scan loops input loop need not pop cancellation CB and hence, it is set
+  // false for the first loop
+  bool inputScanLoop = isInScanRegion;
+  for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+    if (failed(wsloopCodeGen(loopInfo, noLoopMode, inputScanLoop)))
       return failure();
-    for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
-      linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
-                                           index);
-    builder.restoreIP(oldIP);
+    inputScanLoop = false;
   }
 
-  // Set the correct branch target for task cancellation
-  popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
-
-  // Process the reductions if required.
-  if (failed(createReductionsAndCleanup(
-          wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
-          privateReductionVariables, isByRef, wsloopOp.getNowait(),
-          /*isTeamsReduction=*/false)))
-    return failure();
+  // todo: change builder.saveIP to wsLoopIP
+  if (isInScanRegion) {
+    SmallVector<Region *> reductionRegions;
+    llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
+                    [](omp::DeclareReductionOp reductionDecl) {
+                      return &reductionDecl.getCleanupRegion();
+                    });
+    if (failed(inlineOmpRegionCleanup(
+            reductionRegions, privateReductionVariables, moduleTranslation,
+            builder, "omp.reduction.cleanup")))
+      return failure();
+  } else {
+    // Process the reductions if required.
+    if (failed(createReductionsAndCleanup(
+            wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
+            privateReductionVariables, i...
[truncated]
 | 
    
c78b19d    to
    b18ada9      
    Compare
  
    b18ada9    to
    27a9e19      
    Compare
  
    
Scan reductions are supported in OpenMP with the the help of scan directive. Reduction clause of the for workshare loop/simd directive takes an
inscanmodifier if scan reduction is specified. With aninscanmodifier, the body of the directive should specify ascandirective. This PR implements the lowering logic for scan reductions in workshare loops of OpenMP. OpenMPIRBuilder support can be found in #136035. Support for nested loops/ exclusive clause is not done in this PRNested Loops, Linear clause, Collapse and Tiling are not enabled now for loops with scan reduction.