diff --git a/offload/liboffload/API/Common.td b/offload/liboffload/API/Common.td index 79c3bd46f1984..669dfd3cca7c6 100644 --- a/offload/liboffload/API/Common.td +++ b/offload/liboffload/API/Common.td @@ -176,7 +176,7 @@ def : Function { let desc = "Release the resources in use by Offload"; let details = [ "This decrements an internal reference count. When this reaches 0, all resources will be released", - "Subsequent API calls made after this are not valid" + "Subsequent API calls to methods other than `olInit` made after resources are released will return OL_ERRC_UNINITIALIZED" ]; let params = []; let returns = []; diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 6adebb25a2db0..9d4f4f54a8217 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -96,7 +96,10 @@ struct AllocInfo { // Global shared state for liboffload struct OffloadContext; -static OffloadContext *OffloadContextVal; +// This pointer is non-null if and only if the context is valid and fully +// initialized +static std::atomic OffloadContextVal; +std::mutex OffloadContextValMutex; struct OffloadContext { OffloadContext(OffloadContext &) = delete; OffloadContext(OffloadContext &&) = delete; @@ -107,6 +110,7 @@ struct OffloadContext { bool ValidationEnabled = true; DenseMap AllocInfoMap{}; SmallVector Platforms{}; + size_t RefCount; ol_device_handle_t HostDevice() { // The host platform is always inserted last @@ -145,20 +149,18 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) { #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name(); #include "Shared/Targets.def" -Error initPlugins() { - auto *Context = new OffloadContext{}; - +Error initPlugins(OffloadContext &Context) { // Attempt to create an instance of each supported plugin. #define PLUGIN_TARGET(Name) \ do { \ - Context->Platforms.emplace_back(ol_platform_impl_t{ \ + Context.Platforms.emplace_back(ol_platform_impl_t{ \ std::unique_ptr(createPlugin_##Name()), \ pluginNameToBackend(#Name)}); \ } while (false); #include "Shared/Targets.def" // Preemptively initialize all devices in the plugin - for (auto &Platform : Context->Platforms) { + for (auto &Platform : Context.Platforms) { // Do not use the host plugin - it isn't supported. if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN) continue; @@ -178,31 +180,56 @@ Error initPlugins() { } // Add the special host device - auto &HostPlatform = Context->Platforms.emplace_back( + auto &HostPlatform = Context.Platforms.emplace_back( ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST}); HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{}); - Context->HostDevice()->Platform = &HostPlatform; - - Context->TracingEnabled = std::getenv("OFFLOAD_TRACE"); - Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); + Context.HostDevice()->Platform = &HostPlatform; - OffloadContextVal = Context; + Context.TracingEnabled = std::getenv("OFFLOAD_TRACE"); + Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); return Plugin::success(); } -// TODO: We can properly reference count here and manage the resources in a more -// clever way Error olInit_impl() { - static std::once_flag InitFlag; - std::optional InitResult{}; - std::call_once(InitFlag, [&] { InitResult = initPlugins(); }); + std::lock_guard Lock{OffloadContextValMutex}; - if (InitResult) - return std::move(*InitResult); - return Error::success(); + if (isOffloadInitialized()) { + OffloadContext::get().RefCount++; + return Plugin::success(); + } + + // Use a temporary to ensure that entry points querying OffloadContextVal do + // not get a partially initialized context + auto *NewContext = new OffloadContext{}; + Error InitResult = initPlugins(*NewContext); + OffloadContextVal.store(NewContext); + OffloadContext::get().RefCount++; + + return InitResult; +} + +Error olShutDown_impl() { + std::lock_guard Lock{OffloadContextValMutex}; + + if (--OffloadContext::get().RefCount != 0) + return Error::success(); + + llvm::Error Result = Error::success(); + auto *OldContext = OffloadContextVal.exchange(nullptr); + + for (auto &P : OldContext->Platforms) { + // Host plugin is nullptr and has no deinit + if (!P.Plugin) + continue; + + if (auto Res = P.Plugin->deinit()) + Result = llvm::joinErrors(std::move(Result), std::move(Res)); + } + + delete OldContext; + return Result; } -Error olShutDown_impl() { return Error::success(); } Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t PropSize, diff --git a/offload/unittests/OffloadAPI/init/olInit.cpp b/offload/unittests/OffloadAPI/init/olInit.cpp index 8e27e77cd0fb5..508615152b4f1 100644 --- a/offload/unittests/OffloadAPI/init/olInit.cpp +++ b/offload/unittests/OffloadAPI/init/olInit.cpp @@ -15,8 +15,20 @@ struct olInitTest : ::testing::Test {}; +TEST_F(olInitTest, Success) { + ASSERT_SUCCESS(olInit()); + ASSERT_SUCCESS(olShutDown()); +} + TEST_F(olInitTest, Uninitialized) { ASSERT_ERROR(OL_ERRC_UNINITIALIZED, olIterateDevices( [](ol_device_handle_t, void *) { return false; }, nullptr)); } + +TEST_F(olInitTest, RepeatedInit) { + for (size_t I = 0; I < 10; I++) { + ASSERT_SUCCESS(olInit()); + ASSERT_SUCCESS(olShutDown()); + } +}