|
40 | 40 | #include "mlir/IR/SymbolTable.h" |
41 | 41 | #include "mlir/Pass/Pass.h" |
42 | 42 | #include "mlir/Support/LLVM.h" |
| 43 | +#include "llvm/ADT/BitmaskEnum.h" |
43 | 44 | #include "llvm/ADT/SmallPtrSet.h" |
44 | 45 | #include "llvm/ADT/StringSet.h" |
45 | 46 | #include "llvm/Frontend/OpenMP/OMPConstants.h" |
@@ -128,6 +129,17 @@ class MapInfoFinalizationPass |
128 | 129 | } |
129 | 130 | } |
130 | 131 |
|
| 132 | + /// Return true if the module has an OpenMP requires clause that includes |
| 133 | + /// unified_shared_memory. |
| 134 | + static bool moduleRequiresUSM(mlir::ModuleOp module) { |
| 135 | + assert(module && "invalid module"); |
| 136 | + if (auto req = module->getAttrOfType<mlir::omp::ClauseRequiresAttr>( |
| 137 | + "omp.requires")) |
| 138 | + return mlir::omp::bitEnumContainsAll( |
| 139 | + req.getValue(), mlir::omp::ClauseRequires::unified_shared_memory); |
| 140 | + return false; |
| 141 | + } |
| 142 | + |
131 | 143 | /// Create the member map for coordRef and append it (and its index |
132 | 144 | /// path) to the provided new* vectors, if it is not already present. |
133 | 145 | void appendMemberMapIfNew( |
@@ -425,8 +437,12 @@ class MapInfoFinalizationPass |
425 | 437 |
|
426 | 438 | mapFlags flags = mapFlags::OMP_MAP_TO | |
427 | 439 | (mapFlags(mapTypeFlag) & |
428 | | - (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_CLOSE | |
429 | | - mapFlags::OMP_MAP_ALWAYS)); |
| 440 | + (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_ALWAYS)); |
| 441 | + // For unified_shared_memory, we additionally add `CLOSE` on the descriptor |
| 442 | + // to ensure device-local placement where required by tests relying on USM + |
| 443 | + // close semantics. |
| 444 | + if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>())) |
| 445 | + flags |= mapFlags::OMP_MAP_CLOSE; |
430 | 446 | return llvm::to_underlying(flags); |
431 | 447 | } |
432 | 448 |
|
@@ -518,6 +534,75 @@ class MapInfoFinalizationPass |
518 | 534 | return newMapInfoOp; |
519 | 535 | } |
520 | 536 |
|
| 537 | + // Expand mappings of type(C_PTR) to map their `__address` field explicitly |
| 538 | + // as a single pointer-sized member (USM-gated at callsite). This helps in |
| 539 | + // USM scenarios to ensure the pointer-sized mapping is used. |
| 540 | + mlir::omp::MapInfoOp genCptrMemberMap(mlir::omp::MapInfoOp op, |
| 541 | + fir::FirOpBuilder &builder) { |
| 542 | + if (!op.getMembers().empty()) |
| 543 | + return op; |
| 544 | + |
| 545 | + mlir::Type varTy = fir::unwrapRefType(op.getVarPtr().getType()); |
| 546 | + if (!mlir::isa<fir::RecordType>(varTy)) |
| 547 | + return op; |
| 548 | + auto recTy = mlir::cast<fir::RecordType>(varTy); |
| 549 | + // If not a builtin C_PTR record, skip. |
| 550 | + if (!recTy.getName().ends_with("__builtin_c_ptr")) |
| 551 | + return op; |
| 552 | + |
| 553 | + // Find the index of the c_ptr address component named "__address". |
| 554 | + int32_t fieldIdx = recTy.getFieldIndex("__address"); |
| 555 | + if (fieldIdx < 0) |
| 556 | + return op; |
| 557 | + |
| 558 | + mlir::Location loc = op.getVarPtr().getLoc(); |
| 559 | + mlir::Type memTy = recTy.getType(fieldIdx); |
| 560 | + fir::IntOrValue idxConst = |
| 561 | + mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx); |
| 562 | + mlir::Value coord = fir::CoordinateOp::create( |
| 563 | + builder, loc, builder.getRefType(memTy), op.getVarPtr(), |
| 564 | + llvm::SmallVector<fir::IntOrValue, 1>{idxConst}); |
| 565 | + |
| 566 | + // Child for the `__address` member. |
| 567 | + llvm::SmallVector<llvm::SmallVector<int64_t>> memberIdx = {{0}}; |
| 568 | + mlir::ArrayAttr newMembersAttr = builder.create2DI64ArrayAttr(memberIdx); |
| 569 | + // Force CLOSE in USM paths so the pointer gets device-local placement |
| 570 | + // when required by tests relying on USM + close semantics. |
| 571 | + uint64_t mapTypeVal = |
| 572 | + op.getMapType() | |
| 573 | + llvm::to_underlying( |
| 574 | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); |
| 575 | + mlir::IntegerAttr mapTypeAttr = builder.getIntegerAttr( |
| 576 | + builder.getIntegerType(64, /*isSigned=*/false), mapTypeVal); |
| 577 | + |
| 578 | + mlir::omp::MapInfoOp memberMap = mlir::omp::MapInfoOp::create( |
| 579 | + builder, loc, coord.getType(), coord, |
| 580 | + mlir::TypeAttr::get(fir::unwrapRefType(coord.getType())), mapTypeAttr, |
| 581 | + builder.getAttr<mlir::omp::VariableCaptureKindAttr>( |
| 582 | + mlir::omp::VariableCaptureKind::ByRef), |
| 583 | + /*varPtrPtr=*/mlir::Value{}, |
| 584 | + /*members=*/llvm::SmallVector<mlir::Value>{}, |
| 585 | + /*member_index=*/mlir::ArrayAttr{}, |
| 586 | + /*bounds=*/op.getBounds(), |
| 587 | + /*mapperId=*/mlir::FlatSymbolRefAttr(), |
| 588 | + /*name=*/op.getNameAttr(), |
| 589 | + /*partial_map=*/builder.getBoolAttr(false)); |
| 590 | + |
| 591 | + // Rebuild the parent as a container with the `__address` member. |
| 592 | + mlir::omp::MapInfoOp newParent = mlir::omp::MapInfoOp::create( |
| 593 | + builder, op.getLoc(), op.getResult().getType(), op.getVarPtr(), |
| 594 | + op.getVarTypeAttr(), mapTypeAttr, op.getMapCaptureTypeAttr(), |
| 595 | + /*varPtrPtr=*/mlir::Value{}, |
| 596 | + /*members=*/llvm::SmallVector<mlir::Value>{memberMap}, |
| 597 | + /*member_index=*/newMembersAttr, |
| 598 | + /*bounds=*/llvm::SmallVector<mlir::Value>{}, |
| 599 | + /*mapperId=*/mlir::FlatSymbolRefAttr(), op.getNameAttr(), |
| 600 | + /*partial_map=*/builder.getBoolAttr(false)); |
| 601 | + op.replaceAllUsesWith(newParent.getResult()); |
| 602 | + op->erase(); |
| 603 | + return newParent; |
| 604 | + } |
| 605 | + |
521 | 606 | mlir::omp::MapInfoOp genDescriptorMemberMaps(mlir::omp::MapInfoOp op, |
522 | 607 | fir::FirOpBuilder &builder, |
523 | 608 | mlir::Operation *target) { |
@@ -1169,6 +1254,17 @@ class MapInfoFinalizationPass |
1169 | 1254 | genBoxcharMemberMap(op, builder); |
1170 | 1255 | }); |
1171 | 1256 |
|
| 1257 | + // Expand type(C_PTR) only when unified_shared_memory is required, |
| 1258 | + // to ensure device-visible pointer size/behavior in USM scenarios |
| 1259 | + // without changing default expectations elsewhere. |
| 1260 | + func->walk([&](mlir::omp::MapInfoOp op) { |
| 1261 | + // Only expand C_PTR members when unified_shared_memory is required. |
| 1262 | + if (!moduleRequiresUSM(func->getParentOfType<mlir::ModuleOp>())) |
| 1263 | + return; |
| 1264 | + builder.setInsertionPoint(op); |
| 1265 | + genCptrMemberMap(op, builder); |
| 1266 | + }); |
| 1267 | + |
1172 | 1268 | func->walk([&](mlir::omp::MapInfoOp op) { |
1173 | 1269 | // TODO: Currently only supports a single user for the MapInfoOp. This |
1174 | 1270 | // is fine for the moment, as the Fortran frontend will generate a |
|
0 commit comments