Skip to content

Commit 5d550bf

Browse files
authored
[OpenMP] Move `__omp_rtl_data_environment' handling to OpenMP (#157182)
Summary: This operation is done every time we load a binary, this behavior should be moved into OpenMP since it concerns an OpenMP specific data struct. This is a little messy, because ideally we should only be using public APIs, but more can be extracted later.
1 parent 68b98bb commit 5d550bf

File tree

5 files changed

+88
-104
lines changed

5 files changed

+88
-104
lines changed

offload/include/device.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
#include "llvm/ADT/DenseMap.h"
3434
#include "llvm/ADT/SmallVector.h"
3535

36+
#include "GlobalHandler.h"
3637
#include "PluginInterface.h"
38+
3739
using GenericPluginTy = llvm::omp::target::plugin::GenericPluginTy;
3840

3941
// Forward declarations.

offload/libomptarget/device.cpp

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
using namespace llvm::omp::target::ompt;
3838
#endif
3939

40+
using namespace llvm::omp::target::plugin;
41+
4042
int HostDataToTargetTy::addEventIfNecessary(DeviceTy &Device,
4143
AsyncInfoTy &AsyncInfo) const {
4244
// First, check if the user disabled atomic map transfer/malloc/dealloc.
@@ -97,14 +99,94 @@ llvm::Error DeviceTy::init() {
9799
return llvm::Error::success();
98100
}
99101

100-
// Load binary to device.
102+
// Extract the mapping of host function pointers to device function pointers
103+
// from the entry table. Functions marked as 'indirect' in OpenMP will have
104+
// offloading entries generated for them which map the host's function pointer
105+
// to a global containing the corresponding function pointer on the device.
106+
static llvm::Expected<std::pair<void *, uint64_t>>
107+
setupIndirectCallTable(DeviceTy &Device, __tgt_device_image *Image,
108+
__tgt_device_binary Binary) {
109+
AsyncInfoTy AsyncInfo(Device);
110+
llvm::ArrayRef<llvm::offloading::EntryTy> Entries(Image->EntriesBegin,
111+
Image->EntriesEnd);
112+
llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable;
113+
for (const auto &Entry : Entries) {
114+
if (Entry.Kind != llvm::object::OffloadKind::OFK_OpenMP ||
115+
Entry.Size == 0 || !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT))
116+
continue;
117+
118+
assert(Entry.Size == sizeof(void *) && "Global not a function pointer?");
119+
auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
120+
121+
void *Ptr;
122+
if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName, &Ptr))
123+
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
124+
"failed to load %s", Entry.SymbolName);
125+
126+
HstPtr = Entry.Address;
127+
if (Device.retrieveData(&DevPtr, Ptr, Entry.Size, AsyncInfo))
128+
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
129+
"failed to load %s", Entry.SymbolName);
130+
}
131+
132+
// If we do not have any indirect globals we exit early.
133+
if (IndirectCallTable.empty())
134+
return std::pair{nullptr, 0};
135+
136+
// Sort the array to allow for more efficient lookup of device pointers.
137+
llvm::sort(IndirectCallTable,
138+
[](const auto &x, const auto &y) { return x.first < y.first; });
139+
140+
uint64_t TableSize =
141+
IndirectCallTable.size() * sizeof(std::pair<void *, void *>);
142+
void *DevicePtr = Device.allocData(TableSize, nullptr, TARGET_ALLOC_DEVICE);
143+
if (Device.submitData(DevicePtr, IndirectCallTable.data(), TableSize,
144+
AsyncInfo))
145+
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
146+
"failed to copy data");
147+
return std::pair<void *, uint64_t>(DevicePtr, IndirectCallTable.size());
148+
}
149+
150+
// Load binary to device and perform global initialization if needed.
101151
llvm::Expected<__tgt_device_binary>
102152
DeviceTy::loadBinary(__tgt_device_image *Img) {
103153
__tgt_device_binary Binary;
104154

105155
if (RTL->load_binary(RTLDeviceID, Img, &Binary) != OFFLOAD_SUCCESS)
106156
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
107157
"failed to load binary %p", Img);
158+
159+
// This symbol is optional.
160+
void *DeviceEnvironmentPtr;
161+
if (RTL->get_global(Binary, sizeof(DeviceEnvironmentTy),
162+
"__omp_rtl_device_environment", &DeviceEnvironmentPtr))
163+
return Binary;
164+
165+
// Obtain a table mapping host function pointers to device function pointers.
166+
auto CallTablePairOrErr = setupIndirectCallTable(*this, Img, Binary);
167+
if (!CallTablePairOrErr)
168+
return CallTablePairOrErr.takeError();
169+
170+
GenericDeviceTy &GenericDevice = RTL->getDevice(RTLDeviceID);
171+
DeviceEnvironmentTy DeviceEnvironment;
172+
DeviceEnvironment.DeviceDebugKind = GenericDevice.getDebugKind();
173+
DeviceEnvironment.NumDevices = RTL->getNumDevices();
174+
// TODO: The device ID used here is not the real device ID used by OpenMP.
175+
DeviceEnvironment.DeviceNum = RTLDeviceID;
176+
DeviceEnvironment.DynamicMemSize = GenericDevice.getDynamicMemorySize();
177+
DeviceEnvironment.ClockFrequency = GenericDevice.getClockFrequency();
178+
DeviceEnvironment.IndirectCallTable =
179+
reinterpret_cast<uintptr_t>(CallTablePairOrErr->first);
180+
DeviceEnvironment.IndirectCallTableSize = CallTablePairOrErr->second;
181+
DeviceEnvironment.HardwareParallelism =
182+
GenericDevice.getHardwareParallelism();
183+
184+
AsyncInfoTy AsyncInfo(*this);
185+
if (submitData(DeviceEnvironmentPtr, &DeviceEnvironment,
186+
sizeof(DeviceEnvironment), AsyncInfo))
187+
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
188+
"failed to copy data");
189+
108190
return Binary;
109191
}
110192

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -839,11 +839,6 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
839839
Error unloadBinary(DeviceImageTy *Image);
840840
virtual Error unloadBinaryImpl(DeviceImageTy *Image) = 0;
841841

842-
/// Setup the device environment if needed. Notice this setup may not be run
843-
/// on some plugins. By default, it will be executed, but plugins can change
844-
/// this behavior by overriding the shouldSetupDeviceEnvironment function.
845-
Error setupDeviceEnvironment(GenericPluginTy &Plugin, DeviceImageTy &Image);
846-
847842
/// Setup the global device memory pool, if the plugin requires one.
848843
Error setupDeviceMemoryPool(GenericPluginTy &Plugin, DeviceImageTy &Image,
849844
uint64_t PoolSize);
@@ -1043,6 +1038,7 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
10431038
uint32_t getDefaultNumBlocks() const {
10441039
return GridValues.GV_Default_Num_Teams;
10451040
}
1041+
uint32_t getDebugKind() const { return OMPX_DebugKind; }
10461042
uint32_t getDynamicMemorySize() const { return OMPX_SharedMemorySize; }
10471043
virtual uint64_t getClockFrequency() const { return CLOCKS_PER_SEC; }
10481044

@@ -1183,11 +1179,6 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
11831179
virtual Error getDeviceHeapSize(uint64_t &V) = 0;
11841180
virtual Error setDeviceHeapSize(uint64_t V) = 0;
11851181

1186-
/// Indicate whether the device should setup the device environment. Notice
1187-
/// that returning false in this function will change the behavior of the
1188-
/// setupDeviceEnvironment() function.
1189-
virtual bool shouldSetupDeviceEnvironment() const { return true; }
1190-
11911182
/// Indicate whether the device should setup the global device memory pool. If
11921183
/// false is return the value on the device will be uninitialized.
11931184
virtual bool shouldSetupDeviceMemoryPool() const { return true; }
@@ -1243,7 +1234,7 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
12431234
enum class PeerAccessState : uint8_t { AVAILABLE, UNAVAILABLE, PENDING };
12441235

12451236
/// Array of peer access states with the rest of devices. This means that if
1246-
/// the device I has a matrix PeerAccesses with PeerAccesses[J] == AVAILABLE,
1237+
/// the device I has a matrix PeerAccesses with PeerAccesses == AVAILABLE,
12471238
/// the device I can access device J's memory directly. However, notice this
12481239
/// does not mean that device J can access device I's memory directly.
12491240
llvm::SmallVector<PeerAccessState> PeerAccesses;

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 1 addition & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -371,54 +371,6 @@ struct RecordReplayTy {
371371
};
372372
} // namespace llvm::omp::target::plugin
373373

374-
// Extract the mapping of host function pointers to device function pointers
375-
// from the entry table. Functions marked as 'indirect' in OpenMP will have
376-
// offloading entries generated for them which map the host's function pointer
377-
// to a global containing the corresponding function pointer on the device.
378-
static Expected<std::pair<void *, uint64_t>>
379-
setupIndirectCallTable(GenericPluginTy &Plugin, GenericDeviceTy &Device,
380-
DeviceImageTy &Image) {
381-
GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
382-
383-
llvm::ArrayRef<llvm::offloading::EntryTy> Entries(
384-
Image.getTgtImage()->EntriesBegin, Image.getTgtImage()->EntriesEnd);
385-
llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable;
386-
for (const auto &Entry : Entries) {
387-
if (Entry.Kind != object::OffloadKind::OFK_OpenMP || Entry.Size == 0 ||
388-
!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT))
389-
continue;
390-
391-
assert(Entry.Size == sizeof(void *) && "Global not a function pointer?");
392-
auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
393-
394-
GlobalTy DeviceGlobal(Entry.SymbolName, Entry.Size);
395-
if (auto Err =
396-
Handler.getGlobalMetadataFromDevice(Device, Image, DeviceGlobal))
397-
return std::move(Err);
398-
399-
HstPtr = Entry.Address;
400-
if (auto Err = Device.dataRetrieve(&DevPtr, DeviceGlobal.getPtr(),
401-
Entry.Size, nullptr))
402-
return std::move(Err);
403-
}
404-
405-
// If we do not have any indirect globals we exit early.
406-
if (IndirectCallTable.empty())
407-
return std::pair{nullptr, 0};
408-
409-
// Sort the array to allow for more efficient lookup of device pointers.
410-
llvm::sort(IndirectCallTable,
411-
[](const auto &x, const auto &y) { return x.first < y.first; });
412-
413-
uint64_t TableSize =
414-
IndirectCallTable.size() * sizeof(std::pair<void *, void *>);
415-
void *DevicePtr = Device.allocate(TableSize, nullptr, TARGET_ALLOC_DEVICE);
416-
if (auto Err = Device.dataSubmit(DevicePtr, IndirectCallTable.data(),
417-
TableSize, nullptr))
418-
return std::move(Err);
419-
return std::pair<void *, uint64_t>(DevicePtr, IndirectCallTable.size());
420-
}
421-
422374
AsyncInfoWrapperTy::AsyncInfoWrapperTy(GenericDeviceTy &Device,
423375
__tgt_async_info *AsyncInfoPtr)
424376
: Device(Device),
@@ -943,10 +895,6 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
943895
// Add the image to list.
944896
LoadedImages.push_back(Image);
945897

946-
// Setup the device environment if needed.
947-
if (auto Err = setupDeviceEnvironment(Plugin, *Image))
948-
return std::move(Err);
949-
950898
// Setup the global device memory pool if needed.
951899
if (!Plugin.getRecordReplay().isReplaying() &&
952900
shouldSetupDeviceMemoryPool()) {
@@ -982,43 +930,6 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
982930
return Image;
983931
}
984932

985-
Error GenericDeviceTy::setupDeviceEnvironment(GenericPluginTy &Plugin,
986-
DeviceImageTy &Image) {
987-
// There are some plugins that do not need this step.
988-
if (!shouldSetupDeviceEnvironment())
989-
return Plugin::success();
990-
991-
// Obtain a table mapping host function pointers to device function pointers.
992-
auto CallTablePairOrErr = setupIndirectCallTable(Plugin, *this, Image);
993-
if (!CallTablePairOrErr)
994-
return CallTablePairOrErr.takeError();
995-
996-
DeviceEnvironmentTy DeviceEnvironment;
997-
DeviceEnvironment.DeviceDebugKind = OMPX_DebugKind;
998-
DeviceEnvironment.NumDevices = Plugin.getNumDevices();
999-
// TODO: The device ID used here is not the real device ID used by OpenMP.
1000-
DeviceEnvironment.DeviceNum = DeviceId;
1001-
DeviceEnvironment.DynamicMemSize = OMPX_SharedMemorySize;
1002-
DeviceEnvironment.ClockFrequency = getClockFrequency();
1003-
DeviceEnvironment.IndirectCallTable =
1004-
reinterpret_cast<uintptr_t>(CallTablePairOrErr->first);
1005-
DeviceEnvironment.IndirectCallTableSize = CallTablePairOrErr->second;
1006-
DeviceEnvironment.HardwareParallelism = getHardwareParallelism();
1007-
1008-
// Create the metainfo of the device environment global.
1009-
GlobalTy DevEnvGlobal("__omp_rtl_device_environment",
1010-
sizeof(DeviceEnvironmentTy), &DeviceEnvironment);
1011-
1012-
// Write device environment values to the device.
1013-
GenericGlobalHandlerTy &GHandler = Plugin.getGlobalHandler();
1014-
if (auto Err = GHandler.writeGlobalToDevice(*this, Image, DevEnvGlobal)) {
1015-
DP("Missing symbol %s, continue execution anyway.\n",
1016-
DevEnvGlobal.getName().data());
1017-
consumeError(std::move(Err));
1018-
}
1019-
return Plugin::success();
1020-
}
1021-
1022933
Error GenericDeviceTy::setupDeviceMemoryPool(GenericPluginTy &Plugin,
1023934
DeviceImageTy &Image,
1024935
uint64_t PoolSize) {
@@ -2259,8 +2170,7 @@ int32_t GenericPluginTy::get_global(__tgt_device_binary Binary, uint64_t Size,
22592170
GenericGlobalHandlerTy &GHandler = getGlobalHandler();
22602171
if (auto Err =
22612172
GHandler.getGlobalMetadataFromDevice(Device, Image, DeviceGlobal)) {
2262-
REPORT("Failure to look up global address: %s\n",
2263-
toString(std::move(Err)).data());
2173+
consumeError(std::move(Err));
22642174
return OFFLOAD_FAIL;
22652175
}
22662176

offload/plugins-nextgen/host/src/rtl.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
387387
}
388388

389389
/// This plugin should not setup the device environment or memory pool.
390-
virtual bool shouldSetupDeviceEnvironment() const override { return false; };
391390
virtual bool shouldSetupDeviceMemoryPool() const override { return false; };
392391

393392
/// Getters and setters for stack size and heap size not relevant.

0 commit comments

Comments
 (0)