66// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
77//
88// ===----------------------------------------------------------------------===//
9+
10+ // ===----------------------------------------------------------------------===//
11+ // / \file
12+ // / An OpenMP dialect related pass for FIR/HLFIR which creates MapInfoOp
13+ // / instances for certain privatized symbols.
14+ // / For example, if an allocatable variable is used in a private clause attached
15+ // / to a omp.target op, then the allocatable variable's descriptor will be
16+ // / needed on the device (e.g. GPU). This descriptor needs to be separately
17+ // / mapped onto the device. This pass creates the necessary omp.map.info ops for
18+ // / this.
19+ // ===----------------------------------------------------------------------===//
20+ // TODO:
21+ // 1. Before adding omp.map.info, check if in case we already have an
22+ // omp.map.info for the variable in question.
23+ // 2. Generalize this for more than just omp.target ops.
24+ // ===----------------------------------------------------------------------===//
25+
926#include " flang/Optimizer/Dialect/FIRType.h"
1027#include " flang/Optimizer/HLFIR/HLFIROps.h"
1128#include " flang/Optimizer/OpenMP/Passes.h"
@@ -37,8 +54,7 @@ class MapsForPrivatizedSymbolsPass
3754 return false ;
3855 return true ;
3956 }
40- omp::MapInfoOp createMapInfo (mlir::Location loc, mlir::Value var,
41- OpBuilder &builder) {
57+ omp::MapInfoOp createMapInfo (Location loc, Value var, OpBuilder &builder) {
4258 uint64_t mapTypeTo = static_cast <
4359 std::underlying_type_t <llvm::omp::OpenMPOffloadMappingFlags>>(
4460 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
@@ -47,39 +63,24 @@ class MapsForPrivatizedSymbolsPass
4763 assert (declOp &&
4864 " Expected defining Op of privatized var to be hlfir.declare" );
4965 Value varPtr = declOp.getOriginalBase ();
50- Value varBase = declOp.getBase ();
51- llvm::errs () << " varPtr = " ;
52- varPtr.dump ();
53- llvm::errs () << " type -> " ;
54- varPtr.getType ().dump ();
55- llvm::errs () << " \n " ;
56- llvm::errs () << " varBase = " ;
57- varBase.dump ();
58- llvm::errs () << " type -> " ;
59- varBase.getType ().dump ();
60- llvm::errs () << " \n " ;
66+
6167 return builder.create <omp::MapInfoOp>(
6268 loc, varPtr.getType (), varPtr,
63- mlir::TypeAttr::get (
64- llvm::cast<mlir::omp::PointerLikeType>(varPtr.getType ())
65- .getElementType ()),
66- /* varPtrPtr=*/ mlir::Value{},
67- /* members=*/ mlir::SmallVector<mlir::Value>{},
68- /* member_index=*/ mlir::DenseIntElementsAttr{},
69- /* bounds=*/ mlir::ValueRange{},
69+ TypeAttr::get (llvm::cast<omp::PointerLikeType>(varPtr.getType ())
70+ .getElementType ()),
71+ /* varPtrPtr=*/ Value{},
72+ /* members=*/ SmallVector<Value>{},
73+ /* member_index=*/ DenseIntElementsAttr{},
74+ /* bounds=*/ ValueRange{},
7075 builder.getIntegerAttr (builder.getIntegerType (64 , /* isSigned=*/ false ),
7176 mapTypeTo),
7277 builder.getAttr <omp::VariableCaptureKindAttr>(
7378 omp::VariableCaptureKind::ByRef),
74- mlir:: StringAttr (), builder.getBoolAttr (false ));
79+ StringAttr (), builder.getBoolAttr (false ));
7580 }
7681 void addMapInfoOp (omp::TargetOp targetOp, omp::MapInfoOp mapInfoOp) {
77- mlir::Location loc = targetOp.getLoc ();
78- llvm::errs () << " Adding mapInfoOp -> " ;
79- mapInfoOp.dump ();
80- llvm::errs () << " \n " ;
81-
82- targetOp.getMapVarsMutable ().append (mlir::ValueRange{mapInfoOp});
82+ Location loc = targetOp.getLoc ();
83+ targetOp.getMapVarsMutable ().append (ValueRange{mapInfoOp});
8384 size_t numMapVars = targetOp.getMapVars ().size ();
8485 targetOp.getRegion ().insertArgument (numMapVars - 1 , mapInfoOp.getType (),
8586 loc);
@@ -97,9 +98,6 @@ class MapsForPrivatizedSymbolsPass
9798 getOperation ()->walk ([&](omp::TargetOp targetOp) {
9899 if (targetOp.getPrivateVars ().empty ())
99100 return ;
100- llvm::errs () << " Func is \n " ;
101- targetOp.getOperation ()->getParentOp ()->getParentOp ()->dump ();
102- llvm::errs () << " \n " ;
103101 OperandRange privVars = targetOp.getPrivateVars ();
104102 std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms ();
105103 SmallVector<omp::MapInfoOp, 4 > mapInfoOps;
@@ -109,19 +107,11 @@ class MapsForPrivatizedSymbolsPass
109107 omp::PrivateClauseOp privatizer =
110108 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
111109 targetOp, privatizerName);
112- llvm::errs () << " privVar = " ;
113- privVar.dump ();
114- llvm::errs () << " \n " ;
115- llvm::errs () << " privVar.getType() = " ;
116- privVar.getType ().dump ();
117- llvm::errs () << " \n " ;
118- // assert(mlir::isa<fir::ReferenceType>(privVar.getType()) &&
119- // "Privatized variable should be a reference.");
120110 if (!privatizerNeedsMap (privatizer)) {
121111 continue ;
122112 }
123113 builder.setInsertionPoint (targetOp);
124- mlir:: Location loc = targetOp.getLoc ();
114+ Location loc = targetOp.getLoc ();
125115 omp::MapInfoOp mapInfoOp = createMapInfo (loc, privVar, builder);
126116 mapInfoOps.push_back (mapInfoOp);
127117 LLVM_DEBUG (llvm::dbgs () << " MapsForPrivatizedSymbolsPass created ->\n " );
0 commit comments