@@ -129,6 +129,17 @@ class MapInfoFinalizationPass
129129 }
130130 }
131131
132+ // / Return true if the module has an OpenMP requires clause that includes
133+ // / unified_shared_memory.
134+ static bool moduleRequiresUSM (mlir::ModuleOp module ) {
135+ assert (module && " invalid module" );
136+ if (auto req = module ->getAttrOfType <mlir::omp::ClauseRequiresAttr>(
137+ " omp.requires" ))
138+ return mlir::omp::bitEnumContainsAll (
139+ req.getValue (), mlir::omp::ClauseRequires::unified_shared_memory);
140+ return false ;
141+ }
142+
132143 // / Create the member map for coordRef and append it (and its index
133144 // / path) to the provided new* vectors, if it is not already present.
134145 void appendMemberMapIfNew (
@@ -430,19 +441,8 @@ class MapInfoFinalizationPass
430441 // For unified_shared_memory, we additionally add `CLOSE` on the descriptor
431442 // to ensure device-local placement where required by tests relying on USM +
432443 // close semantics.
433- if (target) {
434- if (auto mod = target->getParentOfType <mlir::ModuleOp>()) {
435- if (mlir::Attribute reqAttr = mod->getAttr (" omp.requires" )) {
436- if (auto req =
437- mlir::dyn_cast<mlir::omp::ClauseRequiresAttr>(reqAttr)) {
438- if (mlir::omp::bitEnumContainsAll (
439- req.getValue (),
440- mlir::omp::ClauseRequires::unified_shared_memory))
441- flags |= mapFlags::OMP_MAP_CLOSE;
442- }
443- }
444- }
445- }
444+ if (moduleRequiresUSM (target->getParentOfType <mlir::ModuleOp>()))
445+ flags |= mapFlags::OMP_MAP_CLOSE;
446446 return llvm::to_underlying (flags);
447447 }
448448
@@ -1258,22 +1258,9 @@ class MapInfoFinalizationPass
12581258 // to ensure device-visible pointer size/behavior in USM scenarios
12591259 // without changing default expectations elsewhere.
12601260 func->walk ([&](mlir::omp::MapInfoOp op) {
1261- // Check module requires USM; otherwise, leave mappings untouched.
1262- auto mod = func->getParentOfType <mlir::ModuleOp>();
1263- bool hasUSM = false ;
1264- if (mod) {
1265- if (mlir::Attribute reqAttr = mod->getAttr (" omp.requires" )) {
1266- if (auto req =
1267- mlir::dyn_cast<mlir::omp::ClauseRequiresAttr>(reqAttr)) {
1268- hasUSM = mlir::omp::bitEnumContainsAll (
1269- req.getValue (),
1270- mlir::omp::ClauseRequires::unified_shared_memory);
1271- }
1272- }
1273- }
1274- if (!hasUSM)
1261+ // Only expand C_PTR members when unified_shared_memory is required.
1262+ if (!moduleRequiresUSM (func->getParentOfType <mlir::ModuleOp>()))
12751263 return ;
1276-
12771264 builder.setInsertionPoint (op);
12781265 genCptrMemberMap (op, builder);
12791266 });
0 commit comments