Skip to content

Commit 33483d5

Browse files
committed
add device_type attr to groupprivate
1 parent b2772c4 commit 33483d5

File tree

3 files changed

+69
-28
lines changed

3 files changed

+69
-28
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,10 +2244,13 @@ def GroupprivateOp : OpenMP_Op<"groupprivate",
22442244
the original variable.
22452245
}];
22462246

2247-
let arguments = (ins OpenMP_PointerLikeType:$sym_addr);
2247+
let arguments = (ins
2248+
OpenMP_PointerLikeType:$sym_addr,
2249+
OptionalAttr<DeclareTargetDeviceTypeAttr>:$device_type
2250+
);
22482251
let results = (outs OpenMP_PointerLikeType:$gp_addr);
22492252
let assemblyFormat = [{
2250-
$sym_addr `:` type($sym_addr) `->` type($gp_addr) attr-dict
2253+
$sym_addr `:` type($sym_addr) ( `,` `device_type` $device_type^ )? `->` type($gp_addr) attr-dict
22512254
}];
22522255
}
22532256

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6137,6 +6137,28 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61376137

61386138
if (failed(checkImplementationStatus(opInst)))
61396139
return failure();
6140+
6141+
bool isTargetDevice = ompBuilder->Config.isTargetDevice();
6142+
auto deviceType = groupprivateOp.getDeviceType();
6143+
6144+
// skip allocation based on device_type
6145+
bool shouldAllocate = true;
6146+
if (deviceType.has_value()) {
6147+
switch (*deviceType) {
6148+
case mlir::omp::DeclareTargetDeviceType::host:
6149+
// Only allocate on host
6150+
shouldAllocate = !isTargetDevice;
6151+
break;
6152+
case mlir::omp::DeclareTargetDeviceType::nohost:
6153+
// Only allocate on device
6154+
shouldAllocate = isTargetDevice;
6155+
break;
6156+
case mlir::omp::DeclareTargetDeviceType::any:
6157+
// Allocate on both
6158+
shouldAllocate = true;
6159+
break;
6160+
}
6161+
}
61406162

61416163
Value symAddr = groupprivateOp.getSymAddr();
61426164
auto *symOp = symAddr.getDefiningOp();
@@ -6151,21 +6173,32 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61516173
LLVM::GlobalOp global =
61526174
addressOfOp.getGlobal(moduleTranslation.symbolTable());
61536175
llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
6176+
llvm::Value *resultPtr;
61546177

6155-
// Get the size of the variable
6156-
llvm::Type *varType = globalValue->getValueType();
6157-
llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6158-
llvm::DataLayout DL = llvmModule->getDataLayout();
6159-
uint64_t typeSize = DL.getTypeAllocSize(varType);
6160-
// Call omp_alloc_shared to allocate memory for groupprivate variable.
6161-
llvm::FunctionCallee allocSharedFn = ompBuilder->getOrCreateRuntimeFunction(
6162-
*llvmModule, llvm::omp::OMPRTL___kmpc_alloc_shared);
6163-
// Call runtime to allocate shared memory for this group
6164-
llvm::Value *groupPrivatePtr =
6165-
builder.CreateCall(allocSharedFn, {builder.getInt64(typeSize)});
6166-
groupPrivatePtr =
6167-
builder.CreateBitCast(groupPrivatePtr, globalValue->getType());
6168-
moduleTranslation.mapValue(opInst.getResult(0), groupPrivatePtr);
6178+
if (shouldAllocate) {
6179+
// Get the size of the variable
6180+
llvm::Type *varType = globalValue->getValueType();
6181+
llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6182+
llvm::DataLayout DL = llvmModule->getDataLayout();
6183+
uint64_t typeSize = DL.getTypeAllocSize(varType);
6184+
// Call omp_alloc_shared to allocate memory for groupprivate variable.
6185+
llvm::FunctionCallee allocSharedFn = ompBuilder->getOrCreateRuntimeFunction(
6186+
*llvmModule, llvm::omp::OMPRTL___kmpc_alloc_shared);
6187+
// Call runtime to allocate shared memory for this group
6188+
llvm::Value *groupPrivatePtr =
6189+
builder.CreateCall(allocSharedFn, {builder.getInt64(typeSize)});
6190+
resultPtr =
6191+
builder.CreateBitCast(groupPrivatePtr, globalValue->getType());
6192+
}
6193+
else {
6194+
// Use original global address when not allocating group-private storage
6195+
resultPtr = moduleTranslation.lookupValue(symAddr);
6196+
if (!resultPtr) {
6197+
// Fallback: create address-of for the global
6198+
resultPtr = builder.CreateBitCast(globalValue, globalValue->getType());
6199+
}
6200+
}
6201+
moduleTranslation.mapValue(opInst.getResult(0), resultPtr);
61696202
return success();
61706203
}
61716204

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3368,22 +3368,27 @@ func.func @omp_target_map_clause_type_test(%arg0 : memref<?xi32>) -> () {
33683368
return
33693369
}
33703370

