Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion offload/liboffload/API/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,10 @@ def ol_dimensions_t : Struct {
}

def olInit : Function {
let desc = "Perform initialization of the Offload library and plugins";
let desc = "Perform initialization of the Offload library";
let details = [
"This must be the first API call made by a user of the Offload library",
"The underlying platforms are lazily initialized on their first use"
"Each call will increment an internal reference count that is decremented by `olShutDown`"
];
let params = [];
Expand Down
122 changes: 85 additions & 37 deletions offload/liboffload/src/OffloadImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,40 @@ using namespace error;
struct ol_platform_impl_t {
ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
ol_platform_backend_t BackendType)
: Plugin(std::move(Plugin)), BackendType(BackendType) {}
std::unique_ptr<GenericPluginTy> Plugin;
llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
: BackendType(BackendType), Plugin(std::move(Plugin)) {}
ol_platform_backend_t BackendType;

/// Get the plugin, lazily initializing it if necessary.
llvm::Expected<GenericPluginTy *> getPlugin() {
if (llvm::Error Err = init())
return Err;
return Plugin.get();
}

/// Get the device list, lazily initializing it if necessary.
llvm::Expected<llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> &>
getDevices() {
if (llvm::Error Err = init())
return Err;
return Devices;
}

/// Complete all pending work for this platform and perform any needed
/// cleanup.
///
/// After calling this function, no liboffload functions should be called with
/// this platform handle.
llvm::Error destroy();

/// Initialize the associated plugin and devices.
llvm::Error init();

/// Direct access to the plugin, may be uninitialized if accessed here.
std::unique_ptr<GenericPluginTy> Plugin;

private:
std::once_flag Initialized;
llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
};

// Handle type definitions. Ideally these would be 1:1 with the plugins, but
Expand Down Expand Up @@ -130,6 +153,42 @@ llvm::Error ol_platform_impl_t::destroy() {
return Result;
}

llvm::Error ol_platform_impl_t::init() {
std::unique_ptr<llvm::Error> Storage;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it needed to use dynamic memory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I'm aware, if you pass in a pointer to local you'd get errors for that local not being handled. So you need something nullable. It will only dynamically allocate in the case of an error, so I don't think performance is an issue there.


// This can be called concurrently, make sure we only do the actual
// initialization once.
std::call_once(Initialized, [&]() {
// FIXME: Need better handling for the host platform.
if (!Plugin)
return;

llvm::Error Err = Plugin->init();
if (Err) {
Storage = std::make_unique<llvm::Error>(std::move(Err));
return;
}

for (auto Id = 0, End = Plugin->getNumDevices(); Id != End; Id++) {
if (llvm::Error Err = Plugin->initDevice(Id)) {
Storage = std::make_unique<llvm::Error>(std::move(Err));
return;
}

auto Device = &Plugin->getDevice(Id);
auto Info = Device->obtainInfoImpl();
if (llvm::Error Err = Info.takeError()) {
Storage = std::make_unique<llvm::Error>(std::move(Err));
return;
}
Devices.emplace_back(std::make_unique<ol_device_impl_t>(
Id, Device, *this, std::move(*Info)));
}
});

return Storage ? std::move(*Storage) : llvm::Error::success();
}

