Skip to content

Commit af4353e

Browse files
committed
Use getGlobalFromSymbol for threadprivate and groupprivate
1 parent 6336f7d commit af4353e

File tree

2 files changed

+34
-18
lines changed

2 files changed

+34
-18
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3564,6 +3564,26 @@ convertOmpCancellationPoint(omp::CancellationPointOp op,
35643564
return success();
35653565
}
35663566

3567+
static LLVM::GlobalOp
3568+
getGlobalFromSymbol(Operation *symOp,
3569+
LLVM::ModuleTranslation &moduleTranslation,
3570+
Operation *opInst) {
3571+
3572+
// Handle potential address space cast
3573+
if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3574+
symOp = asCast.getOperand().getDefiningOp();
3575+
3576+
// Check if we have an AddressOfOp
3577+
if (!isa<LLVM::AddressOfOp>(symOp)) {
3578+
if (opInst)
3579+
opInst->emitError("Addressing symbol not found");
3580+
return nullptr;
3581+
}
3582+
3583+
LLVM::AddressOfOp addressOfOp = cast<LLVM::AddressOfOp>(symOp);
3584+
return addressOfOp.getGlobal(moduleTranslation.symbolTable());
3585+
}
3586+
35673587
/// Converts an OpenMP Threadprivate operation into LLVM IR using
35683588
/// OpenMPIRBuilder.
35693589
static LogicalResult
@@ -3579,15 +3599,10 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
35793599
Value symAddr = threadprivateOp.getSymAddr();
35803600
auto *symOp = symAddr.getDefiningOp();
35813601

3582-
if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3583-
symOp = asCast.getOperand().getDefiningOp();
3584-
3585-
if (!isa<LLVM::AddressOfOp>(symOp))
3586-
return opInst.emitError("Addressing symbol not found");
3587-
LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3588-
35893602
LLVM::GlobalOp global =
3590-
addressOfOp.getGlobal(moduleTranslation.symbolTable());
3603+
getGlobalFromSymbol(symOp, moduleTranslation, &opInst);
3604+
if (!global)
3605+
return failure();
35913606
llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
35923607

35933608
if (!ompBuilder->Config.isTargetDevice()) {
@@ -6161,17 +6176,13 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61616176
}
61626177

61636178
Value symAddr = groupprivateOp.getSymAddr();
6164-
auto *symOp = symAddr.getDefiningOp();
6165-
6166-
if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
6167-
symOp = asCast.getOperand().getDefiningOp();
6168-
6169-
if (!isa<LLVM::AddressOfOp>(symOp))
6170-
return opInst.emitError("Addressing symbol not found");
6171-
LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
6179+
Operation *symOp = symAddr.getDefiningOp();
61726180

61736181
LLVM::GlobalOp global =
6174-
addressOfOp.getGlobal(moduleTranslation.symbolTable());
6182+
getGlobalFromSymbol(symOp, moduleTranslation, &opInst);
6183+
if (!global)
6184+
return failure();
6185+
61756186
llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
61766187
llvm::Value *resultPtr;
61776188

@@ -6192,6 +6203,11 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61926203
if (!resultPtr)
61936204
resultPtr = globalValue;
61946205
}
6206+
6207+
llvm::Type *ptrTy = builder.getPtrTy();
6208+
if (resultPtr->getType() != ptrTy)
6209+
resultPtr = builder.CreateBitCast(resultPtr, ptrTy);
6210+
61956211
moduleTranslation.mapValue(opInst.getResult(0), resultPtr);
61966212
return success();
61976213
}

mlir/test/Target/LLVMIR/omptarget-groupprivate.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ module attributes {omp.is_target_device = true, llvm.target_triple = "amdgcn-amd
3333

3434
// CHECK: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_{{.*}}__QQmain_{{.*}}(ptr %{{.*}}, ptr %{{.*}}) #{{[0-9]+}} {
3535
// CHECK-LABEL: omp.target:
36-
// CHECK-NEXT : %[[LOAD:.*]] = load i32, ptr %3, align 4
36+
// CHECK-NEXT : %[[LOAD:.*]] = load i32, ptr %{{.*}}, align 4
3737
// CHECK-NEXT : %[[ALLOC_any:.*]] = call ptr @__kmpc_alloc_shared(i64 4)
3838
// CHECK-NEXT : store i32 %[[LOAD]], ptr %[[ALLOC_any]], align 4
3939
// CHECK-NEXT : store i32 %[[LOAD]], ptr @global_host, align 4

0 commit comments

Comments
 (0)