Skip to content

Commit 1d22d1d

Browse files
Add MapsForPrivatizedSymbolsPass
1 parent 221f15f commit 1d22d1d

File tree

5 files changed

+135
-1
lines changed

5 files changed

+135
-1
lines changed

flang/include/flang/Optimizer/OpenMP/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@ def MapInfoFinalizationPass
2222
let dependentDialects = ["mlir::omp::OpenMPDialect"];
2323
}
2424

25+
def MapsForPrivatizedSymbolsPass
26+
: Pass<"omp-map-for-privatized-symbols", "mlir::func::FuncOp"> {
27+
let summary = "Creates MapInfoOp instances for privatized symbols when needed";
28+
let description = [{
29+
Adds omp.map.info operations for privatized symbols on omp.target ops
30+
In certain situations, such as when an allocatable is privatized, its
31+
descriptor is needed in the alloc region of the privatizer. This results
32+
in the use of the descriptor inside the target region. As such, the
33+
descriptor then needs to be mapped. This pass adds such MapInfoOp operations.
34+
}];
35+
let dependentDialects = ["mlir::omp::OpenMPDialect"];
36+
}
37+
2538
def MarkDeclareTargetPass
2639
: Pass<"omp-mark-declare-target", "mlir::ModuleOp"> {
2740
let summary = "Marks all functions called by an OpenMP declare target function as declare target";

flang/include/flang/Tools/CLOptions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ inline void createHLFIRToFIRPassPipeline(
368368
inline void createOpenMPFIRPassPipeline(
369369
mlir::PassManager &pm, bool isTargetDevice) {
370370
pm.addPass(flangomp::createMapInfoFinalizationPass());
371+
pm.addPass(flangomp::createMapsForPrivatizedSymbolsPass());
371372
pm.addPass(flangomp::createMarkDeclareTargetPass());
372373
if (isTargetDevice)
373374
pm.addPass(flangomp::createFunctionFilteringPass());

flang/lib/Optimizer/OpenMP/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
22

33
add_flang_library(FlangOpenMPTransforms
44
FunctionFiltering.cpp
5+
MapsForPrivatizedSymbols.cpp
56
MapInfoFinalization.cpp
67
MarkDeclareTarget.cpp
78

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
//===- MapsForPrivatizedSymbols.cpp
2+
//-----------------------------------------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
#include "flang/Optimizer/Dialect/FIRType.h"
10+
#include "flang/Optimizer/OpenMP/Passes.h"
11+
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
13+
#include "mlir/IR/BuiltinAttributes.h"
14+
#include "mlir/IR/SymbolTable.h"
15+
#include "mlir/Pass/Pass.h"
16+
#include "llvm/Frontend/OpenMP/OMPConstants.h"
17+
#include <type_traits>
18+
19+
namespace flangomp {
20+
#define GEN_PASS_DEF_MAPSFORPRIVATIZEDSYMBOLSPASS
21+
#include "flang/Optimizer/OpenMP/Passes.h.inc"
22+
} // namespace flangomp
23+
using namespace mlir;
24+
namespace {
25+
class MapsForPrivatizedSymbolsPass
26+
: public flangomp::impl::MapsForPrivatizedSymbolsPassBase<
27+
MapsForPrivatizedSymbolsPass> {
28+
29+
bool privatizerNeedsMap(omp::PrivateClauseOp &privatizer) {
30+
Region &allocRegion = privatizer.getAllocRegion();
31+
Value blockArg0 = allocRegion.getArgument(0);
32+
if (blockArg0.use_empty())
33+
return false;
34+
return true;
35+
}
36+
void dumpPrivatizerInfo(omp::PrivateClauseOp &privatizer,
37+
mlir::Value privVar) {
38+
llvm::errs() << "Found a privatizer:\n";
39+
privatizer.dump();
40+
llvm::errs() << "\n";
41+
42+
llvm::errs() << "$type = " << privatizer.getType() << "\n";
43+
llvm::errs() << "privVar = ";
44+
privVar.dump();
45+
llvm::errs() << "\n";
46+
47+
llvm::errs() << "privVar.getDefiningOp() = ";
48+
privVar.getDefiningOp()->dump();
49+
llvm::errs() << "\n";
50+
llvm::errs() << "\n";
51+
}
52+
omp::MapInfoOp createMapInfo(mlir::Location loc, mlir::Value var,
53+
OpBuilder &builder) {
54+
// llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
55+
// llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
56+
uint64_t mapTypeTo = static_cast<
57+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
58+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
59+
return builder.create<omp::MapInfoOp>(
60+
loc, var.getType(), var,
61+
mlir::TypeAttr::get(fir::unwrapRefType(var.getType())),
62+
/*varPtrPtr=*/mlir::Value{},
63+
/*members=*/mlir::SmallVector<mlir::Value>{},
64+
/*member_index=*/mlir::DenseIntElementsAttr{},
65+
/*bounds=*/mlir::ValueRange{},
66+
builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
67+
mapTypeTo),
68+
builder.getAttr<omp::VariableCaptureKindAttr>(
69+
omp::VariableCaptureKind::ByRef),
70+
mlir::StringAttr(), builder.getBoolAttr(false));
71+
}
72+
void addMapInfoOp(omp::TargetOp targetOp, omp::MapInfoOp mapInfoOp) {
73+
mlir::Location loc = targetOp.getLoc();
74+
targetOp.getMapVarsMutable().append(mlir::ValueRange{mapInfoOp});
75+
size_t numMapVars = targetOp.getMapVars().size();
76+
targetOp.getRegion().insertArgument(numMapVars - 1, mapInfoOp.getType(),
77+
loc);
78+
}
79+
void runOnOperation() override {
80+
MLIRContext *context = &getContext();
81+
OpBuilder builder(context);
82+
getOperation()->walk([&](omp::TargetOp targetOp) {
83+
llvm::errs() << "MapsForPrivatizedSymbolsPass::TargetOp is \n";
84+
targetOp.dump();
85+
llvm::errs() << "\n";
86+
87+
if (targetOp.getPrivateVars().empty())
88+
return;
89+
90+
OperandRange privVars = targetOp.getPrivateVars();
91+
std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
92+
93+
for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
94+
95+
SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
96+
omp::PrivateClauseOp privatizer =
97+
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
98+
targetOp, privatizerName);
99+
100+
assert(mlir::isa<fir::ReferenceType>(privVar.getType()) &&
101+
"Privatized variable should be a reference.");
102+
if (!privatizerNeedsMap(privatizer)) {
103+
return;
104+
}
105+
llvm::errs() << "Privatizer NEEDS a map\n";
106+
builder.setInsertionPoint(targetOp);
107+
dumpPrivatizerInfo(privatizer, privVar);
108+
109+
mlir::Location loc = targetOp.getLoc();
110+
omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
111+
addMapInfoOp(targetOp, mapInfoOp);
112+
llvm::errs() << __FUNCTION__ << "MapInfoOp is \n";
113+
mapInfoOp.dump();
114+
llvm::errs() << "\n";
115+
}
116+
});
117+
}
118+
};
119+
} // namespace

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
880880
objects (e.g. derived types or classes), indicates the bounds to be copied
881881
of the variable. When it's an array slice it is in rank order where rank 0
882882
is the inner-most dimension.
883-
- 'map_clauses': OpenMP map type for this map capture, for example: from, to and
883+
- 'map_type': OpenMP map type for this map capture, for example: from, to and
884884
always. It's a bitfield composed of the OpenMP runtime flags stored in
885885
OpenMPOffloadMappingFlags.
886886
- 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla

0 commit comments

Comments
 (0)