1111#include " adapter.hpp"
1212#include " ur_level_zero.hpp"
1313
14+ // Due to multiple DLLMain definitions with SYCL, Global Adapter is init at
15+ // variable creation.
16+ #if defined(_WIN32)
17+ ur_adapter_handle_t_ *GlobalAdapter = new ur_adapter_handle_t_();
18+ #else
19+ ur_adapter_handle_t_ *GlobalAdapter;
20+ #endif
21+
1422ur_result_t initPlatforms (PlatformVec &platforms) noexcept try {
1523 uint32_t ZeDriverCount = 0 ;
1624 ZE2UR_CALL (zeDriverGet, (&ZeDriverCount, nullptr ));
@@ -37,8 +45,7 @@ ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
3745ur_result_t adapterStateInit () { return UR_RESULT_SUCCESS; }
3846
3947ur_adapter_handle_t_::ur_adapter_handle_t_ () {
40-
41- Adapter.PlatformCache .Compute = [](Result<PlatformVec> &result) {
48+ PlatformCache.Compute = [](Result<PlatformVec> &result) {
4249 static std::once_flag ZeCallCountInitialized;
4350 try {
4451 std::call_once (ZeCallCountInitialized, []() {
@@ -52,7 +59,7 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
5259 }
5360
5461 // initialize level zero only once.
55- if (Adapter. ZeResult == std::nullopt ) {
62+ if (GlobalAdapter-> ZeResult == std::nullopt ) {
5663 // Setting these environment variables before running zeInit will enable
5764 // the validation layer in the Level Zero loader.
5865 if (UrL0Debug & UR_L0_DEBUG_VALIDATION) {
@@ -71,20 +78,21 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
7178 // We must only initialize the driver once, even if urPlatformGet() is
7279 // called multiple times. Declaring the return value as "static" ensures
7380 // it's only called once.
74- Adapter.ZeResult = ZE_CALL_NOCHECK (zeInit, (ZE_INIT_FLAG_GPU_ONLY));
81+ GlobalAdapter->ZeResult =
82+ ZE_CALL_NOCHECK (zeInit, (ZE_INIT_FLAG_GPU_ONLY));
7583 }
76- assert (Adapter. ZeResult !=
84+ assert (GlobalAdapter-> ZeResult !=
7785 std::nullopt ); // verify that level-zero is initialized
7886 PlatformVec platforms;
7987
8088 // Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms.
81- if (*Adapter. ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
89+ if (*GlobalAdapter-> ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
8290 result = std::move (platforms);
8391 return ;
8492 }
85- if (*Adapter. ZeResult != ZE_RESULT_SUCCESS) {
93+ if (*GlobalAdapter-> ZeResult != ZE_RESULT_SUCCESS) {
8694 urPrint (" zeInit: Level Zero initialization failure\n " );
87- result = ze2urResult (*Adapter. ZeResult );
95+ result = ze2urResult (*GlobalAdapter-> ZeResult );
8896 return ;
8997 }
9098
@@ -97,7 +105,11 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
97105 };
98106}
99107
100- ur_adapter_handle_t_ Adapter{};
108+ void globalAdapterOnDemandCleanup () {
109+ if (GlobalAdapter) {
110+ delete GlobalAdapter;
111+ }
112+ }
101113
102114ur_result_t adapterStateTeardown () {
103115 bool LeakFound = false ;
@@ -184,6 +196,11 @@ ur_result_t adapterStateTeardown() {
184196 }
185197 if (LeakFound)
186198 return UR_RESULT_ERROR_INVALID_MEM_OBJECT;
199+ // Due to multiple DLLMain definitions with SYCL, register to cleanup the
200+ // Global Adapter after refcnt is 0
201+ #if defined(_WIN32)
202+ std::atexit (globalAdapterOnDemandCleanup);
203+ #endif
187204
188205 return UR_RESULT_SUCCESS;
189206}
@@ -203,11 +220,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
203220 // /< adapters available.
204221) {
205222 if (NumEntries > 0 && Adapters) {
206- std::lock_guard<std::mutex> Lock{Adapter.Mutex };
207- if (Adapter.RefCount ++ == 0 ) {
208- adapterStateInit ();
223+ if (GlobalAdapter) {
224+ std::lock_guard<std::mutex> Lock{GlobalAdapter->Mutex };
225+ if (GlobalAdapter->RefCount ++ == 0 ) {
226+ adapterStateInit ();
227+ }
228+ } else {
229+ // If the GetAdapter is called after the Library began or was torndown,
230+ // then temporarily create a new Adapter handle and register a new
231+ // cleanup.
232+ GlobalAdapter = new ur_adapter_handle_t_ ();
233+ std::lock_guard<std::mutex> Lock{GlobalAdapter->Mutex };
234+ if (GlobalAdapter->RefCount ++ == 0 ) {
235+ adapterStateInit ();
236+ }
237+ std::atexit (globalAdapterOnDemandCleanup);
209238 }
210- *Adapters = &Adapter ;
239+ *Adapters = GlobalAdapter ;
211240 }
212241
213242 if (NumAdapters) {
@@ -218,17 +247,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
218247}
219248
220249UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease (ur_adapter_handle_t ) {
221- std::lock_guard<std::mutex> Lock{Adapter.Mutex };
222- if (--Adapter.RefCount == 0 ) {
223- return adapterStateTeardown ();
250+ // Check first if the Adapter pointer is valid
251+ if (GlobalAdapter) {
252+ std::lock_guard<std::mutex> Lock{GlobalAdapter->Mutex };
253+ if (--GlobalAdapter->RefCount == 0 ) {
254+ return adapterStateTeardown ();
255+ }
224256 }
225257
226258 return UR_RESULT_SUCCESS;
227259}
228260
229261UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain (ur_adapter_handle_t ) {
230- std::lock_guard<std::mutex> Lock{Adapter.Mutex };
231- Adapter.RefCount ++;
262+ if (GlobalAdapter) {
263+ std::lock_guard<std::mutex> Lock{GlobalAdapter->Mutex };
264+ GlobalAdapter->RefCount ++;
265+ }
232266
233267 return UR_RESULT_SUCCESS;
234268}
@@ -257,7 +291,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
257291 case UR_ADAPTER_INFO_BACKEND:
258292 return ReturnValue (UR_ADAPTER_BACKEND_LEVEL_ZERO);
259293 case UR_ADAPTER_INFO_REFERENCE_COUNT:
260- return ReturnValue (Adapter. RefCount .load ());
294+ return ReturnValue (GlobalAdapter-> RefCount .load ());
261295 default :
262296 return UR_RESULT_ERROR_INVALID_ENUMERATION;
263297 }
0 commit comments