Skip to content

Conversation

ergawy
Copy link
Member

@ergawy ergawy commented Sep 3, 2025

Extends support for mapping `do concurrent` on the device by adding
support for `local` specifiers. The changes in this PR map the local
variable to the `omp.target` op and uses the mapped value as the
`private` clause operand in the nested `omp.parallel` op.
@ergawy ergawy force-pushed the users/ergawy/upstream_dc_device_5_offload_tests branch from 2fd2022 to d967c72 Compare September 4, 2025 09:39
@ergawy ergawy force-pushed the users/ergawy/upstream_dc_device_6_local_specs_device branch from 78fc5ed to 78e1013 Compare September 4, 2025 09:39
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Sep 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 4, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kareem Ergawy (ergawy)

Changes

Extends support for mapping do concurrent on the device by adding support for local specifiers. The changes in this PR map the local variable to the omp.target op and uses the mapped value as the private clause operand in the nested omp.parallel op.


Full diff: https://github.com/llvm/llvm-project/pull/156589.diff

3 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+12)
  • (modified) flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp (+114-78)
  • (added) flang/test/Transforms/DoConcurrent/local_device.mlir (+49)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index bc971e8fd6600..fc6eedc6ed4c6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3894,6 +3894,18 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
       return getReduceVars().size();
     }
 
+    unsigned getInductionVarsStart() {
+      return 0;
+    }
+
+    unsigned getLocalOperandsStart() {
+      return getNumInductionVars();
+    }
+
+    unsigned getReduceOperandsStart() {
+      return getLocalOperandsStart() + getNumLocalOperands();
+    }
+
     mlir::Block::BlockArgListType getInductionVars() {
       return getBody()->getArguments().slice(0, getNumInductionVars());
     }
diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
index a800a20129abe..66b778fecc208 100644
--- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
@@ -137,6 +137,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
 
         liveIns.push_back(operand->get());
       });
+
+  for (mlir::Value local : loop.getLocalVars())
+    liveIns.push_back(local);
 }
 
 /// Collects values that are local to a loop: "loop-local values". A loop-local
@@ -251,8 +254,7 @@ class DoConcurrentConversion
               .getIsTargetDevice();
 
       mlir::omp::TargetOperands targetClauseOps;
-      genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper,
-                           loopNestClauseOps,
+      genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, loopNestClauseOps,
                            isTargetDevice ? nullptr : &targetClauseOps);
 
       LiveInShapeInfoMap liveInShapeInfoMap;
@@ -274,14 +276,13 @@ class DoConcurrentConversion
     }
 
     mlir::omp::ParallelOp parallelOp =
-        genParallelOp(doLoop.getLoc(), rewriter, ivInfos, mapper);
+        genParallelOp(rewriter, loop, ivInfos, mapper);
 
     // Only set as composite when part of `distribute parallel do`.
     parallelOp.setComposite(mapToDevice);
 
     if (!mapToDevice)
