@@ -125,61 +125,82 @@ class MapInfoFinalizationPass
125125 // TODO: map the addendum segment of the descriptor, similarly to the
126126 // above base address/data pointer member.
127127
128- auto addOperands = [&](mlir::OperandRange &operandsArr,
129- mlir::MutableOperandRange &mutableOpRange,
130- auto directiveOp) {
131- llvm::SmallVector<mlir::Value> newMapOps;
132- for (size_t i = 0 ; i < operandsArr.size (); ++i) {
133- if (operandsArr[i] == op) {
134- // Push new implicit maps generated for the descriptor.
135- newMapOps.push_back (baseAddr);
128+ mlir::omp::MapInfoOp newDescParentMapOp =
129+ builder.create <mlir::omp::MapInfoOp>(
130+ op->getLoc (), op.getResult ().getType (), descriptor,
131+ mlir::TypeAttr::get (fir::unwrapRefType (descriptor.getType ())),
132+ /* varPtrPtr=*/ mlir::Value{},
133+ /* members=*/ mlir::SmallVector<mlir::Value>{baseAddr},
134+ /* members_index=*/
135+ mlir::DenseIntElementsAttr::get (
136+ mlir::VectorType::get (
137+ llvm::ArrayRef<int64_t >({1 , 1 }),
138+ mlir::IntegerType::get (builder.getContext (), 32 )),
139+ llvm::ArrayRef<int32_t >({0 })),
140+ /* bounds=*/ mlir::SmallVector<mlir::Value>{},
141+ builder.getIntegerAttr (builder.getIntegerType (64 , false ),
142+ op.getMapType ().value ()),
143+ op.getMapCaptureTypeAttr (), op.getNameAttr (),
144+ op.getPartialMapAttr ());
145+ op.replaceAllUsesWith (newDescParentMapOp.getResult ());
146+ op->erase ();
136147
137- // for TargetOp's which have IsolatedFromAbove we must align the
138- // new additional map operand with an appropriate BlockArgument,
139- // as the printing and later processing currently requires a 1:1
140- // mapping of BlockArgs to MapInfoOp's at the same placement in
141- // each array (BlockArgs and MapOperands).
142- if (directiveOp) {
143- directiveOp.getRegion ().insertArgument (i, baseAddr.getType (), loc);
148+ auto addOperands = [&](mlir::OperandRange &mapVarsArr,
149+ mlir::MutableOperandRange &mutableOpRange,
150+ mlir::Operation *directiveOp,
151+ mlir::omp::MapInfoOp newDesc,
152+ unsigned blockArgInsertIndex = 0 ,
153+ bool insertBlockArgs = true ) {
154+ if (llvm::is_contained (mapVarsArr, newDesc.getResult ())) {
155+ llvm::SmallVector<mlir::Value> newMapOps{mapVarsArr};
156+ for (auto mapMember : newDesc.getMembers ()) {
157+ if (!llvm::is_contained (mapVarsArr, mapMember)) {
158+ newMapOps.push_back (mapMember);
159+ if (directiveOp && insertBlockArgs) {
160+ directiveOp->getRegion (0 ).insertArgument (
161+ blockArgInsertIndex, mapMember.getType (), mapMember.getLoc ());
162+ }
163+ blockArgInsertIndex++;
144164 }
145165 }
146- newMapOps. push_back (operandsArr[i] );
166+ mutableOpRange. assign (newMapOps );
147167 }
148- mutableOpRange.assign (newMapOps);
149168 };
169+
170+ auto argIface =
171+ llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(target);
172+
150173 if (auto mapClauseOwner =
151174 llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target)) {
152- mlir::OperandRange mapOperandsArr = mapClauseOwner.getMapVars ();
175+ mlir::OperandRange mapVarsArr = mapClauseOwner.getMapVars ();
153176 mlir::MutableOperandRange mapMutableOpRange =
154177 mapClauseOwner.getMapVarsMutable ();
155- mlir::omp::TargetOp targetOp =
156- llvm::dyn_cast<mlir::omp::TargetOp>(target);
157- addOperands (mapOperandsArr, mapMutableOpRange, targetOp);
178+ unsigned blockArgInsertIndex =
179+ argIface
180+ ? argIface.getMapBlockArgsStart () + argIface.numMapBlockArgs ()
181+ : 0 ;
182+ addOperands (mapVarsArr, mapMutableOpRange, argIface.getOperation (),
183+ newDescParentMapOp, blockArgInsertIndex,
184+ !llvm::isa<mlir::omp::TargetDataOp>(target));
158185 }
186+
159187 if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
160188 mlir::OperandRange useDevAddrArr = targetDataOp.getUseDeviceAddrVars ();
161189 mlir::MutableOperandRange useDevAddrMutableOpRange =
162190 targetDataOp.getUseDeviceAddrVarsMutable ();
163- addOperands (useDevAddrArr, useDevAddrMutableOpRange, targetDataOp);
164- }
191+ addOperands (useDevAddrArr, useDevAddrMutableOpRange, target,
192+ newDescParentMapOp,
193+ argIface.getUseDeviceAddrBlockArgsStart () +
194+ argIface.numUseDeviceAddrBlockArgs ());
165195
166- mlir::Value newDescParentMapOp = builder.create <mlir::omp::MapInfoOp>(
167- op->getLoc (), op.getResult ().getType (), descriptor,
168- mlir::TypeAttr::get (fir::unwrapRefType (descriptor.getType ())),
169- /* varPtrPtr=*/ mlir::Value{},
170- /* members=*/ mlir::SmallVector<mlir::Value>{baseAddr},
171- /* members_index=*/
172- mlir::DenseIntElementsAttr::get (
173- mlir::VectorType::get (
174- llvm::ArrayRef<int64_t >({1 , 1 }),
175- mlir::IntegerType::get (builder.getContext (), 32 )),
176- llvm::ArrayRef<int32_t >({0 })),
177- /* bounds=*/ mlir::SmallVector<mlir::Value>{},
178- builder.getIntegerAttr (builder.getIntegerType (64 , false ),
179- op.getMapType ().value ()),
180- op.getMapCaptureTypeAttr (), op.getNameAttr (), op.getPartialMapAttr ());
181- op.replaceAllUsesWith (newDescParentMapOp);
182- op->erase ();
196+ mlir::OperandRange useDevPtrArr = targetDataOp.getUseDevicePtrVars ();
197+ mlir::MutableOperandRange useDevPtrMutableOpRange =
198+ targetDataOp.getUseDevicePtrVarsMutable ();
199+ addOperands (useDevPtrArr, useDevPtrMutableOpRange, target,
200+ newDescParentMapOp,
201+ argIface.getUseDevicePtrBlockArgsStart () +
202+ argIface.numUseDevicePtrBlockArgs ());
203+ }
183204 }
184205
185206 // We add all mapped record members not directly used in the target region
0 commit comments