Skip to content

Commit cbe939d

Browse files
[OpenMP][offload] Register Vtables runtime support for indirect calls
- Modify PluginInterface to register Vtables to indirect call table
1 parent c2aea1e commit cbe939d

File tree

4 files changed

+164
-15
lines changed

4 files changed

+164
-15
lines changed

offload/include/omptarget.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ enum OpenMPOffloadingDeclareTargetFlags {
9494
OMP_DECLARE_TARGET_INDIRECT = 0x08,
9595
/// This is an entry corresponding to a requirement to be registered.
9696
OMP_REGISTER_REQUIRES = 0x10,
97+
/// Mark the entry global as being an indirect vtable.
98+
OMP_DECLARE_TARGET_INDIRECT_VTABLE = 0x20,
9799
};
98100

99101
enum TargetAllocTy : int32_t {

offload/libomptarget/PluginManager.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,8 @@ static int loadImagesOntoDevice(DeviceTy &Device) {
434434

435435
llvm::offloading::EntryTy DeviceEntry = Entry;
436436
if (Entry.Size) {
437-
if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName,
437+
if (!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE) &&
438+
Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName,
438439
&DeviceEntry.Address) != OFFLOAD_SUCCESS)
439440
REPORT("Failed to load symbol %s\n", Entry.SymbolName);
440441

@@ -443,7 +444,9 @@ static int loadImagesOntoDevice(DeviceTy &Device) {
443444
// the device to point to the memory on the host.
444445
if ((PM->getRequirements() & OMP_REQ_UNIFIED_SHARED_MEMORY) ||
445446
(PM->getRequirements() & OMPX_REQ_AUTO_ZERO_COPY)) {
446-
if (Device.RTL->data_submit(DeviceId, DeviceEntry.Address,
447+
if (!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE) &&
448+
!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT) &&
449+
Device.RTL->data_submit(DeviceId, DeviceEntry.Address,
447450
Entry.Address,
448451
Entry.Size) != OFFLOAD_SUCCESS)
449452
REPORT("Failed to write symbol for USM %s\n", Entry.SymbolName);

offload/libomptarget/device.cpp

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -112,21 +112,58 @@ setupIndirectCallTable(DeviceTy &Device, __tgt_device_image *Image,
112112
llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable;
113113
for (const auto &Entry : Entries) {
114114
if (Entry.Kind != llvm::object::OffloadKind::OFK_OpenMP ||
115-
Entry.Size == 0 || !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT))
115+
Entry.Size == 0 ||
116+
(!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT) &&
117+
!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE)))
116118
continue;
117119

118-
assert(Entry.Size == sizeof(void *) && "Global not a function pointer?");
119-
auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
120-
121-
void *Ptr;
122-
if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName, &Ptr))
123-
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
124-
"failed to load %s", Entry.SymbolName);
125-
126-
HstPtr = Entry.Address;
127-
if (Device.retrieveData(&DevPtr, Ptr, Entry.Size, AsyncInfo))
128-
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
129-
"failed to load %s", Entry.SymbolName);
120+
size_t PtrSize = sizeof(void *);
121+
if (Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE) {
122+
// This is a VTable entry, the current entry is the first index of the
123+
// VTable and Entry.Size is the total size of the VTable. Unlike the
124+
// indirect function case below, the Global is not of size Entry.Size and
125+
// is instead of size PtrSize (sizeof(void*)).
126+
void *Vtable;
127+
void *res;
128+
if (Device.RTL->get_global(Binary, PtrSize, Entry.SymbolName, &Vtable))
129+
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
130+
"failed to load %s", Entry.SymbolName);
131+
132+
// HstPtr = Entry.Address;
133+
if (Device.retrieveData(&res, Vtable, PtrSize, AsyncInfo))
134+
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
135+
"failed to load %s", Entry.SymbolName);
136+
if (Device.synchronize(AsyncInfo))
137+
return error::createOffloadError(
138+
error::ErrorCode::INVALID_BINARY,
139+
"failed to synchronize after retrieving %s", Entry.SymbolName);
140+
// Calculate and emplace entire Vtable from first Vtable byte
141+
for (uint64_t i = 0; i < Entry.Size / PtrSize; ++i) {
142+
auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
143+
HstPtr = reinterpret_cast<void *>(
144+
reinterpret_cast<uintptr_t>(Entry.Address) + i * PtrSize);
145+
DevPtr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(res) +
146+
i * PtrSize);
147+
}
148+
} else {
149+
// Indirect function case: Entry.Size should equal PtrSize since we're
150+
// dealing with a single function pointer (not a VTable)
151+
assert(Entry.Size == PtrSize && "Global not a function pointer?");
152+
auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
153+
void *Ptr;
154+
if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName, &Ptr))
155+
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
156+
"failed to load %s", Entry.SymbolName);
157+
158+
HstPtr = Entry.Address;
159+
if (Device.retrieveData(&DevPtr, Ptr, Entry.Size, AsyncInfo))
160+
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
161+
"failed to load %s", Entry.SymbolName);
162+
}
163+
if (Device.synchronize(AsyncInfo))
164+
return error::createOffloadError(
165+
error::ErrorCode::INVALID_BINARY,
166+
"failed to synchronize after retrieving %s", Entry.SymbolName);
130167
}
131168

