Skip to content

Commit 04cb5e8

Browse files
committed
Refactor to add utility function moduleRequiresUSM.
1 parent 6f95445 commit 04cb5e8

File tree

1 file changed

+15
-28
lines changed

1 file changed

+15
-28
lines changed

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,17 @@ class MapInfoFinalizationPass
129129
}
130130
}
131131

132+
/// Return true if the module has an OpenMP requires clause that includes
133+
/// unified_shared_memory.
134+
static bool moduleRequiresUSM(mlir::ModuleOp module) {
135+
assert(module && "invalid module");
136+
if (auto req = module->getAttrOfType<mlir::omp::ClauseRequiresAttr>(
137+
"omp.requires"))
138+
return mlir::omp::bitEnumContainsAll(
139+
req.getValue(), mlir::omp::ClauseRequires::unified_shared_memory);
140+
return false;
141+
}
142+
132143
/// Create the member map for coordRef and append it (and its index
133144
/// path) to the provided new* vectors, if it is not already present.
134145
void appendMemberMapIfNew(
@@ -430,19 +441,8 @@ class MapInfoFinalizationPass
430441
// For unified_shared_memory, we additionally add `CLOSE` on the descriptor
431442
// to ensure device-local placement where required by tests relying on USM +
432443
// close semantics.
433-
if (target) {
434-
if (auto mod = target->getParentOfType<mlir::ModuleOp>()) {
435-
if (mlir::Attribute reqAttr = mod->getAttr("omp.requires")) {
436-
if (auto req =
437-
mlir::dyn_cast<mlir::omp::ClauseRequiresAttr>(reqAttr)) {
438-
if (mlir::omp::bitEnumContainsAll(
439-
req.getValue(),
440-
mlir::omp::ClauseRequires::unified_shared_memory))
441-
flags |= mapFlags::OMP_MAP_CLOSE;
442-
}
443-
}
444-
}
445-
}
444+
if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>()))
445+
flags |= mapFlags::OMP_MAP_CLOSE;
446446
return llvm::to_underlying(flags);
447447
}
448448

@@ -1258,22 +1258,9 @@ class MapInfoFinalizationPass
12581258
// to ensure device-visible pointer size/behavior in USM scenarios
12591259
// without changing default expectations elsewhere.
12601260
func->walk([&](mlir::omp::MapInfoOp op) {
1261-
// Check module requires USM; otherwise, leave mappings untouched.
1262-
auto mod = func->getParentOfType<mlir::ModuleOp>();
1263-
bool hasUSM = false;
1264-
if (mod) {
1265-
if (mlir::Attribute reqAttr = mod->getAttr("omp.requires")) {
1266-
if (auto req =
1267-
mlir::dyn_cast<mlir::omp::ClauseRequiresAttr>(reqAttr)) {
1268-
hasUSM = mlir::omp::bitEnumContainsAll(
1269-
req.getValue(),
1270-
mlir::omp::ClauseRequires::unified_shared_memory);
1271-
}
1272-
}
1273-
}
1274-
if (!hasUSM)
1261+
// Only expand C_PTR members when unified_shared_memory is required.
1262+
if (!moduleRequiresUSM(func->getParentOfType<mlir::ModuleOp>()))
12751263
return;
1276-
12771264
builder.setInsertionPoint(op);
12781265
genCptrMemberMap(op, builder);
12791266
});

0 commit comments

Comments
 (0)