-      genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper,
-                           loopNestClauseOps);
+      genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, loopNestClauseOps);
 
     for (mlir::Value local : locals)
       looputils::localizeLoopLocalValue(local, parallelOp.getRegion(),
@@ -290,10 +291,38 @@ class DoConcurrentConversion
     if (mapToDevice)
       genDistributeOp(doLoop.getLoc(), rewriter).setComposite(/*val=*/true);
 
-    mlir::omp::LoopNestOp ompLoopNest =
+    auto [loopNestOp, wsLoopOp] =
         genWsLoopOp(rewriter, loop, mapper, loopNestClauseOps,
                     /*isComposite=*/mapToDevice);
 
+    // `local` region arguments are transferred/cloned from the `do concurrent`
+    // loop to the loopnest op when the region is cloned above. Instead, these
+    // region arguments should be on the workshare loop's region.
+    if (mapToDevice) {
+      for (auto [parallelArg, loopNestArg] : llvm::zip_equal(
+               parallelOp.getRegion().getArguments(),
+               loopNestOp.getRegion().getArguments().slice(
+                   loop.getLocalOperandsStart(), loop.getNumLocalOperands())))
+        rewriter.replaceAllUsesWith(loopNestArg, parallelArg);
+
+      for (auto [wsloopArg, loopNestArg] : llvm::zip_equal(
+               wsLoopOp.getRegion().getArguments(),
+               loopNestOp.getRegion().getArguments().slice(
+                   loop.getReduceOperandsStart(), loop.getNumReduceOperands())))
+        rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
+    } else {
+      for (auto [wsloopArg, loopNestArg] :
+           llvm::zip_equal(wsLoopOp.getRegion().getArguments(),
+                           loopNestOp.getRegion().getArguments().drop_front(
+                               loopNestClauseOps.loopLowerBounds.size())))
+        rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
+    }
+
+    for (unsigned i = 0;
+         i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i)
+      loopNestOp.getRegion().eraseArgument(
+          loopNestClauseOps.loopLowerBounds.size());
+
     rewriter.setInsertionPoint(doLoop);
     fir::FirOpBuilder builder(
         rewriter,
@@ -314,7 +343,7 @@ class DoConcurrentConversion
     // Mark `unordered` loops that are not perfectly nested to be skipped from
     // the legality check of the `ConversionTarget` since we are not interested
     // in mapping them to OpenMP.
-    ompLoopNest->walk([&](fir::DoConcurrentOp doLoop) {
+    loopNestOp->walk([&](fir::DoConcurrentOp doLoop) {
       concurrentLoopsToSkip.insert(doLoop);
     });
 
@@ -370,11 +399,21 @@ class DoConcurrentConversion
       llvm::DenseMap<mlir::Value, TargetDeclareShapeCreationInfo>;
 
   mlir::omp::ParallelOp
-  genParallelOp(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
+  genParallelOp(mlir::ConversionPatternRewriter &rewriter,
+                fir::DoConcurrentLoopOp loop,
                 looputils::InductionVariableInfos &ivInfos,
                 mlir::IRMapping &mapper) const {
-    auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc);
-    rewriter.createBlock(&parallelOp.getRegion());
+    mlir::omp::ParallelOperands parallelOps;
+
+    if (mapToDevice)
+      genPrivatizers(rewriter, mapper, loop, parallelOps);
+
+    mlir::Location loc = loop.getLoc();
+    auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc, parallelOps);
+    Fortran::common::openmp::EntryBlockArgs parallelArgs;
+    parallelArgs.priv.vars = parallelOps.privateVars;
+    Fortran::common::openmp::genEntryBlock(rewriter, parallelArgs,
+                                           parallelOp.getRegion());
     rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc));
 
     genLoopNestIndVarAllocs(rewriter, ivInfos, mapper);
@@ -411,7 +450,7 @@ class DoConcurrentConversion
 
   void genLoopNestClauseOps(
       mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
-      fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
+      fir::DoConcurrentLoopOp loop,
       mlir::omp::LoopNestOperands &loopNestClauseOps,
       mlir::omp::TargetOperands *targetClauseOps = nullptr) const {
     assert(loopNestClauseOps.loopLowerBounds.empty() &&
@@ -440,59 +479,14 @@ class DoConcurrentConversion
     loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
   }
 
-  mlir::omp::LoopNestOp
+  std::pair<mlir::omp::LoopNestOp, mlir::omp::WsloopOp>
   genWsLoopOp(mlir::ConversionPatternRewriter &rewriter,
               fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
               const mlir::omp::LoopNestOperands &clauseOps,
               bool isComposite) const {
     mlir::omp::WsloopOperands wsloopClauseOps;
-
-    auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
-                                           mlir::Region &ompRegion) {
-      if (!firRegion.empty()) {
-        rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin());
-        auto firYield =
-            mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator());
-        rewriter.setInsertionPoint(firYield);
-        mlir::omp::YieldOp::create(rewriter, firYield.getLoc(),
-                                   firYield.getOperands());
-        rewriter.eraseOp(firYield);
-      }
-    };
-
-    // For `local` (and `local_init`) opernads, emit corresponding `private`
-    // clauses and attach these clauses to the workshare loop.
-    if (!loop.getLocalVars().empty())
-      for (auto [op, sym, arg] : llvm::zip_equal(
-               loop.getLocalVars(),
-               loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
-               loop.getRegionLocalArgs())) {
-        auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
-            sym.getLeafReference());
-        if (localizer.getLocalitySpecifierType() ==
-            fir::LocalitySpecifierType::LocalInit)
-          TODO(localizer.getLoc(),
-               "local_init conversion is not supported yet");
-
-        mlir::OpBuilder::InsertionGuard guard(rewriter);
-        rewriter.setInsertionPointAfter(localizer);
-
-        auto privatizer = mlir::omp::PrivateClauseOp::create(
-            rewriter, localizer.getLoc(), sym.getLeafReference().str() + ".omp",
-            localizer.getTypeAttr().getValue(),
-            mlir::omp::DataSharingClauseType::Private);
-
-        cloneFIRRegionToOMP(localizer.getInitRegion(),
-                            privatizer.getInitRegion());
-        cloneFIRRegionToOMP(localizer.getDeallocRegion(),
-                            privatizer.getDeallocRegion());
-
-        moduleSymbolTable.insert(privatizer);
-
-        wsloopClauseOps.privateVars.push_back(op);
-        wsloopClauseOps.privateSyms.push_back(
-            mlir::SymbolRefAttr::get(privatizer));
-      }
+    if (!mapToDevice)
+      genPrivatizers(rewriter, mapper, loop, wsloopClauseOps);
 
     if (!loop.getReduceVars().empty()) {
       for (auto [op, byRef, sym, arg] : llvm::zip_equal(
@@ -515,15 +509,15 @@ class DoConcurrentConversion
               rewriter, firReducer.getLoc(), ompReducerName,
               firReducer.getTypeAttr().getValue());
 
-          cloneFIRRegionToOMP(firReducer.getAllocRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
                               ompReducer.getAllocRegion());
-          cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
                               ompReducer.getInitializerRegion());
-          cloneFIRRegionToOMP(firReducer.getReductionRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
                               ompReducer.getReductionRegion());
-          cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
                               ompReducer.getAtomicReductionRegion());
