@@ -701,175 +701,105 @@ class MapInfoFinalizationPass
701701
702702 auto recordType = mlir::cast<fir::RecordType>(underlyingType);
703703 llvm::SmallVector<mlir::Value> newMapOpsForFields;
704- llvm::SmallVector<llvm::SmallVector< int64_t >> newMemberIndexPaths ;
704+ llvm::SmallVector<int64_t > fieldIndicies ;
705705
706- auto appendMemberMap = [&](mlir::Location loc, mlir::Value coordRef,
707- mlir::Type memTy,
708- llvm::ArrayRef<int64_t > indexPath,
709- llvm::StringRef memberName) {
710- // Check if already mapped (index path equality).
706+ for (auto fieldMemTyPair : recordType.getTypeList ()) {
707+ auto &field = fieldMemTyPair.first ;
708+ auto memTy = fieldMemTyPair.second ;
709+
710+ bool shouldMapField =
711+ llvm::find_if (mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
712+ if (!fir::isAllocatableType (memTy))
713+ return false ;
714+
715+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
716+ if (!designateOp)
717+ return false ;
718+
719+ return designateOp.getComponent () &&
720+ designateOp.getComponent ()->strref () == field;
721+ }) != mapVarForwardSlice.end ();
722+
723+ // TODO Handle recursive record types. Adapting
724+ // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
725+ // entities might be helpful here.
726+
727+ if (!shouldMapField)
728+ continue ;
729+
730+ int32_t fieldIdx = recordType.getFieldIndex (field);
711731 bool alreadyMapped = [&]() {
712732 if (op.getMembersIndexAttr ())
713733 for (auto indexList : op.getMembersIndexAttr ()) {
714734 auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
715- if (indexListAttr.size () != indexPath.size ())
716- continue ;
717- bool allEq = true ;
718- for (auto [i, attr] : llvm::enumerate (indexListAttr)) {
719- if (mlir::cast<mlir::IntegerAttr>(attr).getInt () !=
720- indexPath[i]) {
721- allEq = false ;
722- break ;
723- }
724- }
725- if (allEq)
735+ if (indexListAttr.size () == 1 &&
736+ mlir::cast<mlir::IntegerAttr>(indexListAttr[0 ]).getInt () ==
737+ fieldIdx)
726738 return true ;
727739 }
728740
729741 return false ;
730742 }();
731743
732744 if (alreadyMapped)
733- return ;
745+ continue ;
734746
735747 builder.setInsertionPoint (op);
748+ fir::IntOrValue idxConst =
749+ mlir::IntegerAttr::get (builder.getI32Type (), fieldIdx);
750+ auto fieldCoord = fir::CoordinateOp::create (
751+ builder, op.getLoc (), builder.getRefType (memTy), op.getVarPtr (),
752+ llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
736753 fir::factory::AddrAndBoundsInfo info =
737- fir::factory::getDataOperandBaseAddr (builder, coordRef,
738- /* isOptional=*/ false , loc );
754+ fir::factory::getDataOperandBaseAddr (
755+ builder, fieldCoord, /* isOptional=*/ false , op. getLoc () );
739756 llvm::SmallVector<mlir::Value> bounds =
740757 fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
741758 mlir::omp::MapBoundsType>(
742759 builder, info,
743- hlfir::translateToExtendedValue (loc , builder,
744- hlfir::Entity{coordRef })
760+ hlfir::translateToExtendedValue (op. getLoc () , builder,
761+ hlfir::Entity{fieldCoord })
745762 .first ,
746- /* dataExvIsAssumedSize=*/ false , loc );
763+ /* dataExvIsAssumedSize=*/ false , op. getLoc () );
747764
748765 mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create (
749- builder, loc, coordRef.getType (), coordRef,
750- mlir::TypeAttr::get (fir::unwrapRefType (coordRef.getType ())),
766+ builder, op.getLoc (), fieldCoord.getResult ().getType (),
767+ fieldCoord.getResult (),
768+ mlir::TypeAttr::get (
769+ fir::unwrapRefType (fieldCoord.getResult ().getType ())),
751770 op.getMapTypeAttr (),
752771 builder.getAttr <mlir::omp::VariableCaptureKindAttr>(
753772 mlir::omp::VariableCaptureKind::ByRef),
754773 /* varPtrPtr=*/ mlir::Value{}, /* members=*/ mlir::ValueRange{},
755774 /* members_index=*/ mlir::ArrayAttr{}, bounds,
756775 /* mapperId=*/ mlir::FlatSymbolRefAttr (),
757- builder.getStringAttr (op.getNameAttr ().strref () + " ." +
758- memberName + " .implicit_map" ),
776+ builder.getStringAttr (op.getNameAttr ().strref () + " ." + field +
777+ " .implicit_map" ),
759778 /* partial_map=*/ builder.getBoolAttr (false ));
760779 newMapOpsForFields.emplace_back (fieldMapOp);
761- newMemberIndexPaths.emplace_back (indexPath.begin (), indexPath.end ());
762- };
763-
764- // 1) Handle direct top-level allocatable fields (existing behavior).
765- for (auto fieldMemTyPair : recordType.getTypeList ()) {
766- auto &field = fieldMemTyPair.first ;
767- auto memTy = fieldMemTyPair.second ;
768-
769- if (!fir::isAllocatableType (memTy))
770- continue ;
771-
772- bool referenced = llvm::any_of (mapVarForwardSlice, [&](auto *opv) {
773- auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
774- return designateOp && designateOp.getComponent () &&
775- designateOp.getComponent ()->strref () == field;
776- });
777- if (!referenced)
778- continue ;
779-
780- int32_t fieldIdx = recordType.getFieldIndex (field);
781- builder.setInsertionPoint (op);
782- fir::IntOrValue idxConst =
783- mlir::IntegerAttr::get (builder.getI32Type (), fieldIdx);
784- auto fieldCoord = fir::CoordinateOp::create (
785- builder, op.getLoc (), builder.getRefType (memTy), op.getVarPtr (),
786- llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
787- appendMemberMap (op.getLoc (), fieldCoord, memTy, {fieldIdx}, field);
788- }
789-
790- // Handle nested allocatable fields along any component chain
791- // referenced in the region via HLFIR designates.
792- for (mlir::Operation *sliceOp : mapVarForwardSlice) {
793- auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
794- if (!designateOp || !designateOp.getComponent ())
795- continue ;
796- llvm::SmallVector<llvm::StringRef> compPathReversed;
797- compPathReversed.push_back (designateOp.getComponent ()->strref ());
798- mlir::Value curBase = designateOp.getMemref ();
799- bool rootedAtMapArg = false ;
800- while (true ) {
801- if (auto parentDes = curBase.getDefiningOp <hlfir::DesignateOp>()) {
802- if (!parentDes.getComponent ())
803- break ;
804- compPathReversed.push_back (parentDes.getComponent ()->strref ());
805- curBase = parentDes.getMemref ();
806- continue ;
807- }
808- if (auto decl = curBase.getDefiningOp <hlfir::DeclareOp>()) {
809- if (auto barg =
810- mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref ()))
811- rootedAtMapArg = (barg == opBlockArg);
812- } else if (auto blockArg =
813- mlir::dyn_cast_or_null<mlir::BlockArgument>(
814- curBase)) {
815- rootedAtMapArg = (blockArg == opBlockArg);
816- }
817- break ;
818- }
819- if (!rootedAtMapArg || compPathReversed.size () < 2 )
820- continue ;
821- builder.setInsertionPoint (op);
822- llvm::SmallVector<int64_t > indexPath;
823- mlir::Type curTy = underlyingType;
824- mlir::Value coordRef = op.getVarPtr ();
825- bool validPath = true ;
826- for (llvm::StringRef compName : llvm::reverse (compPathReversed)) {
827- auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
828- if (!recTy) {
829- validPath = false ;
830- break ;
831- }
832- int32_t idx = recTy.getFieldIndex (compName);
833- if (idx < 0 ) {
834- validPath = false ;
835- break ;
836- }
837- indexPath.push_back (idx);
838- mlir::Type memTy = recTy.getType (idx);
839- fir::IntOrValue idxConst =
840- mlir::IntegerAttr::get (builder.getI32Type (), idx);
841- coordRef = fir::CoordinateOp::create (
842- builder, op.getLoc (), builder.getRefType (memTy), coordRef,
843- llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
844- curTy = memTy;
845- }
846- if (!validPath)
847- continue ;
848- if (auto finalRefTy =
849- mlir::dyn_cast<fir::ReferenceType>(coordRef.getType ())) {
850- mlir::Type eleTy = finalRefTy.getElementType ();
851- if (fir::isAllocatableType (eleTy))
852- appendMemberMap (op.getLoc (), coordRef, eleTy, indexPath,
853- compPathReversed.front ());
854- }
780+ fieldIndicies.emplace_back (fieldIdx);
855781 }
856782
857783 if (newMapOpsForFields.empty ())
858784 return mlir::WalkResult::advance ();
859785
860786 op.getMembersMutable ().append (newMapOpsForFields);
861787 llvm::SmallVector<llvm::SmallVector<int64_t >> newMemberIndices;
862- if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr ())
863- for (mlir::Attribute indexList : oldAttr) {
788+ mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr ();
789+
790+ if (oldMembersIdxAttr)
791+ for (mlir::Attribute indexList : oldMembersIdxAttr) {
864792 llvm::SmallVector<int64_t > listVec;
865793
866794 for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
867795 listVec.push_back (mlir::cast<mlir::IntegerAttr>(index).getInt ());
868796
869797 newMemberIndices.emplace_back (std::move (listVec));
870798 }
871- for (auto &path : newMemberIndexPaths)
872- newMemberIndices.emplace_back (path);
799+
800+ for (int64_t newFieldIdx : fieldIndicies)
801+ newMemberIndices.emplace_back (
802+ llvm::SmallVector<int64_t >(1 , newFieldIdx));
873803
874804 op.setMembersIndexAttr (builder.create2DI64ArrayAttr (newMemberIndices));
875805 op.setPartialMap (true );
0 commit comments