Skip to content

Commit 6f16f7d

Browse files
committed
Use getGlobalFromSymbol for threadprivate and groupprivate
1 parent 699220c commit 6f16f7d

File tree

2 files changed

+34
-18
lines changed

2 files changed

+34
-18
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3564,6 +3564,26 @@ convertOmpCancellationPoint(omp::CancellationPointOp op,
35643564
return success();
35653565
}
35663566

3567+
static LLVM::GlobalOp
3568+
getGlobalFromSymbol(Operation *symOp,
3569+
LLVM::ModuleTranslation &moduleTranslation,
3570+
Operation *opInst) {
3571+
3572+
// Handle potential address space cast
3573+
if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3574+
symOp = asCast.getOperand().getDefiningOp();
3575+
3576+
// Check if we have an AddressOfOp
3577+
if (!isa<LLVM::AddressOfOp>(symOp)) {
3578+
if (opInst)
3579+
opInst->emitError("Addressing symbol not found");
3580+
return nullptr;
3581+
}
3582+
3583+
LLVM::AddressOfOp addressOfOp = cast<LLVM::AddressOfOp>(symOp);
3584+
return addressOfOp.getGlobal(moduleTranslation.symbolTable());
3585+
}
3586+
35673587
/// Converts an OpenMP Threadprivate operation into LLVM IR using
35683588
/// OpenMPIRBuilder.
35693589
static LogicalResult
@@ -3579,15 +3599,10 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
35793599
Value symAddr = threadprivateOp.getSymAddr();
35803600
auto *symOp = symAddr.getDefiningOp();
35813601

3582-
if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3583-
symOp = asCast.getOperand().getDefiningOp();
3584-
3585-
if (!isa<LLVM::AddressOfOp>(symOp))
3586-
return opInst.emitError("Addressing symbol not found");
3587-
LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3588-
35893602
LLVM::GlobalOp global =
3590-
addressOfOp.getGlobal(moduleTranslation.symbolTable());
3603+
getGlobalFromSymbol(symOp, moduleTranslation, &opInst);
3604+
if (!global)
3605+
return failure();
35913606
llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
35923607

35933608
if (!ompBuilder->Config.isTargetDevice()) {
@@ -6144,17 +6159,13 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61446159
}
61456160

61466161
Value symAddr = groupprivateOp.getSymAddr();
6147-
auto *symOp = symAddr.getDefiningOp();
6148-
6149-
if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
6150-
symOp = asCast.getOperand().getDefiningOp();
6151-
6152-
if (!isa<LLVM::AddressOfOp>(symOp))
6153-
return opInst.emitError("Addressing symbol not found");
6154-
LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
6162+
Operation *symOp = symAddr.getDefiningOp();
61556163

61566164
LLVM::GlobalOp global =
6157-
addressOfOp.getGlobal(moduleTranslation.symbolTable());
6165+
getGlobalFromSymbol(symOp, moduleTranslation, &opInst);
6166+
if (!global)
6167+
return failure();
6168+
61586169
llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
61596170
llvm::Value *resultPtr;
61606171

@@ -6175,6 +6186,11 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61756186
if (!resultPtr)
61766187
resultPtr = globalValue;
61776188
}
6189+
6190+
llvm::Type *ptrTy = builder.getPtrTy();
6191+
if (resultPtr->getType() != ptrTy)
6192+
resultPtr = builder.CreateBitCast(resultPtr, ptrTy);
6193+
61786194
moduleTranslation.mapValue(opInst.getResult(0), resultPtr);
61796195
return success();
61806196
}

mlir/test/Target/LLVMIR/omptarget-groupprivate.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ module attributes {omp.is_target_device = true, llvm.target_triple = "amdgcn-amd
3333

3434
// CHECK: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_{{.*}}__QQmain_{{.*}}(ptr %{{.*}}, ptr %{{.*}}) #{{[0-9]+}} {
3535
// CHECK-LABEL: omp.target:
36-
// CHECK-NEXT : %[[LOAD:.*]] = load i32, ptr %3, align 4
36+
// CHECK-NEXT : %[[LOAD:.*]] = load i32, ptr %{{.*}}, align 4
3737
// CHECK-NEXT : %[[ALLOC_any:.*]] = call ptr @__kmpc_alloc_shared(i64 4)
3838
// CHECK-NEXT : store i32 %[[LOAD]], ptr %[[ALLOC_any]], align 4
3939
// CHECK-NEXT : store i32 %[[LOAD]], ptr @global_host, align 4

0 commit comments

Comments
 (0)