diff --git a/offload/liboffload/API/Common.td b/offload/liboffload/API/Common.td index ac27d85b6c964..b47223612479a 100644 --- a/offload/liboffload/API/Common.td +++ b/offload/liboffload/API/Common.td @@ -140,9 +140,10 @@ def ol_dimensions_t : Struct { } def olInit : Function { - let desc = "Perform initialization of the Offload library and plugins"; + let desc = "Perform initialization of the Offload library"; let details = [ "This must be the first API call made by a user of the Offload library", + "The underlying platforms are lazily initialized on their first use" "Each call will increment an internal reference count that is decremented by `olShutDown`" ]; let params = []; diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index c549ae04361d0..e5fd68ee3d668 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -42,17 +42,40 @@ using namespace error; struct ol_platform_impl_t { ol_platform_impl_t(std::unique_ptr Plugin, ol_platform_backend_t BackendType) - : Plugin(std::move(Plugin)), BackendType(BackendType) {} - std::unique_ptr Plugin; - llvm::SmallVector> Devices; + : BackendType(BackendType), Plugin(std::move(Plugin)) {} ol_platform_backend_t BackendType; + /// Get the plugin, lazily initializing it if necessary. + llvm::Expected getPlugin() { + if (llvm::Error Err = init()) + return Err; + return Plugin.get(); + } + + /// Get the device list, lazily initializing it if necessary. + llvm::Expected> &> + getDevices() { + if (llvm::Error Err = init()) + return Err; + return Devices; + } + /// Complete all pending work for this platform and perform any needed /// cleanup. /// /// After calling this function, no liboffload functions should be called with /// this platform handle. llvm::Error destroy(); + + /// Initialize the associated plugin and devices. + llvm::Error init(); + + /// Direct access to the plugin, may be uninitialized if accessed here. + std::unique_ptr Plugin; + +private: + std::once_flag Initialized; + llvm::SmallVector> Devices; }; // Handle type definitions. Ideally these would be 1:1 with the plugins, but @@ -130,6 +153,42 @@ llvm::Error ol_platform_impl_t::destroy() { return Result; } +llvm::Error ol_platform_impl_t::init() { + std::unique_ptr Storage; + + // This can be called concurrently, make sure we only do the actual + // initialization once. + std::call_once(Initialized, [&]() { + // FIXME: Need better handling for the host platform. + if (!Plugin) + return; + + llvm::Error Err = Plugin->init(); + if (Err) { + Storage = std::make_unique(std::move(Err)); + return; + } + + for (auto Id = 0, End = Plugin->getNumDevices(); Id != End; Id++) { + if (llvm::Error Err = Plugin->initDevice(Id)) { + Storage = std::make_unique(std::move(Err)); + return; + } + + auto Device = &Plugin->getDevice(Id); + auto Info = Device->obtainInfoImpl(); + if (llvm::Error Err = Info.takeError()) { + Storage = std::make_unique(std::move(Err)); + return; + } + Devices.emplace_back(std::make_unique( + Id, Device, *this, std::move(*Info))); + } + }); + + return Storage ? std::move(*Storage) : llvm::Error::success(); +} + struct ol_queue_impl_t { ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device) : AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {} @@ -209,13 +268,9 @@ struct OffloadContext { // key in AllocInfoMap llvm::SmallVector AllocBases{}; SmallVector, 4> Platforms{}; + ol_device_handle_t HostDevice; size_t RefCount; - ol_device_handle_t HostDevice() { - // The host platform is always inserted last - return Platforms.back()->Devices[0].get(); - } - static OffloadContext &get() { assert(OffloadContextVal); return *OffloadContextVal; @@ -259,28 +314,16 @@ Error initPlugins(OffloadContext &Context) { } while (false); #include "Shared/Targets.def" - // Preemptively initialize all devices in the plugin - for (auto &Platform : Context.Platforms) { - auto Err = Platform->Plugin->init(); - [[maybe_unused]] std::string InfoMsg = toString(std::move(Err)); - for (auto DevNum = 0; DevNum < Platform->Plugin->number_of_devices(); - DevNum++) { - if (Platform->Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { - auto Device = &Platform->Plugin->getDevice(DevNum); - auto Info = Device->obtainInfoImpl(); - if (auto Err = Info.takeError()) - return Err; - Platform->Devices.emplace_back(std::make_unique( - DevNum, Device, *Platform, std::move(*Info))); - } - } - } - // Add the special host device auto &HostPlatform = Context.Platforms.emplace_back( std::make_unique(nullptr, OL_PLATFORM_BACKEND_HOST)); - HostPlatform->Devices.emplace_back(std::make_unique( - -1, nullptr, *HostPlatform, InfoTreeNode{})); + auto DevicesOrErr = HostPlatform->getDevices(); + if (!DevicesOrErr) + return DevicesOrErr.takeError(); + Context.HostDevice = DevicesOrErr + ->emplace_back(std::make_unique( + -1, nullptr, *HostPlatform, InfoTreeNode{})) + .get(); Context.TracingEnabled = std::getenv("OFFLOAD_TRACE"); Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); @@ -315,12 +358,12 @@ Error olShutDown_impl() { llvm::Error Result = Error::success(); auto *OldContext = OffloadContextVal.exchange(nullptr); - for (auto &P : OldContext->Platforms) { + for (auto &Platform : OldContext->Platforms) { // Host plugin is nullptr and has no deinit - if (!P->Plugin || !P->Plugin->is_initialized()) + if (!Platform->Plugin || !Platform->Plugin->is_initialized()) continue; - if (auto Res = P->destroy()) + if (auto Res = Platform->destroy()) Result = llvm::joinErrors(std::move(Result), std::move(Res)); } @@ -334,6 +377,8 @@ Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, InfoWriter Info(PropSize, PropValue, PropSizeRet); bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST; + // Note that the plugin is potentially uninitialized here. It will need to be + // initialized once info is added that requires it to be initialized. switch (PropName) { case OL_PLATFORM_INFO_NAME: return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName()); @@ -373,7 +418,7 @@ Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform, Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { - assert(Device != OffloadContext::get().HostDevice()); + assert(Device != OffloadContext::get().HostDevice); InfoWriter Info(PropSize, PropValue, PropSizeRet); auto makeError = [&](ErrorCode Code, StringRef Err) { @@ -511,7 +556,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { - assert(Device == OffloadContext::get().HostDevice()); + assert(Device == OffloadContext::get().HostDevice); InfoWriter Info(PropSize, PropValue, PropSizeRet); constexpr auto uint32_max = std::numeric_limits::max(); @@ -579,7 +624,7 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue) { - if (Device == OffloadContext::get().HostDevice()) + if (Device == OffloadContext::get().HostDevice) return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue, nullptr); return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue, @@ -588,7 +633,7 @@ Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName, Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, ol_device_info_t PropName, size_t *PropSizeRet) { - if (Device == OffloadContext::get().HostDevice()) + if (Device == OffloadContext::get().HostDevice) return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr, PropSizeRet); return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet); @@ -596,9 +641,12 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) { for (auto &Platform : OffloadContext::get().Platforms) { - for (auto &Device : Platform->Devices) { + auto DevicesOrErr = Platform->getDevices(); + if (!DevicesOrErr) + return DevicesOrErr.takeError(); + for (auto &Device : *DevicesOrErr) { if (!Callback(Device.get(), UserData)) { - break; + return Error::success(); } } } @@ -949,7 +997,7 @@ Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) { Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr, ol_device_handle_t DstDevice, const void *SrcPtr, ol_device_handle_t SrcDevice, size_t Size) { - auto Host = OffloadContext::get().HostDevice(); + auto Host = OffloadContext::get().HostDevice; if (DstDevice == Host && SrcDevice == Host) { if (!Queue) { std::memcpy(DstPtr, SrcPtr, Size);