@@ -701,105 +701,175 @@ class MapInfoFinalizationPass
701701
702702 auto recordType = mlir::cast<fir::RecordType>(underlyingType);
703703 llvm::SmallVector<mlir::Value> newMapOpsForFields;
704- llvm::SmallVector<int64_t > fieldIndicies ;
704+ llvm::SmallVector<llvm::SmallVector< int64_t >> newMemberIndexPaths ;
705705
706- for (auto fieldMemTyPair : recordType.getTypeList ()) {
707- auto &field = fieldMemTyPair.first ;
708- auto memTy = fieldMemTyPair.second ;
709-
710- bool shouldMapField =
711- llvm::find_if (mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
712- if (!fir::isAllocatableType (memTy))
713- return false ;
714-
715- auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
716- if (!designateOp)
717- return false ;
718-
719- return designateOp.getComponent () &&
720- designateOp.getComponent ()->strref () == field;
721- }) != mapVarForwardSlice.end ();
722-
723- // TODO Handle recursive record types. Adapting
724- // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
725- // entities might be helpful here.
726-
727- if (!shouldMapField)
728- continue ;
729-
730- int32_t fieldIdx = recordType.getFieldIndex (field);
706+ auto appendMemberMap = [&](mlir::Location loc, mlir::Value coordRef,
707+ mlir::Type memTy,
708+ llvm::ArrayRef<int64_t > indexPath,
709+ llvm::StringRef memberName) {
710+ // Check if already mapped (index path equality).
731711 bool alreadyMapped = [&]() {
732712 if (op.getMembersIndexAttr ())
733713 for (auto indexList : op.getMembersIndexAttr ()) {
734714 auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
735- if (indexListAttr.size () == 1 &&
736- mlir::cast<mlir::IntegerAttr>(indexListAttr[0 ]).getInt () ==
737- fieldIdx)
715+ if (indexListAttr.size () != indexPath.size ())
716+ continue ;
717+ bool allEq = true ;
718+ for (auto [i, attr] : llvm::enumerate (indexListAttr)) {
719+ if (mlir::cast<mlir::IntegerAttr>(attr).getInt () !=
720+ indexPath[i]) {
721+ allEq = false ;
722+ break ;
723+ }
724+ }
725+ if (allEq)
738726 return true ;
739727 }
740728
741729 return false ;
742730 }();
743731
744732 if (alreadyMapped)
745- continue ;
733+ return ;
746734
747735 builder.setInsertionPoint (op);
748- fir::IntOrValue idxConst =
749- mlir::IntegerAttr::get (builder.getI32Type (), fieldIdx);
750- auto fieldCoord = fir::CoordinateOp::create (
751- builder, op.getLoc (), builder.getRefType (memTy), op.getVarPtr (),
752- llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
753736 fir::factory::AddrAndBoundsInfo info =
754- fir::factory::getDataOperandBaseAddr (
755- builder, fieldCoord, /* isOptional=*/ false , op. getLoc () );
737+ fir::factory::getDataOperandBaseAddr (builder, coordRef,
738+ /* isOptional=*/ false , loc );
756739 llvm::SmallVector<mlir::Value> bounds =
757740 fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
758741 mlir::omp::MapBoundsType>(
759742 builder, info,
760- hlfir::translateToExtendedValue (op. getLoc () , builder,
761- hlfir::Entity{fieldCoord })
743+ hlfir::translateToExtendedValue (loc , builder,
744+ hlfir::Entity{coordRef })
762745 .first ,
763- /* dataExvIsAssumedSize=*/ false , op. getLoc () );
746+ /* dataExvIsAssumedSize=*/ false , loc );
764747
765748 mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create (
766- builder, op.getLoc (), fieldCoord.getResult ().getType (),
767- fieldCoord.getResult (),
768- mlir::TypeAttr::get (
769- fir::unwrapRefType (fieldCoord.getResult ().getType ())),
749+ builder, loc, coordRef.getType (), coordRef,
750+ mlir::TypeAttr::get (fir::unwrapRefType (coordRef.getType ())),
770751 op.getMapTypeAttr (),
771752 builder.getAttr <mlir::omp::VariableCaptureKindAttr>(
772753 mlir::omp::VariableCaptureKind::ByRef),
773754 /* varPtrPtr=*/ mlir::Value{}, /* members=*/ mlir::ValueRange{},
774755 /* members_index=*/ mlir::ArrayAttr{}, bounds,
775756 /* mapperId=*/ mlir::FlatSymbolRefAttr (),
776- builder.getStringAttr (op.getNameAttr ().strref () + " ." + field +
777- " .implicit_map" ),
757+ builder.getStringAttr (op.getNameAttr ().strref () + " ." +
758+ memberName + " .implicit_map" ),
778759 /* partial_map=*/ builder.getBoolAttr (false ));
779760 newMapOpsForFields.emplace_back (fieldMapOp);
780- fieldIndicies.emplace_back (fieldIdx);
761+ newMemberIndexPaths.emplace_back (indexPath.begin (), indexPath.end ());
762+ };
763+
764+ // 1) Handle direct top-level allocatable fields (existing behavior).
765+ for (auto fieldMemTyPair : recordType.getTypeList ()) {
766+ auto &field = fieldMemTyPair.first ;
767+ auto memTy = fieldMemTyPair.second ;
768+
769+ if (!fir::isAllocatableType (memTy))
770+ continue ;
771+
772+ bool referenced = llvm::any_of (mapVarForwardSlice, [&](auto *opv) {
773+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
774+ return designateOp && designateOp.getComponent () &&
775+ designateOp.getComponent ()->strref () == field;
776+ });
777+ if (!referenced)
778+ continue ;
779+
780+ int32_t fieldIdx = recordType.getFieldIndex (field);
781+ builder.setInsertionPoint (op);
782+ fir::IntOrValue idxConst =
783+ mlir::IntegerAttr::get (builder.getI32Type (), fieldIdx);
784+ auto fieldCoord = fir::CoordinateOp::create (
785+ builder, op.getLoc (), builder.getRefType (memTy), op.getVarPtr (),
786+ llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
787+ appendMemberMap (op.getLoc (), fieldCoord, memTy, {fieldIdx}, field);
788+ }
789+
790+ // Handle nested allocatable fields along any component chain
791+ // referenced in the region via HLFIR designates.
792+ for (mlir::Operation *sliceOp : mapVarForwardSlice) {
793+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
794+ if (!designateOp || !designateOp.getComponent ())
795+ continue ;
796+ llvm::SmallVector<llvm::StringRef> compPathReversed;
797+ compPathReversed.push_back (designateOp.getComponent ()->strref ());
798+ mlir::Value curBase = designateOp.getMemref ();
799+ bool rootedAtMapArg = false ;
800+ while (true ) {
801+ if (auto parentDes = curBase.getDefiningOp <hlfir::DesignateOp>()) {
802+ if (!parentDes.getComponent ())
803+ break ;
804+ compPathReversed.push_back (parentDes.getComponent ()->strref ());
805+ curBase = parentDes.getMemref ();
806+ continue ;
807+ }
808+ if (auto decl = curBase.getDefiningOp <hlfir::DeclareOp>()) {
809+ if (auto barg =
810+ mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref ()))
811+ rootedAtMapArg = (barg == opBlockArg);
812+ } else if (auto blockArg =
813+ mlir::dyn_cast_or_null<mlir::BlockArgument>(
814+ curBase)) {
815+ rootedAtMapArg = (blockArg == opBlockArg);
816+ }
817+ break ;
818+ }
819+ if (!rootedAtMapArg || compPathReversed.size () < 2 )
820+ continue ;
821+ builder.setInsertionPoint (op);
822+ llvm::SmallVector<int64_t > indexPath;
823+ mlir::Type curTy = underlyingType;
824+ mlir::Value coordRef = op.getVarPtr ();
825+ bool validPath = true ;
826+ for (llvm::StringRef compName : llvm::reverse (compPathReversed)) {
827+ auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
828+ if (!recTy) {
829+ validPath = false ;
830+ break ;
831+ }
832+ int32_t idx = recTy.getFieldIndex (compName);
833+ if (idx < 0 ) {
834+ validPath = false ;
835+ break ;
836+ }
837+ indexPath.push_back (idx);
838+ mlir::Type memTy = recTy.getType (idx);
839+ fir::IntOrValue idxConst =
840+ mlir::IntegerAttr::get (builder.getI32Type (), idx);
841+ coordRef = fir::CoordinateOp::create (
842+ builder, op.getLoc (), builder.getRefType (memTy), coordRef,
843+ llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
844+ curTy = memTy;
845+ }
846+ if (!validPath)
847+ continue ;
848+ if (auto finalRefTy =
849+ mlir::dyn_cast<fir::ReferenceType>(coordRef.getType ())) {
850+ mlir::Type eleTy = finalRefTy.getElementType ();
851+ if (fir::isAllocatableType (eleTy))
852+ appendMemberMap (op.getLoc (), coordRef, eleTy, indexPath,
853+ compPathReversed.front ());
854+ }
781855 }
782856
783857 if (newMapOpsForFields.empty ())
784858 return mlir::WalkResult::advance ();
785859
786860 op.getMembersMutable ().append (newMapOpsForFields);
787861 llvm::SmallVector<llvm::SmallVector<int64_t >> newMemberIndices;
788- mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr ();
789-
790- if (oldMembersIdxAttr)
791- for (mlir::Attribute indexList : oldMembersIdxAttr) {
862+ if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr ())
863+ for (mlir::Attribute indexList : oldAttr) {
792864 llvm::SmallVector<int64_t > listVec;
793865
794866 for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
795867 listVec.push_back (mlir::cast<mlir::IntegerAttr>(index).getInt ());
796868
797869 newMemberIndices.emplace_back (std::move (listVec));
798870 }
799-
800- for (int64_t newFieldIdx : fieldIndicies)
801- newMemberIndices.emplace_back (
802- llvm::SmallVector<int64_t >(1 , newFieldIdx));
871+ for (auto &path : newMemberIndexPaths)
872+ newMemberIndices.emplace_back (path);
803873
804874 op.setMembersIndexAttr (builder.create2DI64ArrayAttr (newMemberIndices));
805875 op.setPartialMap (true );
0 commit comments