struct ol_queue_impl_t {
ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
: AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {}
Expand Down Expand Up @@ -209,13 +268,9 @@ struct OffloadContext {
// key in AllocInfoMap
llvm::SmallVector<void *> AllocBases{};
SmallVector<std::unique_ptr<ol_platform_impl_t>, 4> Platforms{};
ol_device_handle_t HostDevice;
size_t RefCount;

ol_device_handle_t HostDevice() {
// The host platform is always inserted last
return Platforms.back()->Devices[0].get();
}

static OffloadContext &get() {
assert(OffloadContextVal);
return *OffloadContextVal;
Expand Down Expand Up @@ -259,28 +314,16 @@ Error initPlugins(OffloadContext &Context) {
} while (false);
#include "Shared/Targets.def"

// Preemptively initialize all devices in the plugin
for (auto &Platform : Context.Platforms) {
auto Err = Platform->Plugin->init();
[[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
for (auto DevNum = 0; DevNum < Platform->Plugin->number_of_devices();
DevNum++) {
if (Platform->Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
auto Device = &Platform->Plugin->getDevice(DevNum);
auto Info = Device->obtainInfoImpl();
if (auto Err = Info.takeError())
return Err;
Platform->Devices.emplace_back(std::make_unique<ol_device_impl_t>(
DevNum, Device, *Platform, std::move(*Info)));
}
}
}

// Add the special host device
auto &HostPlatform = Context.Platforms.emplace_back(
std::make_unique<ol_platform_impl_t>(nullptr, OL_PLATFORM_BACKEND_HOST));
HostPlatform->Devices.emplace_back(std::make_unique<ol_device_impl_t>(
-1, nullptr, *HostPlatform, InfoTreeNode{}));
auto DevicesOrErr = HostPlatform->getDevices();
if (!DevicesOrErr)
return DevicesOrErr.takeError();
Context.HostDevice = DevicesOrErr
->emplace_back(std::make_unique<ol_device_impl_t>(
-1, nullptr, *HostPlatform, InfoTreeNode{}))
.get();

Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
Expand Down Expand Up @@ -315,12 +358,12 @@ Error olShutDown_impl() {
llvm::Error Result = Error::success();
auto *OldContext = OffloadContextVal.exchange(nullptr);

for (auto &P : OldContext->Platforms) {
for (auto &Platform : OldContext->Platforms) {
// Host plugin is nullptr and has no deinit
if (!P->Plugin || !P->Plugin->is_initialized())
if (!Platform->Plugin || !Platform->Plugin->is_initialized())
continue;

if (auto Res = P->destroy())
if (auto Res = Platform->destroy())
Result = llvm::joinErrors(std::move(Result), std::move(Res));
}

Expand All @@ -334,6 +377,8 @@ Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
InfoWriter Info(PropSize, PropValue, PropSizeRet);
bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST;

// Note that the plugin is potentially uninitialized here. It will need to be
// initialized once info is added that requires it to be initialized.
switch (PropName) {
case OL_PLATFORM_INFO_NAME:
return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName());
Expand Down Expand Up @@ -373,7 +418,7 @@ Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
ol_device_info_t PropName, size_t PropSize,
void *PropValue, size_t *PropSizeRet) {
assert(Device != OffloadContext::get().HostDevice());
assert(Device != OffloadContext::get().HostDevice);
InfoWriter Info(PropSize, PropValue, PropSizeRet);

auto makeError = [&](ErrorCode Code, StringRef Err) {
Expand Down Expand Up @@ -511,7 +556,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
ol_device_info_t PropName, size_t PropSize,
void *PropValue, size_t *PropSizeRet) {
assert(Device == OffloadContext::get().HostDevice());
assert(Device == OffloadContext::get().HostDevice);
InfoWriter Info(PropSize, PropValue, PropSizeRet);

constexpr auto uint32_max = std::numeric_limits<uint32_t>::max();
Expand Down Expand Up @@ -579,7 +624,7 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,

Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
size_t PropSize, void *PropValue) {
if (Device == OffloadContext::get().HostDevice())
if (Device == OffloadContext::get().HostDevice)
return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue,
nullptr);
return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
Expand All @@ -588,17 +633,20 @@ Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,

Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
ol_device_info_t PropName, size_t *PropSizeRet) {
if (Device == OffloadContext::get().HostDevice())
if (Device == OffloadContext::get().HostDevice)
return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr,
PropSizeRet);
return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
}

Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
for (auto &Platform : OffloadContext::get().Platforms) {
for (auto &Device : Platform->Devices) {
auto DevicesOrErr = Platform->getDevices();
if (!DevicesOrErr)
return DevicesOrErr.takeError();
for (auto &Device : *DevicesOrErr) {
if (!Callback(Device.get(), UserData)) {
break;
return Error::success();
}
}
}
Expand Down Expand Up @@ -949,7 +997,7 @@ Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) {
Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
ol_device_handle_t DstDevice, const void *SrcPtr,
ol_device_handle_t SrcDevice, size_t Size) {
auto Host = OffloadContext::get().HostDevice();
auto Host = OffloadContext::get().HostDevice;
if (DstDevice == Host && SrcDevice == Host) {
if (!Queue) {
std::memcpy(DstPtr, SrcPtr, Size);
Expand Down
Loading