Skip to content

Commit 1343e2e

Browse files
committed
[Offload] Lazily initialize platforms in the Offloading API
Summary: The Offloading library wraps around the underlying plugins. The problem is that we currently initialize all plugins we find, even if they are not needed for the program. This is very expensive for trivial uses, as fully heterogenous usage is quite rare. In practice this means that you will always pay a 200 ms penalty for having CUDA installed. This patch changes the behavior to provide accessors into the plugins and devices that allows them to be initialized lazily. We use a once_flag, this should properly take a fast-path check while still blocking on concurrent use. Making full use of this will require a way to filter platforms more specifically. I'm thinking of what this would look like as an API. I'm thinking that we either have an extra iterate function that takes a callback on the platform, or we just provide a helper to find all the devices that can run a given image. Maybe both? Fixes: #159636
1 parent 8a27b48 commit 1343e2e

File tree

2 files changed

+88
-39
lines changed

2 files changed

+88
-39
lines changed

offload/liboffload/API/Common.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,10 @@ def ol_dimensions_t : Struct {
140140
}
141141

142142
def olInit : Function {
143-
let desc = "Perform initialization of the Offload library and plugins";
143+
let desc = "Perform initialization of the Offload library";
144144
let details = [
145145
"This must be the first API call made by a user of the Offload library",
146+
"The underlying platforms are lazily initialized on their first use"
146147
"Each call will increment an internal reference count that is decremented by `olShutDown`"
147148
];
148149
let params = [];

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 86 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,40 @@ using namespace error;
4242
struct ol_platform_impl_t {
4343
ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
4444
ol_platform_backend_t BackendType)
45-
: Plugin(std::move(Plugin)), BackendType(BackendType) {}
46-
std::unique_ptr<GenericPluginTy> Plugin;
47-
llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
45+
: BackendType(BackendType), Plugin(std::move(Plugin)) {}
4846
ol_platform_backend_t BackendType;
4947

48+
/// Get the plugin, lazily initializing it if necessary.
49+
llvm::Expected<GenericPluginTy *> getPlugin() {
50+
if (llvm::Error Err = init())
51+
return Err;
52+
return Plugin.get();
53+
}
54+
55+
/// Get the device list, lazily initializing it if necessary.
56+
llvm::Expected<llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> &>
57+
getDevices() {
58+
if (llvm::Error Err = init())
59+
return Err;
60+
return Devices;
61+
}
62+
5063
/// Complete all pending work for this platform and perform any needed
5164
/// cleanup.
5265
///
5366
/// After calling this function, no liboffload functions should be called with
5467
/// this platform handle.
5568
llvm::Error destroy();
69+
70+
/// Initialize the associated plugin and devices.
71+
llvm::Error init();
72+
73+
/// Direct access to the plugin, may be uninitialized if accessed here.
74+
std::unique_ptr<GenericPluginTy> Plugin;
75+
76+
private:
77+
std::once_flag Initialized;
78+
llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
5679
};
5780

5881
// Handle type definitions. Ideally these would be 1:1 with the plugins, but
@@ -130,6 +153,39 @@ llvm::Error ol_platform_impl_t::destroy() {
130153
return Result;
131154
}
132155