3371-
3372-
// CHECK-LABEL: func.func @omp_groupprivate
3373-
llvm.mlir.global internal @_QFgpEx() : i32
3374-
func.func @omp_groupprivate() {
3371+
// CHECK-LABEL: func.func @omp_groupprivate_device_type
3372+
func.func @omp_groupprivate_device_type() {
33753373
%0 = arith.constant 1 : i32
33763374
%1 = arith.constant 2 : i32
33773375
// CHECK: [[ARG0:%.*]] = llvm.mlir.addressof @_QFgpEx : !llvm.ptr
33783376
%global_addr = llvm.mlir.addressof @_QFgpEx : !llvm.ptr
3379-
omp.teams {
3380-
// CHECK: {{.*}} = omp.groupprivate [[ARG0]] : !llvm.ptr -> !llvm.ptr
3381-
%group_private_addr_in_teams = omp.groupprivate %global_addr : !llvm.ptr -> !llvm.ptr
3382-
llvm.store %0, %group_private_addr_in_teams : i32, !llvm.ptr
3383-
omp.terminator
3384-
}
3377+
33853378
// CHECK: {{.*}} = omp.groupprivate [[ARG0]] : !llvm.ptr -> !llvm.ptr
3386-
%group_private_addr_after_teams = omp.groupprivate %global_addr : !llvm.ptr -> !llvm.ptr
3387-
llvm.store %1, %group_private_addr_after_teams : i32, !llvm.ptr
3379+
%group_private_addr = omp.groupprivate %global_addr : !llvm.ptr -> !llvm.ptr
3380+
3381+
// CHECK: {{.*}} = omp.groupprivate [[ARG0]] : !llvm.ptr, device_type (any) -> !llvm.ptr
3382+
%group_private_any = omp.groupprivate %global_addr : !llvm.ptr, device_type(any) -> !llvm.ptr
3383+
llvm.store %1, %group_private_any : i32, !llvm.ptr
3384+
3385+
// CHECK: {{.*}} = omp.groupprivate [[ARG0]] : !llvm.ptr, device_type (host) -> !llvm.ptr
3386+
%group_private_host = omp.groupprivate %global_addr : !llvm.ptr, device_type(host) -> !llvm.ptr
3387+
llvm.store %1, %group_private_host : i32, !llvm.ptr
3388+
3389+
// CHECK: {{.*}} = omp.groupprivate [[ARG0]] : !llvm.ptr, device_type (nohost) -> !llvm.ptr
3390+
%group_private_nohost = omp.groupprivate %global_addr : !llvm.ptr, device_type(nohost) -> !llvm.ptr
3391+
llvm.store %1, %group_private_nohost : i32, !llvm.ptr
3392+
33883393
return
33893394
}

0 commit comments

Comments
 (0)