77//
88// ===----------------------------------------------------------------------===//
99#include " flang/Optimizer/Dialect/FIRType.h"
10+ #include " flang/Optimizer/HLFIR/HLFIROps.h"
1011#include " flang/Optimizer/OpenMP/Passes.h"
1112#include " mlir/Dialect/Func/IR/FuncOps.h"
1213#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -41,10 +42,27 @@ class MapsForPrivatizedSymbolsPass
4142 uint64_t mapTypeTo = static_cast <
4243 std::underlying_type_t <llvm::omp::OpenMPOffloadMappingFlags>>(
4344 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
44-
45+ Operation *definingOp = var.getDefiningOp ();
46+ auto declOp = llvm::dyn_cast_or_null<hlfir::DeclareOp>(definingOp);
47+ assert (declOp &&
48+ " Expected defining Op of privatized var to be hlfir.declare" );
49+ Value varPtr = declOp.getOriginalBase ();
50+ Value varBase = declOp.getBase ();
51+ llvm::errs () << " varPtr = " ;
52+ varPtr.dump ();
53+ llvm::errs () << " type -> " ;
54+ varPtr.getType ().dump ();
55+ llvm::errs () << " \n " ;
56+ llvm::errs () << " varBase = " ;
57+ varBase.dump ();
58+ llvm::errs () << " type -> " ;
59+ varBase.getType ().dump ();
60+ llvm::errs () << " \n " ;
4561 return builder.create <omp::MapInfoOp>(
46- loc, var.getType (), var,
47- mlir::TypeAttr::get (fir::unwrapRefType (var.getType ())),
62+ loc, varPtr.getType (), varPtr,
63+ mlir::TypeAttr::get (
64+ llvm::cast<mlir::omp::PointerLikeType>(varPtr.getType ())
65+ .getElementType ()),
4866 /* varPtrPtr=*/ mlir::Value{},
4967 /* members=*/ mlir::SmallVector<mlir::Value>{},
5068 /* member_index=*/ mlir::DenseIntElementsAttr{},
@@ -57,41 +75,67 @@ class MapsForPrivatizedSymbolsPass
5775 }
5876 void addMapInfoOp (omp::TargetOp targetOp, omp::MapInfoOp mapInfoOp) {
5977 mlir::Location loc = targetOp.getLoc ();
78+ llvm::errs () << " Adding mapInfoOp -> " ;
79+ mapInfoOp.dump ();
80+ llvm::errs () << " \n " ;
81+
6082 targetOp.getMapVarsMutable ().append (mlir::ValueRange{mapInfoOp});
6183 size_t numMapVars = targetOp.getMapVars ().size ();
6284 targetOp.getRegion ().insertArgument (numMapVars - 1 , mapInfoOp.getType (),
6385 loc);
6486 }
87+ void addMapInfoOps (omp::TargetOp targetOp,
88+ llvm::SmallVectorImpl<omp::MapInfoOp> &mapInfoOps) {
89+ for (auto mapInfoOp : mapInfoOps)
90+ addMapInfoOp (targetOp, mapInfoOp);
91+ }
6592 void runOnOperation () override {
6693 MLIRContext *context = &getContext ();
6794 OpBuilder builder (context);
95+ llvm::DenseMap<Operation *, llvm::SmallVector<omp::MapInfoOp, 4 >>
96+ mapInfoOpsForTarget;
6897 getOperation ()->walk ([&](omp::TargetOp targetOp) {
6998 if (targetOp.getPrivateVars ().empty ())
7099 return ;
71-
100+ llvm::errs () << " Func is \n " ;
101+ targetOp.getOperation ()->getParentOp ()->getParentOp ()->dump ();
102+ llvm::errs () << " \n " ;
72103 OperandRange privVars = targetOp.getPrivateVars ();
73104 std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms ();
74-
105+ SmallVector<omp::MapInfoOp, 4 > mapInfoOps;
75106 for (auto [privVar, privSym] : llvm::zip_equal (privVars, *privSyms)) {
76107
77108 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
78109 omp::PrivateClauseOp privatizer =
79110 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
80111 targetOp, privatizerName);
81-
82- assert (mlir::isa<fir::ReferenceType>(privVar.getType ()) &&
83- " Privatized variable should be a reference." );
112+ llvm::errs () << " privVar = " ;
113+ privVar.dump ();
114+ llvm::errs () << " \n " ;
115+ llvm::errs () << " privVar.getType() = " ;
116+ privVar.getType ().dump ();
117+ llvm::errs () << " \n " ;
118+ // assert(mlir::isa<fir::ReferenceType>(privVar.getType()) &&
119+ // "Privatized variable should be a reference.");
84120 if (!privatizerNeedsMap (privatizer)) {
85- return ;
121+ continue ;
86122 }
87123 builder.setInsertionPoint (targetOp);
88124 mlir::Location loc = targetOp.getLoc ();
89125 omp::MapInfoOp mapInfoOp = createMapInfo (loc, privVar, builder);
90- addMapInfoOp (targetOp, mapInfoOp);
126+ mapInfoOps. push_back ( mapInfoOp);
91127 LLVM_DEBUG (llvm::dbgs () << " MapsForPrivatizedSymbolsPass created ->\n " );
92128 LLVM_DEBUG (mapInfoOp.dump ());
93129 }
130+ if (!mapInfoOps.empty ()) {
131+ mapInfoOpsForTarget.insert ({targetOp.getOperation (), mapInfoOps});
132+ }
94133 });
134+ if (!mapInfoOpsForTarget.empty ()) {
135+ for (auto &[targetOp, mapInfoOps] : mapInfoOpsForTarget) {
136+ addMapInfoOps (static_cast <omp::TargetOp>(targetOp), mapInfoOps);
137+ }
138+ }
95139 }
96140};
97141} // namespace
0 commit comments