-          cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
                               ompReducer.getCleanupRegion());
           moduleSymbolTable.insert(ompReducer);
         }
@@ -555,21 +549,10 @@ class DoConcurrentConversion
 
     rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back());
     mlir::omp::YieldOp::create(rewriter, loop->getLoc());
+    loop->getParentOfType<mlir::ModuleOp>().print(
+        llvm::errs(), mlir::OpPrintingFlags().assumeVerified());
 
-    // `local` region arguments are transferred/cloned from the `do concurrent`
-    // loop to the loopnest op when the region is cloned above. Instead, these
-    // region arguments should be on the workshare loop's region.
-    for (auto [wsloopArg, loopNestArg] :
-         llvm::zip_equal(wsloopOp.getRegion().getArguments(),
-                         loopNestOp.getRegion().getArguments().drop_front(
-                             clauseOps.loopLowerBounds.size())))
-      rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
-
-    for (unsigned i = 0;
-         i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i)
-      loopNestOp.getRegion().eraseArgument(clauseOps.loopLowerBounds.size());
-
-    return loopNestOp;
+    return {loopNestOp, wsloopOp};
   }
 
   void genBoundsOps(fir::FirOpBuilder &builder, mlir::Value liveIn,
@@ -810,6 +793,59 @@ class DoConcurrentConversion
     return distOp;
   }
 
