Skip to content

Commit 3c55cac

Browse files
committed
Merge branch 'main' into user-after-free
2 parents 3d9aa64 + d99d5f7 commit 3c55cac

File tree

4 files changed

+56
-26
lines changed

4 files changed

+56
-26
lines changed

source/adapters/cuda/image.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,17 +1006,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesMapExternalArrayExp(
10061006
ArrayDesc.Format = format;
10071007

10081008
CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC mipmapDesc = {};
1009-
mipmapDesc.numLevels = 1;
1009+
mipmapDesc.numLevels = pImageDesc->numMipLevel;
10101010
mipmapDesc.arrayDesc = ArrayDesc;
10111011

1012+
// External memory is mapped to a CUmipmappedArray
1013+
// If desired, a CUarray is retrieved from the mipmaps 0th level
10121014
CUmipmappedArray memMipMap;
10131015
UR_CHECK_ERROR(cuExternalMemoryGetMappedMipmappedArray(
10141016
&memMipMap, (CUexternalMemory)hInteropMem, &mipmapDesc));
10151017

1016-
CUarray memArray;
1017-
UR_CHECK_ERROR(cuMipmappedArrayGetLevel(&memArray, memMipMap, 0));
1018+
if (pImageDesc->numMipLevel > 1) {
1019+
*phImageMem = (ur_exp_image_mem_handle_t)memMipMap;
1020+
} else {
1021+
CUarray memArray;
1022+
UR_CHECK_ERROR(cuMipmappedArrayGetLevel(&memArray, memMipMap, 0));
10181023

1019-
*phImageMem = (ur_exp_image_mem_handle_t)memArray;
1024+
*phImageMem = (ur_exp_image_mem_handle_t)memArray;
1025+
}
10201026

10211027
} catch (ur_result_t Err) {
10221028
return Err;

source/adapters/cuda/tracing.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,20 @@
2727
using tracing_event_t = xpti_td *;
2828
using subscriber_handle_t = CUpti_SubscriberHandle;
2929

30-
using cuptiSubscribe_fn = CUPTIAPI
31-
CUptiResult (*)(CUpti_SubscriberHandle *subscriber, CUpti_CallbackFunc callback,
32-
void *userdata);
30+
using cuptiSubscribe_fn =
31+
CUptiResult(CUPTIAPI *)(CUpti_SubscriberHandle *subscriber,
32+
CUpti_CallbackFunc callback, void *userdata);
3333

34-
using cuptiUnsubscribe_fn = CUPTIAPI
35-
CUptiResult (*)(CUpti_SubscriberHandle subscriber);
34+
using cuptiUnsubscribe_fn =
35+
CUptiResult(CUPTIAPI *)(CUpti_SubscriberHandle subscriber);
3636

37-
using cuptiEnableDomain_fn = CUPTIAPI
38-
CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
39-
CUpti_CallbackDomain domain);
37+
using cuptiEnableDomain_fn =
38+
CUptiResult(CUPTIAPI *)(uint32_t enable, CUpti_SubscriberHandle subscriber,
39+
CUpti_CallbackDomain domain);
4040

41-
using cuptiEnableCallback_fn = CUPTIAPI
42-
CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
43-
CUpti_CallbackDomain domain, CUpti_CallbackId cbid);
41+
using cuptiEnableCallback_fn =
42+
CUptiResult(CUPTIAPI *)(uint32_t enable, CUpti_SubscriberHandle subscriber,
43+
CUpti_CallbackDomain domain, CUpti_CallbackId cbid);
4444

4545
#define LOAD_CUPTI_SYM(p, lib, x) \
4646
p.x = (cupti##x##_fn)ur_loader::LibLoader::getFunctionPtr(lib.get(), \

source/adapters/opencl/adapter.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,36 @@ struct ur_adapter_handle_t_ {
1515
std::mutex Mutex;
1616
};
1717

18-
ur_adapter_handle_t_ adapter{};
18+
static ur_adapter_handle_t_ *adapter = nullptr;
19+
20+
static void globalAdapterShutdown() {
21+
if (cl_ext::ExtFuncPtrCache) {
22+
delete cl_ext::ExtFuncPtrCache;
23+
cl_ext::ExtFuncPtrCache = nullptr;
24+
}
25+
if (adapter) {
26+
delete adapter;
27+
adapter = nullptr;
28+
}
29+
}
1930

2031
UR_APIEXPORT ur_result_t UR_APICALL
2132
urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
2233
uint32_t *pNumAdapters) {
2334
if (NumEntries > 0 && phAdapters) {
24-
std::lock_guard<std::mutex> Lock{adapter.Mutex};
25-
if (adapter.RefCount++ == 0) {
26-
cl_ext::ExtFuncPtrCache = std::make_unique<cl_ext::ExtFuncPtrCacheT>();
35+
// Sometimes urAdaterGet may be called after the library already been torn
36+
// down, we also need to create a temporary handle for it.
37+
if (!adapter) {
38+
adapter = new ur_adapter_handle_t_();
39+
atexit(globalAdapterShutdown);
2740
}
2841

29-
*phAdapters = &adapter;
42+
std::lock_guard<std::mutex> Lock{adapter->Mutex};
43+
if (adapter->RefCount++ == 0) {
44+
cl_ext::ExtFuncPtrCache = new cl_ext::ExtFuncPtrCacheT();
45+
}
46+
47+
*phAdapters = adapter;
3048
}
3149

3250
if (pNumAdapters) {
@@ -37,14 +55,20 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
3755
}
3856

3957
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
40-
++adapter.RefCount;
58+
++adapter->RefCount;
4159
return UR_RESULT_SUCCESS;
4260
}
4361

4462
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
45-
std::lock_guard<std::mutex> Lock{adapter.Mutex};
46-
if (--adapter.RefCount == 0) {
47-
cl_ext::ExtFuncPtrCache.reset();
63+
// Check first if the adapter is valid pointer
64+
if (adapter) {
65+
std::lock_guard<std::mutex> Lock{adapter->Mutex};
66+
if (--adapter->RefCount == 0) {
67+
if (cl_ext::ExtFuncPtrCache) {
68+
delete cl_ext::ExtFuncPtrCache;
69+
cl_ext::ExtFuncPtrCache = nullptr;
70+
}
71+
}
4872
}
4973
return UR_RESULT_SUCCESS;
5074
}
@@ -68,7 +92,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
6892
case UR_ADAPTER_INFO_BACKEND:
6993
return ReturnValue(UR_ADAPTER_BACKEND_OPENCL);
7094
case UR_ADAPTER_INFO_REFERENCE_COUNT:
71-
return ReturnValue(adapter.RefCount.load());
95+
return ReturnValue(adapter->RefCount.load());
7296
default:
7397
return UR_RESULT_ERROR_INVALID_ENUMERATION;
7498
}

source/adapters/opencl/common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ struct ExtFuncPtrCacheT {
349349
// piTeardown to avoid issues with static destruction order (a user application
350350
// might have static objects that indirectly access this cache in their
351351
// destructor).
352-
inline std::unique_ptr<ExtFuncPtrCacheT> ExtFuncPtrCache;
352+
inline ExtFuncPtrCacheT *ExtFuncPtrCache;
353353

354354
// USM helper function to get an extension function pointer
355355
template <typename T>

0 commit comments

Comments
 (0)