132169
// If we do not have any indirect globals we exit early.
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// RUN: %libomptarget-compile-run-and-check-generic
2+
#include <assert.h>
3+
#include <omp.h>
4+
#include <stdio.h>
5+
6+
// ---------------------------------------------------------------------------
7+
// Various definitions copied from OpenMP RTL
8+
9+
typedef struct {
10+
uint64_t Reserved;
11+
uint16_t Version;
12+
uint16_t Kind; // OpenMP==1
13+
uint32_t Flags;
14+
void *Address;
15+
char *SymbolName;
16+
uint64_t Size;
17+
uint64_t Data;
18+
void *AuxAddr;
19+
} __tgt_offload_entry;
20+
21+
enum OpenMPOffloadingDeclareTargetFlags {
22+
/// Mark the entry global as having a 'link' attribute.
23+
OMP_DECLARE_TARGET_LINK = 0x01,
24+
/// Mark the entry global as being an indirectly callable function.
25+
OMP_DECLARE_TARGET_INDIRECT = 0x08,
26+
/// This is an entry corresponding to a requirement to be registered.
27+
OMP_REGISTER_REQUIRES = 0x10,
28+
/// Mark the entry global as being an indirect vtable.
29+
OMP_DECLARE_TARGET_INDIRECT_VTABLE = 0x20,
30+
};
31+
32+
#pragma omp begin declare variant match(device = {kind(gpu)})
33+
// Provided by the runtime.
34+
void *__llvm_omp_indirect_call_lookup(void *host_ptr);
35+
#pragma omp declare target to(__llvm_omp_indirect_call_lookup) \
36+
device_type(nohost)
37+
#pragma omp end declare variant
38+
39+
#pragma omp begin declare variant match(device = {kind(cpu)})
40+
// We assume unified addressing on the CPU target.
41+
void *__llvm_omp_indirect_call_lookup(void *host_ptr) { return host_ptr; }
42+
#pragma omp end declare variant
43+
44+
#pragma omp begin declare target
45+
void foo(int *i) { *i += 1; }
46+
void bar(int *i) { *i += 10; }
47+
void baz(int *i) { *i += 100; }
48+
#pragma omp end declare target
49+
50+
typedef void (*fptr_t)(int *i);
51+
52+
// Dispatch Table - declare separately on host and device to avoid
53+
// registering with the library; this also allows us to use separate
54+
// names, which is convenient for debugging. This dispatchTable is
55+
// intended to mimic what Clang emits for C++ vtables.
56+
fptr_t dispatchTable[] = {foo, bar, baz};
57+
#pragma omp begin declare target device_type(nohost)
58+
fptr_t GPUdispatchTable[] = {foo, bar, baz};
59+
fptr_t *GPUdispatchTablePtr = GPUdispatchTable;
60+
#pragma omp end declare target
61+
62+
// Define "manual" OpenMP offload entries, where we emit Clang
63+
// offloading entry structure definitions in the appropriate ELF
64+
// section. This allows us to emulate the offloading entries that Clang would
65+
// normally emit for us
66+
67+
__attribute__((weak, section("llvm_offload_entries"), aligned(8)))
68+
const __tgt_offload_entry __offloading_entry[] = {{
69+
0ULL, // Reserved
70+
1, // Version
71+
1, // Kind
72+
OMP_DECLARE_TARGET_INDIRECT_VTABLE, // Flags
73+
&dispatchTable, // Address
74+
"GPUdispatchTablePtr", // SymbolName
75+
(size_t)(sizeof(dispatchTable)), // Size
76+
0ULL, // Data
77+
NULL // AuxAddr
78+
}};
79+
80+
// Mimic how Clang emits vtable pointers for C++ classes
81+
typedef struct {
82+
fptr_t *dispatchPtr;
83+
} myClass;
84+
85+
// ---------------------------------------------------------------------------
86+
int main() {
87+
myClass obj_foo = {dispatchTable + 0};
88+
myClass obj_bar = {dispatchTable + 1};
89+
myClass obj_baz = {dispatchTable + 2};
90+
int aaa = 0;
91+
92+
#pragma omp target map(aaa) map(to : obj_foo, obj_bar, obj_baz)
93+
{
94+
// Lookup
95+
fptr_t *foo_ptr = __llvm_omp_indirect_call_lookup(obj_foo.dispatchPtr);
96+
fptr_t *bar_ptr = __llvm_omp_indirect_call_lookup(obj_bar.dispatchPtr);
97+
fptr_t *baz_ptr = __llvm_omp_indirect_call_lookup(obj_baz.dispatchPtr);
98+
foo_ptr[0](&aaa);
99+
bar_ptr[0](&aaa);
100+
baz_ptr[0](&aaa);
101+
}
102+
103+
assert(aaa == 111);
104+
// CHECK: PASS
105+
printf("PASS\n");
106+
return 0;
107+
}

0 commit comments

Comments
 (0)