+  void cloneFIRRegionToOMP(mlir::ConversionPatternRewriter &rewriter,
+                           mlir::Region &firRegion,
+                           mlir::Region &ompRegion) const {
+    if (!firRegion.empty()) {
+      rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin());
+      auto firYield =
+          mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator());
+      rewriter.setInsertionPoint(firYield);
+      mlir::omp::YieldOp::create(rewriter, firYield.getLoc(),
+                                 firYield.getOperands());
+      rewriter.eraseOp(firYield);
+    }
+  }
+
+  void genPrivatizers(mlir::ConversionPatternRewriter &rewriter,
+                      mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
+                      mlir::omp::PrivateClauseOps &privateClauseOps) const {
+    // For `local` (and `local_init`) operands, emit corresponding `private`
+    // clauses and attach these clauses to the workshare loop.
+    if (!loop.getLocalVars().empty())
+      for (auto [var, sym, arg] : llvm::zip_equal(
+               loop.getLocalVars(),
+               loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
+               loop.getRegionLocalArgs())) {
+        auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
+            sym.getLeafReference());
+        if (localizer.getLocalitySpecifierType() ==
+            fir::LocalitySpecifierType::LocalInit)
+          TODO(localizer.getLoc(),
+               "local_init conversion is not supported yet");
+
+        mlir::OpBuilder::InsertionGuard guard(rewriter);
+        rewriter.setInsertionPointAfter(localizer);
+
+        auto privatizer = mlir::omp::PrivateClauseOp::create(
+            rewriter, localizer.getLoc(), sym.getLeafReference().str() + ".omp",
+            localizer.getTypeAttr().getValue(),
+            mlir::omp::DataSharingClauseType::Private);
+
+        cloneFIRRegionToOMP(rewriter, localizer.getInitRegion(),
+                            privatizer.getInitRegion());
+        cloneFIRRegionToOMP(rewriter, localizer.getDeallocRegion(),
+                            privatizer.getDeallocRegion());
+
+        moduleSymbolTable.insert(privatizer);
+
+        privateClauseOps.privateVars.push_back(mapToDevice ? mapper.lookup(var)
+                                                           : var);
+        privateClauseOps.privateSyms.push_back(
+            mlir::SymbolRefAttr::get(privatizer));
+      }
+  }
+
   bool mapToDevice;
   llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
   mlir::SymbolTable &moduleSymbolTable;
diff --git a/flang/test/Transforms/DoConcurrent/local_device.mlir b/flang/test/Transforms/DoConcurrent/local_device.mlir
new file mode 100644
index 0000000000000..e54bb1aeb414e
--- /dev/null
+++ b/flang/test/Transforms/DoConcurrent/local_device.mlir
@@ -0,0 +1,49 @@
+// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | FileCheck %s
+
+fir.local {type = local} @_QFfooEmy_local_private_f32 : f32
+
+func.func @_QPfoo() {
+  %0 = fir.dummy_scope : !fir.dscope
+  %3 = fir.alloca f32 {bindc_name = "my_local", uniq_name = "_QFfooEmy_local"}
+  %4:2 = hlfir.declare %3 {uniq_name = "_QFfooEmy_local"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+
+  fir.do_concurrent {
+    %7 = fir.alloca i32 {bindc_name = "i"}
+    %8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+    fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) local(@_QFfooEmy_local_private_f32 %4#0 -> %arg1 : !fir.ref<f32>) {
+      %9 = fir.convert %arg0 : (index) -> i32
+      fir.store %9 to %8#0 : !fir.ref<i32>
+      %10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEmy_local"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+      %cst = arith.constant 4.200000e+01 : f32
+      hlfir.assign %cst to %10#0 : f32, !fir.ref<f32>
+    }
+  }
+  return
+}
+
+// CHECK: omp.private {type = private} @[[OMP_PRIVATIZER:.*.omp]] : f32
+
+// CHECK: %[[LOCAL_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "{{.*}}my_local"}
+// CHECK: %[[LOCAL_MAP:.*]] = omp.map.info var_ptr(%[[LOCAL_DECL]]#1 : {{.*}})
+
+// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[LOCAL_MAP]] -> %[[LOCAL_MAP_ARG:.*]] : {{.*}}) {
+// CHECK:   %[[LOCAL_DEV_DECL:.*]]:2 = hlfir.declare %[[LOCAL_MAP_ARG]] {uniq_name = "_QFfooEmy_local"}
+
+// CHECK:   omp.teams {
+// CHECK:     omp.parallel private(@[[OMP_PRIVATIZER]] %[[LOCAL_DEV_DECL]]#0 -> %[[LOCAL_PRIV_ARG:.*]] : {{.*}}) {
+// CHECK:       omp.distribute {
+// CHECK:         omp.wsloop {
+// CHECK:           omp.loop_nest {{.*}} {
+// CHECK:             %[[LOCAL_LOOP_DECL:.*]]:2 = hlfir.declare %[[LOCAL_PRIV_ARG]] {uniq_name = "_QFfooEmy_local"}
+// CHECK:             hlfir.assign %{{.*}} to %[[LOCAL_LOOP_DECL]]#0
+// CHECK:             omp.yield
+// CHECK:           }
+// CHECK:         }
+// CHECK:       }
+// CHECK:     }
+// CHECK:   }
+// CHECK: }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants