@@ -96,7 +96,10 @@ struct AllocInfo {
9696
9797// Global shared state for liboffload
9898struct OffloadContext ;
99- static OffloadContext *OffloadContextVal;
99+ // This pointer is non-null if and only if the context is valid and fully
100+ // initialized
101+ static std::atomic<OffloadContext *> OffloadContextVal;
102+ std::mutex OffloadContextValMutex;
100103struct OffloadContext {
101104 OffloadContext (OffloadContext &) = delete ;
102105 OffloadContext (OffloadContext &&) = delete ;
@@ -107,6 +110,7 @@ struct OffloadContext {
107110 bool ValidationEnabled = true ;
108111 DenseMap<void *, AllocInfo> AllocInfoMap{};
109112 SmallVector<ol_platform_impl_t , 4 > Platforms{};
113+ size_t RefCount;
110114
111115 ol_device_handle_t HostDevice () {
112116 // The host platform is always inserted last
@@ -145,20 +149,18 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
145149#define PLUGIN_TARGET (Name ) extern " C" GenericPluginTy *createPlugin_##Name();
146150#include " Shared/Targets.def"
147151
148- Error initPlugins () {
149- auto *Context = new OffloadContext{};
150-
152+ Error initPlugins (OffloadContext &Context) {
151153 // Attempt to create an instance of each supported plugin.
152154#define PLUGIN_TARGET (Name ) \
153155 do { \
154- Context-> Platforms .emplace_back (ol_platform_impl_t { \
156+ Context. Platforms .emplace_back (ol_platform_impl_t { \
155157 std::unique_ptr<GenericPluginTy>(createPlugin_##Name ()), \
156158 pluginNameToBackend (#Name)}); \
157159 } while (false );
158160#include " Shared/Targets.def"
159161
160162 // Preemptively initialize all devices in the plugin
161- for (auto &Platform : Context-> Platforms ) {
163+ for (auto &Platform : Context. Platforms ) {
162164 // Do not use the host plugin - it isn't supported.
163165 if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
164166 continue ;
@@ -178,31 +180,56 @@ Error initPlugins() {
178180 }
179181
180182 // Add the special host device
181- auto &HostPlatform = Context-> Platforms .emplace_back (
183+ auto &HostPlatform = Context. Platforms .emplace_back (
182184 ol_platform_impl_t {nullptr , OL_PLATFORM_BACKEND_HOST});
183185 HostPlatform.Devices .emplace_back (-1 , nullptr , nullptr , InfoTreeNode{});
184- Context->HostDevice ()->Platform = &HostPlatform;
185-
186- Context->TracingEnabled = std::getenv (" OFFLOAD_TRACE" );
187- Context->ValidationEnabled = !std::getenv (" OFFLOAD_DISABLE_VALIDATION" );
186+ Context.HostDevice ()->Platform = &HostPlatform;
188187
189- OffloadContextVal = Context;
188+ Context.TracingEnabled = std::getenv (" OFFLOAD_TRACE" );
189+ Context.ValidationEnabled = !std::getenv (" OFFLOAD_DISABLE_VALIDATION" );
190190
191191 return Plugin::success ();
192192}
193193
194- // TODO: We can properly reference count here and manage the resources in a more
195- // clever way
196194Error olInit_impl () {
197- static std::once_flag InitFlag;
198- std::optional<Error> InitResult{};
199- std::call_once (InitFlag, [&] { InitResult = initPlugins (); });
195+ std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
200196
201- if (InitResult)
202- return std::move (*InitResult);
203- return Error::success ();
197+ if (isOffloadInitialized ()) {
198+ OffloadContext::get ().RefCount ++;
199+ return Plugin::success ();
200+ }
201+
202+ // Use a temporary to ensure that entry points querying OffloadContextVal do
203+ // not get a partially initialized context
204+ auto *NewContext = new OffloadContext{};
205+ Error InitResult = initPlugins (*NewContext);
206+ OffloadContextVal.store (NewContext);
207+ OffloadContext::get ().RefCount ++;
208+
209+ return InitResult;
210+ }
211+
212+ Error olShutDown_impl () {
213+ std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
214+
215+ if (--OffloadContext::get ().RefCount != 0 )
216+ return Error::success ();
217+
218+ llvm::Error Result = Error::success ();
219+ auto *OldContext = OffloadContextVal.exchange (nullptr );
220+
221+ for (auto &P : OldContext->Platforms ) {
222+ // Host plugin is nullptr and has no deinit
223+ if (!P.Plugin )
224+ continue ;
225+
226+ if (auto Res = P.Plugin ->deinit ())
227+ Result = llvm::joinErrors (std::move (Result), std::move (Res));
228+ }
229+
230+ delete OldContext;
231+ return Result;
204232}
205- Error olShutDown_impl () { return Error::success (); }
206233
207234Error olGetPlatformInfoImplDetail (ol_platform_handle_t Platform,
208235 ol_platform_info_t PropName, size_t PropSize,
0 commit comments