@@ -15,18 +15,36 @@ struct ur_adapter_handle_t_ {
15
15
std::mutex Mutex;
16
16
};
17
17
18
- ur_adapter_handle_t_ adapter{};
18
+ static ur_adapter_handle_t_ *adapter = nullptr ;
19
+
20
+ static void globalAdapterShutdown () {
21
+ if (cl_ext::ExtFuncPtrCache) {
22
+ delete cl_ext::ExtFuncPtrCache;
23
+ cl_ext::ExtFuncPtrCache = nullptr ;
24
+ }
25
+ if (adapter) {
26
+ delete adapter;
27
+ adapter = nullptr ;
28
+ }
29
+ }
19
30
20
31
UR_APIEXPORT ur_result_t UR_APICALL
21
32
urAdapterGet (uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
22
33
uint32_t *pNumAdapters) {
23
34
if (NumEntries > 0 && phAdapters) {
24
- std::lock_guard<std::mutex> Lock{adapter.Mutex };
25
- if (adapter.RefCount ++ == 0 ) {
26
- cl_ext::ExtFuncPtrCache = std::make_unique<cl_ext::ExtFuncPtrCacheT>();
35
+ // Sometimes urAdaterGet may be called after the library already been torn
36
+ // down, we also need to create a temporary handle for it.
37
+ if (!adapter) {
38
+ adapter = new ur_adapter_handle_t_ ();
39
+ atexit (globalAdapterShutdown);
27
40
}
28
41
29
- *phAdapters = &adapter;
42
+ std::lock_guard<std::mutex> Lock{adapter->Mutex };
43
+ if (adapter->RefCount ++ == 0 ) {
44
+ cl_ext::ExtFuncPtrCache = new cl_ext::ExtFuncPtrCacheT ();
45
+ }
46
+
47
+ *phAdapters = adapter;
30
48
}
31
49
32
50
if (pNumAdapters) {
@@ -37,14 +55,20 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
37
55
}
38
56
39
57
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain (ur_adapter_handle_t ) {
40
- ++adapter. RefCount ;
58
+ ++adapter-> RefCount ;
41
59
return UR_RESULT_SUCCESS;
42
60
}
43
61
44
62
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease (ur_adapter_handle_t ) {
45
- std::lock_guard<std::mutex> Lock{adapter.Mutex };
46
- if (--adapter.RefCount == 0 ) {
47
- cl_ext::ExtFuncPtrCache.reset ();
63
+ // Check first if the adapter is valid pointer
64
+ if (adapter) {
65
+ std::lock_guard<std::mutex> Lock{adapter->Mutex };
66
+ if (--adapter->RefCount == 0 ) {
67
+ if (cl_ext::ExtFuncPtrCache) {
68
+ delete cl_ext::ExtFuncPtrCache;
69
+ cl_ext::ExtFuncPtrCache = nullptr ;
70
+ }
71
+ }
48
72
}
49
73
return UR_RESULT_SUCCESS;
50
74
}
@@ -68,7 +92,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
68
92
case UR_ADAPTER_INFO_BACKEND:
69
93
return ReturnValue (UR_ADAPTER_BACKEND_OPENCL);
70
94
case UR_ADAPTER_INFO_REFERENCE_COUNT:
71
- return ReturnValue (adapter. RefCount .load ());
95
+ return ReturnValue (adapter-> RefCount .load ());
72
96
default :
73
97
return UR_RESULT_ERROR_INVALID_ENUMERATION;
74
98
}
0 commit comments