@@ -2470,6 +2470,121 @@ genSingleOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2470
2470
queue, item, clauseOps);
2471
2471
}
2472
2472
2473
+ static mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper (
2474
+ lower::AbstractConverter &converter, mlir::Location loc,
2475
+ fir::RecordType recordType, llvm::StringRef mapperNameStr) {
2476
+ if (converter.getModuleOp ().lookupSymbol (mapperNameStr))
2477
+ return mlir::FlatSymbolRefAttr::get (&converter.getMLIRContext (),
2478
+ mapperNameStr);
2479
+
2480
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2481
+
2482
+ // Save current insertion point before moving to the module scope to create
2483
+ // the DeclareMapperOp.
2484
+ mlir::OpBuilder::InsertionGuard guard (firOpBuilder);
2485
+
2486
+ firOpBuilder.setInsertionPointToStart (converter.getModuleOp ().getBody ());
2487
+ auto declMapperOp = firOpBuilder.create <mlir::omp::DeclareMapperOp>(
2488
+ loc, mapperNameStr, recordType);
2489
+ auto ®ion = declMapperOp.getRegion ();
2490
+ firOpBuilder.createBlock (®ion);
2491
+ auto mapperArg = region.addArgument (firOpBuilder.getRefType (recordType), loc);
2492
+
2493
+ auto declareOp =
2494
+ firOpBuilder.create <hlfir::DeclareOp>(loc, mapperArg, /* uniq_name=*/ " " );
2495
+
2496
+ const auto genBoundsOps = [&](mlir::Value mapVal,
2497
+ llvm::SmallVectorImpl<mlir::Value> &bounds) {
2498
+ fir::ExtendedValue extVal =
2499
+ hlfir::translateToExtendedValue (mapVal.getLoc (), firOpBuilder,
2500
+ hlfir::Entity{mapVal},
2501
+ /* contiguousHint=*/ true )
2502
+ .first ;
2503
+ fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr (
2504
+ firOpBuilder, mapVal, /* isOptional=*/ false , mapVal.getLoc ());
2505
+ bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
2506
+ mlir::omp::MapBoundsType>(
2507
+ firOpBuilder, info, extVal,
2508
+ /* dataExvIsAssumedSize=*/ false , mapVal.getLoc ());
2509
+ };
2510
+
2511
+ // Return a reference to the contents of a derived type with one field.
2512
+ // Also return the field type.
2513
+ const auto getFieldRef =
2514
+ [&](mlir::Value rec, llvm::StringRef fieldName, mlir::Type fieldTy,
2515
+ mlir::Type recType) -> std::tuple<mlir::Value, mlir::Type> {
2516
+ mlir::Value field = firOpBuilder.create <fir::FieldIndexOp>(
2517
+ loc, fir::FieldType::get (recType.getContext ()), fieldName, recType,
2518
+ fir::getTypeParams (rec));
2519
+ return {firOpBuilder.create <fir::CoordinateOp>(
2520
+ loc, firOpBuilder.getRefType (fieldTy), rec, field),
2521
+ fieldTy};
2522
+ };
2523
+
2524
+ mlir::omp::DeclareMapperInfoOperands clauseOps;
2525
+ llvm::SmallVector<llvm::SmallVector<int64_t >> memberPlacementIndices;
2526
+ llvm::SmallVector<mlir::Value> memberMapOps;
2527
+
2528
+ llvm::omp::OpenMPOffloadMappingFlags mapFlag =
2529
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
2530
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM |
2531
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
2532
+ mlir::omp::VariableCaptureKind captureKind =
2533
+ mlir::omp::VariableCaptureKind::ByRef;
2534
+
2535
+ // Populate the declareMapper region with the map information.
2536
+ for (const auto &entry : llvm::enumerate (recordType.getTypeList ())) {
2537
+ const auto &memberName = entry.value ().first ;
2538
+ const auto &memberType = entry.value ().second ;
2539
+ auto [ref, type] =
2540
+ getFieldRef (declareOp.getBase (), memberName, memberType, recordType);
2541
+ mlir::FlatSymbolRefAttr mapperId;
2542
+ if (auto recType = mlir::dyn_cast<fir::RecordType>(memberType)) {
2543
+ std::string mapperIdName =
2544
+ recType.getName ().str () + llvm::omp::OmpDefaultMapperName;
2545
+ if (auto *sym = converter.getCurrentScope ().FindSymbol (mapperIdName))
2546
+ mapperIdName = converter.mangleName (mapperIdName, sym->owner ());
2547
+ else if (auto *sym = converter.getCurrentScope ().FindSymbol (memberName))
2548
+ mapperIdName = converter.mangleName (mapperIdName, sym->owner ());
2549
+
2550
+ mapperId = getOrGenImplicitDefaultDeclareMapper (converter, loc, recType,
2551
+ mapperIdName);
2552
+ }
2553
+
2554
+ llvm::SmallVector<mlir::Value> bounds;
2555
+ genBoundsOps (ref, bounds);
2556
+ mlir::Value mapOp = createMapInfoOp (
2557
+ firOpBuilder, loc, ref, /* varPtrPtr=*/ mlir::Value{}, /* name=*/ " " ,
2558
+ bounds,
2559
+ /* members=*/ {},
2560
+ /* membersIndex=*/ mlir::ArrayAttr{},
2561
+ static_cast <
2562
+ std::underlying_type_t <llvm::omp::OpenMPOffloadMappingFlags>>(
2563
+ mapFlag),
2564
+ captureKind, ref.getType (), /* partialMap=*/ false , mapperId);
2565
+ memberMapOps.emplace_back (mapOp);
2566
+ memberPlacementIndices.emplace_back (
2567
+ llvm::SmallVector<int64_t >{(int64_t )entry.index ()});
2568
+ }
2569
+
2570
+ llvm::SmallVector<mlir::Value> bounds;
2571
+ genBoundsOps (declareOp.getOriginalBase (), bounds);
2572
+ mlir::omp::MapInfoOp mapOp = createMapInfoOp (
2573
+ firOpBuilder, loc, declareOp.getOriginalBase (),
2574
+ /* varPtrPtr=*/ mlir::Value (), /* name=*/ " " , bounds, memberMapOps,
2575
+ firOpBuilder.create2DI64ArrayAttr (memberPlacementIndices),
2576
+ static_cast <std::underlying_type_t <llvm::omp::OpenMPOffloadMappingFlags>>(
2577
+ mapFlag),
2578
+ captureKind, declareOp.getType (0 ),
2579
+ /* partialMap=*/ true );
2580
+
2581
+ clauseOps.mapVars .emplace_back (mapOp);
2582
+
2583
+ firOpBuilder.create <mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars );
2584
+ return mlir::FlatSymbolRefAttr::get (&converter.getMLIRContext (),
2585
+ mapperNameStr);
2586
+ }
2587
+
2473
2588
static mlir::omp::TargetOp
2474
2589
genTargetOp (lower::AbstractConverter &converter, lower::SymMap &symTable,
2475
2590
lower::StatementContext &stmtCtx,
@@ -2546,15 +2661,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2546
2661
name << sym.name ().ToString ();
2547
2662
2548
2663
mlir::FlatSymbolRefAttr mapperId;
2549
- if (sym.GetType ()->category () == semantics::DeclTypeSpec::TypeDerived) {
2664
+ if (sym.GetType ()->category () == semantics::DeclTypeSpec::TypeDerived &&
2665
+ defaultMaps.empty ()) {
2550
2666
auto &typeSpec = sym.GetType ()->derivedTypeSpec ();
2551
2667
std::string mapperIdName =
2552
2668
typeSpec.name ().ToString () + llvm::omp::OmpDefaultMapperName;
2553
2669
if (auto *sym = converter.getCurrentScope ().FindSymbol (mapperIdName))
2554
2670
mapperIdName = converter.mangleName (mapperIdName, sym->owner ());
2671
+ else
2672
+ mapperIdName =
2673
+ converter.mangleName (mapperIdName, *typeSpec.GetScope ());
2674
+
2555
2675
if (converter.getModuleOp ().lookupSymbol (mapperIdName))
2556
2676
mapperId = mlir::FlatSymbolRefAttr::get (&converter.getMLIRContext (),
2557
2677
mapperIdName);
2678
+ mapperId = getOrGenImplicitDefaultDeclareMapper (
2679
+ converter, loc,
2680
+ mlir::cast<fir::RecordType>(
2681
+ converter.genType (sym.GetType ()->derivedTypeSpec ())),
2682
+ mapperIdName);
2558
2683
}
2559
2684
2560
2685
fir::factory::AddrAndBoundsInfo info =
0 commit comments