Skip to content

Commit acbc73a

Browse files
committed
[flang][OpenMP] Extend do concurrent mapping to device
Applies upstream PR llvm#155987 to avoid annoying merge conflicts later on.
1 parent 04ed28e commit acbc73a

File tree

2 files changed

+98
-59
lines changed

2 files changed

+98
-59
lines changed

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 58 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,24 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "flang/Optimizer/Builder/BoxValue.h"
109
#include "flang/Optimizer/Builder/DirectivesCommon.h"
1110
#include "flang/Optimizer/Builder/FIRBuilder.h"
1211
#include "flang/Optimizer/Builder/HLFIRTools.h"
1312
#include "flang/Optimizer/Builder/Todo.h"
1413
#include "flang/Optimizer/Dialect/FIROps.h"
15-
#include "flang/Optimizer/Dialect/FIRType.h"
16-
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
1714
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1815
#include "flang/Optimizer/OpenMP/Passes.h"
1916
#include "flang/Optimizer/OpenMP/Utils.h"
2017
#include "flang/Support/OpenMP-utils.h"
2118
#include "flang/Utils/OpenMP.h"
2219
#include "mlir/Analysis/SliceAnalysis.h"
23-
#include "mlir/Dialect/Func/IR/FuncOps.h"
2420
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
25-
#include "mlir/IR/Diagnostics.h"
2621
#include "mlir/IR/IRMapping.h"
27-
#include "mlir/Pass/Pass.h"
2822
#include "mlir/Transforms/DialectConversion.h"
2923
#include "mlir/Transforms/RegionUtils.h"
24+
#include "llvm/ADT/SmallPtrSet.h"
3025
#include "llvm/Frontend/OpenMP/OMPConstants.h"
3126

32-
#include <memory>
33-
#include <utility>
34-
3527
namespace flangomp {
3628
#define GEN_PASS_DEF_DOCONCURRENTCONVERSIONPASS
3729
#include "flang/Optimizer/OpenMP/Passes.h.inc"
@@ -49,7 +41,6 @@ struct InductionVariableInfo {
4941
mlir::Value inductionVar) {
5042
populateInfo(loop, inductionVar);
5143
}
52-
5344
/// The operation allocating memory for iteration variable.
5445
mlir::Operation *iterVarMemDef;
5546
/// the operation(s) updating the iteration variable with the current
@@ -126,7 +117,7 @@ using InductionVariableInfos = llvm::SmallVector<InductionVariableInfo>;
126117
void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
127118
llvm::SmallVectorImpl<mlir::Value> &liveIns) {
128119
llvm::SmallDenseSet<mlir::Value> seenValues;
129-
llvm::SmallDenseSet<mlir::Operation *> seenOps;
120+
llvm::SmallPtrSet<mlir::Operation *, 8> seenOps;
130121

131122
for (auto [lb, ub, st] : llvm::zip_equal(
132123
loop.getLowerBound(), loop.getUpperBound(), loop.getStep())) {
@@ -210,6 +201,52 @@ static void localizeLoopLocalValue(mlir::Value local, mlir::Region &allocRegion,
210201

211202
class DoConcurrentConversion
212203
: public mlir::OpConversionPattern<fir::DoConcurrentOp> {
204+
private:
205+
struct TargetDeclareShapeCreationInfo {
206+
// Note: We use `std::vector` (rather than `llvm::SmallVector` as usual) to
207+
// interface more easily `ShapeShiftOp::getOrigins()` which returns
208+
// `std::vector`.
209+
std::vector<mlir::Value> startIndices;
210+
std::vector<mlir::Value> extents;
211+
212+
TargetDeclareShapeCreationInfo(mlir::Value liveIn) {
213+
mlir::Value shape = nullptr;
214+
mlir::Operation *liveInDefiningOp = liveIn.getDefiningOp();
215+
auto declareOp =
216+
mlir::dyn_cast_if_present<hlfir::DeclareOp>(liveInDefiningOp);
217+
218+
if (declareOp != nullptr)
219+
shape = declareOp.getShape();
220+
221+
if (!shape)
222+
return;
223+
224+
auto shapeOp =
225+
mlir::dyn_cast_if_present<fir::ShapeOp>(shape.getDefiningOp());
226+
auto shapeShiftOp =
227+
mlir::dyn_cast_if_present<fir::ShapeShiftOp>(shape.getDefiningOp());
228+
229+
if (!shapeOp && !shapeShiftOp)
230+
TODO(liveIn.getLoc(),
231+
"Shapes not defined by `fir.shape` or `fir.shape_shift` op's are"
232+
"not supported yet.");
233+
234+
if (shapeShiftOp != nullptr)
235+
startIndices = shapeShiftOp.getOrigins();
236+
237+
extents = shapeOp != nullptr
238+
? std::vector<mlir::Value>(shapeOp.getExtents().begin(),
239+
shapeOp.getExtents().end())
240+
: shapeShiftOp.getExtents();
241+
}
242+
243+
bool isShapedValue() const { return !extents.empty(); }
244+
bool isShapeShiftedValue() const { return !startIndices.empty(); }
245+
};
246+
247+
using LiveInShapeInfoMap =
248+
llvm::DenseMap<mlir::Value, TargetDeclareShapeCreationInfo>;
249+
213250
public:
214251
using mlir::OpConversionPattern<fir::DoConcurrentOp>::OpConversionPattern;
215252

@@ -285,6 +322,7 @@ class DoConcurrentConversion
285322

286323
mlir::omp::ParallelOp parallelOp =
287324
genParallelOp(doLoop.getLoc(), rewriter, ivInfos, mapper);
325+
288326
// Only set as composite when part of `distribute parallel do`.
289327
parallelOp.setComposite(mapToDevice);
290328

@@ -333,51 +371,6 @@ class DoConcurrentConversion
333371
}
334372

335373
private:
336-
struct TargetDeclareShapeCreationInfo {
337-
// Note: We use `std::vector` (rather than `llvm::SmallVector` as usual) to
338-
// interface more easily `ShapeShiftOp::getOrigins()` which returns
339-
// `std::vector`.
340-
std::vector<mlir::Value> startIndices{};
341-
std::vector<mlir::Value> extents{};
342-
343-
TargetDeclareShapeCreationInfo(mlir::Value liveIn) {
344-
mlir::Value shape = nullptr;
345-
mlir::Operation *liveInDefiningOp = liveIn.getDefiningOp();
346-
auto declareOp =
347-
mlir::dyn_cast_if_present<hlfir::DeclareOp>(liveInDefiningOp);
348-
349-
if (declareOp != nullptr)
350-
shape = declareOp.getShape();
351-
352-
if (shape == nullptr)
353-
return;
354-
355-
auto shapeOp =
356-
mlir::dyn_cast_if_present<fir::ShapeOp>(shape.getDefiningOp());
357-
auto shapeShiftOp =
358-
mlir::dyn_cast_if_present<fir::ShapeShiftOp>(shape.getDefiningOp());
359-
360-
if (shapeOp == nullptr && shapeShiftOp == nullptr)
361-
TODO(liveIn.getLoc(),
362-
"Shapes not defined by `fir.shape` or `fir.shape_shift` op's are"
363-
"not supported yet.");
364-
365-
if (shapeShiftOp != nullptr)
366-
startIndices = shapeShiftOp.getOrigins();
367-
368-
extents = shapeOp != nullptr
369-
? std::vector<mlir::Value>(shapeOp.getExtents().begin(),
370-
shapeOp.getExtents().end())
371-
: shapeShiftOp.getExtents();
372-
}
373-
374-
bool isShapedValue() const { return !extents.empty(); }
375-
bool isShapeShiftedValue() const { return !startIndices.empty(); }
376-
};
377-
378-
using LiveInShapeInfoMap =
379-
llvm::DenseMap<mlir::Value, TargetDeclareShapeCreationInfo>;
380-
381374
mlir::omp::ParallelOp
382375
genParallelOp(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
383376
looputils::InductionVariableInfos &ivInfos,
@@ -435,6 +428,8 @@ class DoConcurrentConversion
435428
llvm::SmallVectorImpl<mlir::Value> &bounds) {
436429
populateBounds(var, bounds);
437430

431+
// Ensure that loop-nest bounds are evaluated in the host and forwarded to
432+
// the nested omp constructs when we map to the device.
438433
if (targetClauseOps)
439434
targetClauseOps->hostEvalVars.push_back(var);
440435
};
@@ -616,6 +611,7 @@ class DoConcurrentConversion
616611
}
617612

618613
if (!llvm::isa<mlir::omp::PointerLikeType>(rawAddr.getType())) {
614+
mlir::OpBuilder::InsertionGuard guard(builder);
619615
builder.setInsertionPointAfter(liveInDefiningOp);
620616
auto copyVal = builder.createTemporary(liveIn.getLoc(), liveIn.getType());
621617
builder.createStoreWithConvert(copyVal.getLoc(), liveIn, copyVal);
@@ -678,8 +674,8 @@ class DoConcurrentConversion
678674
rewriter,
679675
fir::getKindMapping(targetOp->getParentOfType<mlir::ModuleOp>()));
680676

681-
// Within the loop, it possible that we discover other values that need to
682-
// mapped to the target region (the shape info values for arrays, for
677+
// Within the loop, it is possible that we discover other values that need
678+
// to be mapped to the target region (the shape info values for arrays, for
683679
// example). Therefore, the map block args might be extended and resized.
684680
// Hence, we invoke `argIface.getMapBlockArgs()` every iteration to make
685681
// sure we access the proper vector of data.
@@ -692,10 +688,13 @@ class DoConcurrentConversion
692688
miOp, liveInShapeInfoMap.at(mappedVar));
693689
++idx;
694690

695-
// TODO If `mappedVar.getDefiningOp()` is a `fir::BoxAddrOp`, we probably
691+
// If `mappedVar.getDefiningOp()` is a `fir::BoxAddrOp`, we probably
696692
// need to "unpack" the box by getting the defining op of it's value.
697693
// However, we did not hit this case in reality yet so leaving it as a
698694
// todo for now.
695+
if (mlir::isa<fir::BoxAddrOp>(mappedVar.getDefiningOp()))
696+
TODO(mappedVar.getLoc(),
697+
"Mapped variabled defined by `BoxAddrOp` are not supported yet");
699698

700699
auto mapHostValueToDevice = [&](mlir::Value hostValue,
701700
mlir::Value deviceValue) {
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
! Tests that when a loop bound is used in the body, that the mapped version of
2+
! the loop bound (rather than the host-eval one) is the one used inside the loop.
3+
4+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-to-openmp=device %s -o - \
5+
! RUN: | FileCheck %s
6+
! RUN: bbc -emit-hlfir -fopenmp -fdo-concurrent-to-openmp=device %s -o - \
7+
! RUN: | FileCheck %s
8+
9+
subroutine foo(a, n)
10+
implicit none
11+
integer :: i, n
12+
real, dimension(n) :: a
13+
14+
do concurrent (i=1:n)
15+
a(i) = n
16+
end do
17+
end subroutine
18+
19+
! CHECK-LABEL: func.func @_QPfoo
20+
! CHECK: omp.target
21+
! CHECK-SAME: host_eval(%{{.*}} -> %{{.*}}, %{{.*}} -> %[[N_HOST_EVAL:.*]], %{{.*}} -> %{{.*}} : index, index, index)
22+
! CHECK-SAME: map_entries({{[^[:space:]]*}} -> {{[^[:space:]]*}},
23+
! CHECK-SAME: {{[^[:space:]]*}} -> {{[^[:space:]]*}}, {{[^[:space:]]*}} -> {{[^[:space:]]*}},
24+
! CHECK-SAME: {{[^[:space:]]*}} -> {{[^[:space:]]*}}, {{[^[:space:]]*}} -> %[[N_MAP_ARG:[^[:space:]]*]], {{.*}}) {
25+
! CHECK: %[[N_MAPPED:.*]]:2 = hlfir.declare %[[N_MAP_ARG]] {uniq_name = "_QFfooEn"}
26+
! CHECK: omp.teams {
27+
! CHECK: omp.parallel {
28+
! CHECK: omp.distribute {
29+
! CHECK: omp.wsloop {
30+
! CHECK: omp.loop_nest (%{{.*}}) : index = (%{{.*}}) to (%[[N_HOST_EVAL]]) inclusive step (%{{.*}}) {
31+
! CHECK: %[[N_VAL:.*]] = fir.load %[[N_MAPPED]]#0 : !fir.ref<i32>
32+
! CHECK: %[[N_VAL_CVT:.*]] = fir.convert %[[N_VAL]] : (i32) -> f32
33+
! CHECK: hlfir.assign %[[N_VAL_CVT]] to {{.*}}
34+
! CHECK-NEXT: omp.yield
35+
! CHECK: }
36+
! CHECK: }
37+
! CHECK: }
38+
! CHECK: }
39+
! CHECK: }
40+
! CHECK: }

0 commit comments

Comments
 (0)