156+
llvm::Error ol_platform_impl_t::init() {
157+
std::unique_ptr<llvm::Error> Storage;
158+
159+
// This can be called concurrently, make sure we only do the actual
160+
// initialization once.
161+
std::call_once(Initialized, [&]() {
162+
// FIXME: Need better handling for the host platform.
163+
if (!Plugin)
164+
return;
165+
166+
llvm::Error Err = Plugin->init();
167+
if (Err) {
168+
Storage = std::make_unique<llvm::Error>(std::move(Err));
169+
return;
170+
}
171+
172+
for (auto DevNum = 0; DevNum < Plugin->number_of_devices(); DevNum++) {
173+
if (Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
174+
auto Device = &Plugin->getDevice(DevNum);
175+
auto Info = Device->obtainInfoImpl();
176+
if (llvm::Error Err = Info.takeError()) {
177+
Storage = std::make_unique<llvm::Error>(std::move(Err));
178+
return;
179+
}
180+
Devices.emplace_back(std::make_unique<ol_device_impl_t>(
181+
DevNum, Device, *this, std::move(*Info)));
182+
}
183+
}
184+
});
185+
186+
return Storage ? std::move(*Storage) : llvm::Error::success();
187+
}
188+
133189
struct ol_queue_impl_t {
134190
ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
135191
: AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {}
@@ -209,13 +265,9 @@ struct OffloadContext {
209265
// key in AllocInfoMap
210266
llvm::SmallVector<void *> AllocBases{};
211267
SmallVector<std::unique_ptr<ol_platform_impl_t>, 4> Platforms{};
268+
ol_device_handle_t HostDevice;
212269
size_t RefCount;
213270

214-
ol_device_handle_t HostDevice() {
215-
// The host platform is always inserted last
216-
return Platforms.back()->Devices[0].get();
217-
}
218-
219271
static OffloadContext &get() {
220272
assert(OffloadContextVal);
221273
return *OffloadContextVal;
@@ -259,28 +311,16 @@ Error initPlugins(OffloadContext &Context) {
259311
} while (false);
260312
#include "Shared/Targets.def"
261313

262-
// Preemptively initialize all devices in the plugin
263-
for (auto &Platform : Context.Platforms) {
264-
auto Err = Platform->Plugin->init();
265-
[[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
266-
for (auto DevNum = 0; DevNum < Platform->Plugin->number_of_devices();
267-
DevNum++) {
268-
if (Platform->Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
269-
auto Device = &Platform->Plugin->getDevice(DevNum);
270-
auto Info = Device->obtainInfoImpl();
271-
if (auto Err = Info.takeError())
272-
return Err;
273-
Platform->Devices.emplace_back(std::make_unique<ol_device_impl_t>(
274-
DevNum, Device, *Platform, std::move(*Info)));
275-
}
276-
}
277-
}
278-
279314
// Add the special host device
280315
auto &HostPlatform = Context.Platforms.emplace_back(
281316
std::make_unique<ol_platform_impl_t>(nullptr, OL_PLATFORM_BACKEND_HOST));
282-
HostPlatform->Devices.emplace_back(std::make_unique<ol_device_impl_t>(
283-
-1, nullptr, *HostPlatform, InfoTreeNode{}));
317+
auto DevicesOrErr = HostPlatform->getDevices();
318+
if (!DevicesOrErr)
319+
return DevicesOrErr.takeError();
320+
Context.HostDevice = DevicesOrErr
321+
->emplace_back(std::make_unique<ol_device_impl_t>(
322+
-1, nullptr, *HostPlatform, InfoTreeNode{}))
323+
.get();
284324

285325
Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
286326
Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
@@ -315,12 +355,12 @@ Error olShutDown_impl() {
315355
llvm::Error Result = Error::success();
316356
auto *OldContext = OffloadContextVal.exchange(nullptr);
317357

318-
for (auto &P : OldContext->Platforms) {
358+
for (auto &Platform : OldContext->Platforms) {
319359
// Host plugin is nullptr and has no deinit
320-
if (!P->Plugin || !P->Plugin->is_initialized())
360+
if (!Platform->Plugin || !Platform->Plugin->is_initialized())
321361
continue;
322362

323-
if (auto Res = P->destroy())
363+
if (auto Res = Platform->destroy())
324364
Result = llvm::joinErrors(std::move(Result), std::move(Res));
325365
}
326366

@@ -334,9 +374,14 @@ Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
334374
InfoWriter Info(PropSize, PropValue, PropSizeRet);
335375
bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST;
336376

377+
auto PluginOrErr = Platform->getPlugin();
378+
if (!PluginOrErr)
379+
return PluginOrErr.takeError();
380+
GenericPluginTy *Plugin = *PluginOrErr;
381+
337382
switch (PropName) {
338383
case OL_PLATFORM_INFO_NAME:
339-
return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName());
384+
return Info.writeString(IsHost ? "Host" : Plugin->getName());
340385
case OL_PLATFORM_INFO_VENDOR_NAME:
341386
// TODO: Implement this
342387
return Info.writeString("Unknown platform vendor");
@@ -373,7 +418,7 @@ Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
373418
Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
374419
ol_device_info_t PropName, size_t PropSize,
375420
void *PropValue, size_t *PropSizeRet) {
376-
assert(Device != OffloadContext::get().HostDevice());
421+
assert(Device != OffloadContext::get().HostDevice);
377422
InfoWriter Info(PropSize, PropValue, PropSizeRet);
378423

379424
auto makeError = [&](ErrorCode Code, StringRef Err) {
@@ -511,7 +556,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
511556
Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
512557
ol_device_info_t PropName, size_t PropSize,
513558
void *PropValue, size_t *PropSizeRet) {
514-
assert(Device == OffloadContext::get().HostDevice());
559+
assert(Device == OffloadContext::get().HostDevice);
515560
InfoWriter Info(PropSize, PropValue, PropSizeRet);
516561

517562
constexpr auto uint32_max = std::numeric_limits<uint32_t>::max();
@@ -579,7 +624,7 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
579624

580625
Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
581626
size_t PropSize, void *PropValue) {
582-
if (Device == OffloadContext::get().HostDevice())
627+
if (Device == OffloadContext::get().HostDevice)
583628
return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue,
584629
nullptr);
585630
return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
@@ -588,17 +633,20 @@ Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
588633

589634
Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
590635
ol_device_info_t PropName, size_t *PropSizeRet) {
591-
if (Device == OffloadContext::get().HostDevice())
636+
if (Device == OffloadContext::get().HostDevice)
592637
return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr,
593638
PropSizeRet);
594639
return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
595640
}
596641

597642
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
598643
for (auto &Platform : OffloadContext::get().Platforms) {
599-
for (auto &Device : Platform->Devices) {
644+
auto DevicesOrErr = Platform->getDevices();
645+
if (!DevicesOrErr)
646+
return DevicesOrErr.takeError();
647+
for (auto &Device : *DevicesOrErr) {
600648
if (!Callback(Device.get(), UserData)) {
601-
break;
649+
return Error::success();
602650
}
603651
}
604652
}
@@ -949,7 +997,7 @@ Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) {
949997
Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
950998
ol_device_handle_t DstDevice, const void *SrcPtr,
951999
ol_device_handle_t SrcDevice, size_t Size) {
952-
auto Host = OffloadContext::get().HostDevice();
1000+
auto Host = OffloadContext::get().HostDevice;
9531001
if (DstDevice == Host && SrcDevice == Host) {
9541002
if (!Queue) {
9551003
std::memcpy(DstPtr, SrcPtr, Size);

0 commit comments

Comments
 (0)