Skip to content

Commit fa56e8b

Browse files
authored
[OpenMP][MLIR] Fix threadprivate lowering when compiling for target when target operations are in use (#119310)
Currently the compiler will ICE in programs like the following on the device lowering pass: ``` program main implicit none type i1_t integer :: val(1000) end type i1_t integer :: i type(i1_t), pointer :: newi1 type(i1_t), pointer :: tab=>null() integer, dimension(:), pointer :: tabval !$omp THREADPRIVATE(tab) allocate(newi1) tab=>newi1 tab%val(:)=1 tabval=>tab%val !$omp target teams distribute parallel do do i = 1, 1000 tabval(i) = i end do !$omp end target teams distribute parallel do end program main ``` This is due to the fact that THREADPRIVATE returns a result operation, and this operation can actually be used by other LLVM dialect (or other dialect) operations. However, we currently skip the lowering of threadprivate, so we effectively never generate and bind an LLVM-IR result to the threadprivate operation result. So when we later go on to lower dependent LLVM dialect operations, we are missing the required LLVM-IR result, try to access and use it and then ICE. The fix in this particular PR is to allow compilation of threadprivate for device as well as host, and simply treat the device compilation as a no-op, binding the LLVM-IR result of threadprivate with no alterations and binding it, which will allow the rest of the compilation to proceed, where we'll eventually discard the host segment in any case. The other possible solution to this I can think of, is doing something similar to Flang's passes that occur prior to CodeGen to the LLVM dialect, where they erase/no-op certain unrequired operations or transform them to lower level series of operations. And we would erase/no-op threadprivate on device as we'd never have these in target regions. The main issues I can see with this are that we currently do not specialise this stage based on wether we're compiling for device or host, so it's setting a precedent and adding another point of having to understand the separation between target and host compilation. I am also not sure we'd necessarily want to enforce this at a dialect level incase someone else wishes to add a different lowering flow or translation flow. Another possible issue is that a target operation we have/utilise would depend on the result of threadprivate, meaning we'd not be allowed to entirely erase/no-op it, I am not sure of any situations where this may be an issue currently though.
1 parent c744ed5 commit fa56e8b

File tree

3 files changed

+94
-11
lines changed

3 files changed

+94
-11
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2588,31 +2588,39 @@ static LogicalResult
25882588
convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
25892589
LLVM::ModuleTranslation &moduleTranslation) {
25902590
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2591+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
25912592
auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
25922593

25932594
if (failed(checkImplementationStatus(opInst)))
25942595
return failure();
25952596

25962597
Value symAddr = threadprivateOp.getSymAddr();
25972598
auto *symOp = symAddr.getDefiningOp();
2599+
2600+
if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
2601+
symOp = asCast.getOperand().getDefiningOp();
2602+
25982603
if (!isa<LLVM::AddressOfOp>(symOp))
25992604
return opInst.emitError("Addressing symbol not found");
26002605
LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
26012606

26022607
LLVM::GlobalOp global =
26032608
addressOfOp.getGlobal(moduleTranslation.symbolTable());
26042609
llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
2605-
llvm::Type *type = globalValue->getValueType();
2606-
llvm::TypeSize typeSize =
2607-
builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
2608-
type);
2609-
llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
2610-
llvm::StringRef suffix = llvm::StringRef(".cache", 6);
2611-
std::string cacheName = (Twine(global.getSymName()).concat(suffix)).str();
2612-
llvm::Value *callInst =
2613-
moduleTranslation.getOpenMPBuilder()->createCachedThreadPrivate(
2614-
ompLoc, globalValue, size, cacheName);
2615-
moduleTranslation.mapValue(opInst.getResult(0), callInst);
2610+
2611+
if (!ompBuilder->Config.isTargetDevice()) {
2612+
llvm::Type *type = globalValue->getValueType();
2613+
llvm::TypeSize typeSize =
2614+
builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
2615+
type);
2616+
llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
2617+
llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
2618+
ompLoc, globalValue, size, global.getSymName() + ".cache");
2619+
moduleTranslation.mapValue(opInst.getResult(0), callInst);
2620+
} else {
2621+
moduleTranslation.mapValue(opInst.getResult(0), globalValue);
2622+
}
2623+
26162624
return success();
26172625
}
26182626

@@ -4212,6 +4220,14 @@ static bool isTargetDeviceOp(Operation *op) {
42124220
if (op->getParentOfType<omp::TargetOp>())
42134221
return true;
42144222

4223+
// Certain operations return results, and whether utilised in host or
4224+
// target there is a chance an LLVM Dialect operation depends on it
4225+
// by taking it in as an operand, so we must always lower these in
4226+
// some manner or result in an ICE (whether they end up in a no-op
4227+
// or otherwise).
4228+
if (mlir::isa<omp::ThreadprivateOp>(op))
4229+
return true;
4230+
42154231
if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
42164232
if (auto declareTargetIface =
42174233
llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// Not intended to be a functional example, the aim of this test is to verify
4+
// omp.threadprivate does not crash on lowering during the OpenMP target device
5+
// pass when used in conjunction with target code in the same module.
6+
7+
module attributes {omp.is_target_device = true } {
8+
llvm.func @func() attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>} {
9+
%0 = llvm.mlir.addressof @_QFEpointer2 : !llvm.ptr
10+
%1 = omp.threadprivate %0 : !llvm.ptr -> !llvm.ptr
11+
%2 = omp.map.info var_ptr(%1 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>) map_clauses(implicit, to) capture(ByRef) -> !llvm.ptr
12+
omp.target map_entries(%2 -> %arg0 : !llvm.ptr) {
13+
%3 = llvm.mlir.constant(1 : i32) : i32
14+
%4 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
15+
llvm.store %3, %4 : i32, !llvm.ptr
16+
omp.terminator
17+
}
18+
llvm.return
19+
}
20+
llvm.mlir.global internal @_QFEpointer2() {addr_space = 0 : i32} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {
21+
%0 = llvm.mlir.undef : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
22+
llvm.return %0 : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
23+
}
24+
}
25+
26+
// CHECK: define weak_odr protected void @{{.*}}(ptr %{{.*}}, ptr %[[ARG1:.*]]) {
27+
// CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8
28+
// CHECK: store ptr %[[ARG1]], ptr %[[ALLOCA]], align 8
29+
// CHECK: %[[LOAD_ALLOCA:.*]] = load ptr, ptr %[[ALLOCA]], align 8
30+
// CHECK: store i32 1, ptr %[[LOAD_ALLOCA]], align 4
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
! Basic offloading test that makes sure we can use the predominantly host
2+
! pragma threadprivate in the same program as target code
3+
! REQUIRES: flang, amdgpu
4+
5+
! RUN: %libomptarget-compile-fortran-run-and-check-generic
6+
program main
7+
implicit none
8+
9+
type dtype
10+
integer :: val(10)
11+
end type dtype
12+
13+
integer :: i
14+
type(dtype), pointer :: pointer1
15+
type(dtype), pointer :: pointer2=>null()
16+
integer, dimension(:), pointer :: data_pointer
17+
18+
!$omp threadprivate(pointer2)
19+
20+
nullify(pointer1)
21+
allocate(pointer1)
22+
23+
pointer2=>pointer1
24+
pointer2%val(:)=1
25+
data_pointer=>pointer2%val
26+
27+
!$omp target
28+
do i = 1, 10
29+
data_pointer(i) = i
30+
end do
31+
!$omp end target
32+
33+
print *, data_pointer
34+
35+
end program main
36+
37+
! CHECK: 1 2 3 4 5 6 7 8 9 10

0 commit comments

Comments
 (0)