Skip to content

Commit 06af400

Browse files
fix failures seen in target-private-multiple-variables.f90
1 parent 2b42abb commit 06af400

File tree

1 file changed

+54
-10
lines changed

1 file changed

+54
-10
lines changed

flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//
88
//===----------------------------------------------------------------------===//
99
#include "flang/Optimizer/Dialect/FIRType.h"
10+
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1011
#include "flang/Optimizer/OpenMP/Passes.h"
1112
#include "mlir/Dialect/Func/IR/FuncOps.h"
1213
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -41,10 +42,27 @@ class MapsForPrivatizedSymbolsPass
4142
uint64_t mapTypeTo = static_cast<
4243
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
4344
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
44-
45+
Operation *definingOp = var.getDefiningOp();
46+
auto declOp = llvm::dyn_cast_or_null<hlfir::DeclareOp>(definingOp);
47+
assert(declOp &&
48+
"Expected defining Op of privatized var to be hlfir.declare");
49+
Value varPtr = declOp.getOriginalBase();
50+
Value varBase = declOp.getBase();
51+
llvm::errs() << "varPtr = ";
52+
varPtr.dump();
53+
llvm::errs() << " type -> ";
54+
varPtr.getType().dump();
55+
llvm::errs() << "\n";
56+
llvm::errs() << "varBase = ";
57+
varBase.dump();
58+
llvm::errs() << " type -> ";
59+
varBase.getType().dump();
60+
llvm::errs() << "\n";
4561
return builder.create<omp::MapInfoOp>(
46-
loc, var.getType(), var,
47-
mlir::TypeAttr::get(fir::unwrapRefType(var.getType())),
62+
loc, varPtr.getType(), varPtr,
63+
mlir::TypeAttr::get(
64+
llvm::cast<mlir::omp::PointerLikeType>(varPtr.getType())
65+
.getElementType()),
4866
/*varPtrPtr=*/mlir::Value{},
4967
/*members=*/mlir::SmallVector<mlir::Value>{},
5068
/*member_index=*/mlir::DenseIntElementsAttr{},
@@ -57,41 +75,67 @@ class MapsForPrivatizedSymbolsPass
5775
}
5876
void addMapInfoOp(omp::TargetOp targetOp, omp::MapInfoOp mapInfoOp) {
5977
mlir::Location loc = targetOp.getLoc();
78+
llvm::errs() << "Adding mapInfoOp -> ";
79+
mapInfoOp.dump();
80+
llvm::errs() << "\n";
81+
6082
targetOp.getMapVarsMutable().append(mlir::ValueRange{mapInfoOp});
6183
size_t numMapVars = targetOp.getMapVars().size();
6284
targetOp.getRegion().insertArgument(numMapVars - 1, mapInfoOp.getType(),
6385
loc);
6486
}
87+
void addMapInfoOps(omp::TargetOp targetOp,
88+
llvm::SmallVectorImpl<omp::MapInfoOp> &mapInfoOps) {
89+
for (auto mapInfoOp : mapInfoOps)
90+
addMapInfoOp(targetOp, mapInfoOp);
91+
}
6592
void runOnOperation() override {
6693
MLIRContext *context = &getContext();
6794
OpBuilder builder(context);
95+
llvm::DenseMap<Operation *, llvm::SmallVector<omp::MapInfoOp, 4>>
96+
mapInfoOpsForTarget;
6897
getOperation()->walk([&](omp::TargetOp targetOp) {
6998
if (targetOp.getPrivateVars().empty())
7099
return;
71-
100+
llvm::errs() << "Func is \n";
101+
targetOp.getOperation()->getParentOp()->getParentOp()->dump();
102+
llvm::errs() << "\n";
72103
OperandRange privVars = targetOp.getPrivateVars();
73104
std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
74-
105+
SmallVector<omp::MapInfoOp, 4> mapInfoOps;
75106
for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
76107

77108
SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
78109
omp::PrivateClauseOp privatizer =
79110
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
80111
targetOp, privatizerName);
81-
82-
assert(mlir::isa<fir::ReferenceType>(privVar.getType()) &&
83-
"Privatized variable should be a reference.");
112+
llvm::errs() << "privVar = ";
113+
privVar.dump();
114+
llvm::errs() << "\n";
115+
llvm::errs() << "privVar.getType() = ";
116+
privVar.getType().dump();
117+
llvm::errs() << "\n";
118+
// assert(mlir::isa<fir::ReferenceType>(privVar.getType()) &&
119+
// "Privatized variable should be a reference.");
84120
if (!privatizerNeedsMap(privatizer)) {
85-
return;
121+
continue;
86122
}
87123
builder.setInsertionPoint(targetOp);
88124
mlir::Location loc = targetOp.getLoc();
89125
omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
90-
addMapInfoOp(targetOp, mapInfoOp);
126+
mapInfoOps.push_back(mapInfoOp);
91127
LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
92128
LLVM_DEBUG(mapInfoOp.dump());
93129
}
130+
if (!mapInfoOps.empty()) {
131+
mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
132+
}
94133
});
134+
if (!mapInfoOpsForTarget.empty()) {
135+
for (auto &[targetOp, mapInfoOps] : mapInfoOpsForTarget) {
136+
addMapInfoOps(static_cast<omp::TargetOp>(targetOp), mapInfoOps);
137+
}
138+
}
95139
}
96140
};
97141
} // namespace

0 commit comments

Comments
 (0)