Skip to content

Commit 94fffb4

Browse files
committed
Match cuda
1 parent b977d7c commit 94fffb4

File tree

11 files changed

+119
-80
lines changed

11 files changed

+119
-80
lines changed

source/adapters/opencl/adapter.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
2222
uint32_t *pNumAdapters) {
2323
if (NumEntries > 0 && phAdapters) {
2424
std::lock_guard<std::mutex> Lock{adapter.Mutex};
25-
adapter.RefCount++;
25+
// adapter.RefCount++;
26+
if (adapter.RefCount++ == 0) {
27+
cl_ext::ExtFuncPtrCache = std::make_unique<cl_ext::ExtFuncPtrCacheT>();
28+
}
2629
*phAdapters = &adapter;
2730
}
2831

@@ -40,7 +43,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
4043

4144
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
4245
std::lock_guard<std::mutex> Lock{adapter.Mutex};
43-
--adapter.RefCount;
46+
// --adapter.RefCount;
47+
if (--adapter.RefCount == 0) {
48+
cl_ext::ExtFuncPtrCache.reset();
49+
}
4450
return UR_RESULT_SUCCESS;
4551
}
4652

source/adapters/opencl/command_buffer.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "command_buffer.hpp"
1212
#include "common.hpp"
13+
#include "context.hpp"
1314

1415
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
1516
ur_context_handle_t hContext, ur_device_handle_t hDevice,
@@ -19,7 +20,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
1920
ur_queue_handle_t Queue = nullptr;
2021
UR_RETURN_ON_FAILURE(urQueueCreate(hContext, hDevice, nullptr, &Queue));
2122

22-
cl_context CLContext = cl_adapter::cast<cl_context>(hContext);
23+
cl_context CLContext = hContext->get();
2324
cl_ext::clCreateCommandBufferKHR_fn clCreateCommandBufferKHR = nullptr;
2425
cl_int Res =
2526
cl_ext::getExtFuncFromContext<decltype(clCreateCommandBufferKHR)>(
@@ -49,7 +50,7 @@ UR_APIEXPORT ur_result_t UR_APICALL
4950
urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
5051
UR_RETURN_ON_FAILURE(urQueueRetain(hCommandBuffer->hInternalQueue));
5152

52-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
53+
cl_context CLContext = hCommandBuffer->hContext->get();
5354
cl_ext::clRetainCommandBufferKHR_fn clRetainCommandBuffer = nullptr;
5455
cl_int Res = cl_ext::getExtFuncFromContext<decltype(clRetainCommandBuffer)>(
5556
CLContext, cl_ext::ExtFuncPtrCache->clRetainCommandBufferKHRCache,
@@ -66,7 +67,7 @@ UR_APIEXPORT ur_result_t UR_APICALL
6667
urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
6768
UR_RETURN_ON_FAILURE(urQueueRelease(hCommandBuffer->hInternalQueue));
6869

69-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
70+
cl_context CLContext = hCommandBuffer->hContext->get();
7071
cl_ext::clReleaseCommandBufferKHR_fn clReleaseCommandBufferKHR = nullptr;
7172
cl_int Res =
7273
cl_ext::getExtFuncFromContext<decltype(clReleaseCommandBufferKHR)>(
@@ -83,7 +84,7 @@ urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
8384

8485
UR_APIEXPORT ur_result_t UR_APICALL
8586
urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
86-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
87+
cl_context CLContext = hCommandBuffer->hContext->get();
8788
cl_ext::clFinalizeCommandBufferKHR_fn clFinalizeCommandBufferKHR = nullptr;
8889
cl_int Res =
8990
cl_ext::getExtFuncFromContext<decltype(clFinalizeCommandBufferKHR)>(
@@ -106,7 +107,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
106107
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
107108
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
108109

109-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
110+
cl_context CLContext = hCommandBuffer->hContext->get();
110111
cl_ext::clCommandNDRangeKernelKHR_fn clCommandNDRangeKernelKHR = nullptr;
111112
cl_int Res =
112113
cl_ext::getExtFuncFromContext<decltype(clCommandNDRangeKernelKHR)>(
@@ -154,7 +155,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
154155
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
155156
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
156157

157-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
158+
cl_context CLContext = hCommandBuffer->hContext->get();
158159
cl_ext::clCommandCopyBufferKHR_fn clCommandCopyBufferKHR = nullptr;
159160
cl_int Res = cl_ext::getExtFuncFromContext<decltype(clCommandCopyBufferKHR)>(
160161
CLContext, cl_ext::ExtFuncPtrCache->clCommandCopyBufferKHRCache,
@@ -190,7 +191,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
190191
size_t OpenCLDstRect[3]{dstOrigin.x, dstOrigin.y, dstOrigin.z};
191192
size_t OpenCLRegion[3]{region.width, region.height, region.depth};
192193

193-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
194+
cl_context CLContext = hCommandBuffer->hContext->get();
194195
cl_ext::clCommandCopyBufferRectKHR_fn clCommandCopyBufferRectKHR = nullptr;
195196
cl_int Res =
196197
cl_ext::getExtFuncFromContext<decltype(clCommandCopyBufferRectKHR)>(
@@ -280,7 +281,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMembufferFillExp(
280281
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
281282
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
282283

283-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
284+
cl_context CLContext = hCommandBuffer->hContext->get();
284285
cl_ext::clCommandFillBufferKHR_fn clCommandFillBufferKHR = nullptr;
285286
cl_int Res = cl_ext::getExtFuncFromContext<decltype(clCommandFillBufferKHR)>(
286287
CLContext, cl_ext::ExtFuncPtrCache->clCommandFillBufferKHRCache,
@@ -302,7 +303,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
302303
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
303304
ur_event_handle_t *phEvent) {
304305

305-
cl_context CLContext = cl_adapter::cast<cl_context>(hCommandBuffer->hContext);
306+
cl_context CLContext = hCommandBuffer->hContext->get();
306307
cl_ext::clEnqueueCommandBufferKHR_fn clEnqueueCommandBufferKHR = nullptr;
307308
cl_int Res =
308309
cl_ext::getExtFuncFromContext<decltype(clEnqueueCommandBufferKHR)>(

source/adapters/opencl/context.cpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,16 @@ ur_result_t cl_adapter::getDevicesFromContext(
1818
ur_context_handle_t hContext,
1919
std::unique_ptr<std::vector<cl_device_id>> &DevicesInCtx) {
2020

21-
cl_uint DeviceCount;
22-
CL_RETURN_ON_FAILURE(clGetContextInfo(cl_adapter::cast<cl_context>(hContext),
23-
CL_CONTEXT_NUM_DEVICES, sizeof(cl_uint),
24-
&DeviceCount, nullptr));
21+
cl_uint DeviceCount = hContext->DeviceCount;
2522

2623
if (DeviceCount < 1) {
2724
return UR_RESULT_ERROR_INVALID_CONTEXT;
2825
}
2926

3027
DevicesInCtx = std::make_unique<std::vector<cl_device_id>>(DeviceCount);
31-
32-
CL_RETURN_ON_FAILURE(clGetContextInfo(
33-
cl_adapter::cast<cl_context>(hContext), CL_CONTEXT_DEVICES,
34-
DeviceCount * sizeof(cl_device_id), (*DevicesInCtx).data(), nullptr));
28+
for (size_t i = 0; i < DeviceCount; i++) {
29+
(*DevicesInCtx)[i] = hContext->Devices[i]->get();
30+
}
3531

3632
return UR_RESULT_SUCCESS;
3733
}
@@ -41,11 +37,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
4137
const ur_context_properties_t *, ur_context_handle_t *phContext) {
4238

4339
cl_int Ret;
44-
*phContext = cl_adapter::cast<ur_context_handle_t>(
45-
clCreateContext(nullptr, cl_adapter::cast<cl_uint>(DeviceCount),
46-
cl_adapter::cast<const cl_device_id *>(phDevices),
47-
nullptr, nullptr, cl_adapter::cast<cl_int *>(&Ret)));
40+
std::vector<cl_device_id> CLDevices(DeviceCount);
41+
for (size_t i = 0; i < DeviceCount; i++) {
42+
CLDevices[i] = phDevices[i]->get();
43+
}
44+
45+
cl_context Ctx = clCreateContext(nullptr, cl_adapter::cast<cl_uint>(DeviceCount),
46+
CLDevices.data(),
47+
nullptr, nullptr, cl_adapter::cast<cl_int *>(&Ret));
4848

49+
*phContext = new ur_context_handle_t_(Ctx, DeviceCount, phDevices);
4950
return mapCLErrorToUR(Ret);
5051
}
5152

@@ -95,7 +96,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
9596
case UR_CONTEXT_INFO_REFERENCE_COUNT: {
9697
size_t CheckPropSize = 0;
9798
auto ClResult =
98-
clGetContextInfo(cl_adapter::cast<cl_context>(hContext), CLPropName,
99+
clGetContextInfo(hContext->get(), CLPropName,
99100
propSize, pPropValue, &CheckPropSize);
100101
if (pPropValue && CheckPropSize != propSize) {
101102
return UR_RESULT_ERROR_INVALID_SIZE;
@@ -114,29 +115,31 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
114115
UR_APIEXPORT ur_result_t UR_APICALL
115116
urContextRelease(ur_context_handle_t hContext) {
116117

117-
cl_int Ret = clReleaseContext(cl_adapter::cast<cl_context>(hContext));
118+
cl_int Ret = clReleaseContext(hContext->get());
118119
return mapCLErrorToUR(Ret);
119120
}
120121

121122
UR_APIEXPORT ur_result_t UR_APICALL
122123
urContextRetain(ur_context_handle_t hContext) {
123124

124-
cl_int Ret = clRetainContext(cl_adapter::cast<cl_context>(hContext));
125+
cl_int Ret = clRetainContext(hContext->get());
125126
return mapCLErrorToUR(Ret);
126127
}
127128

128129
UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
129130
ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) {
130131

131-
*phNativeContext = reinterpret_cast<ur_native_handle_t>(hContext);
132+
*phNativeContext = reinterpret_cast<ur_native_handle_t>(hContext->get());
132133
return UR_RESULT_SUCCESS;
133134
}
134135

135136
UR_APIEXPORT ur_result_t UR_APICALL urContextCreateWithNativeHandle(
136-
ur_native_handle_t hNativeContext, uint32_t, const ur_device_handle_t *,
137+
ur_native_handle_t hNativeContext, uint32_t numDevices, const ur_device_handle_t *phDevices,
137138
const ur_context_native_properties_t *, ur_context_handle_t *phContext) {
138139

139-
*phContext = reinterpret_cast<ur_context_handle_t>(hNativeContext);
140+
cl_context NativeHandle =
141+
reinterpret_cast<cl_context>(hNativeContext);
142+
*phContext = new ur_context_handle_t_(NativeHandle, numDevices, phDevices);
140143
return UR_RESULT_SUCCESS;
141144
}
142145

@@ -187,7 +190,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextSetExtendedDeleter(
187190
C->execute();
188191
};
189192
CL_RETURN_ON_FAILURE(clSetContextDestructorCallback(
190-
cl_adapter::cast<cl_context>(hContext), ClCallback, Callback));
193+
hContext->get(), ClCallback, Callback));
191194

192195
return UR_RESULT_SUCCESS;
193196
}

source/adapters/opencl/context.hpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,29 @@
1010
#pragma once
1111

1212
#include "common.hpp"
13+
#include "device.hpp"
14+
15+
#include <vector>
1316

1417
namespace cl_adapter {
1518
ur_result_t
1619
getDevicesFromContext(ur_context_handle_t hContext,
1720
std::unique_ptr<std::vector<cl_device_id>> &DevicesInCtx);
1821
}
1922

20-
// struct ur_context_handle_t_ {
21-
// using native_type = cl_context;
22-
// native_type Context;
23-
// std::atomic_uint32_t RefCount;
24-
// ur_platform_handle_t Platform;
25-
26-
// ur_context_handle_t_(native_type Ctx):Context(Ctx) {}
27-
28-
// ~ur_context_handle_t_() {}
29-
30-
// native_type get() { return Context; }
23+
struct ur_context_handle_t_ {
24+
using native_type = cl_context;
25+
native_type Context;
26+
std::vector<ur_device_handle_t> Devices;
27+
uint32_t DeviceCount;
3128

32-
// uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
29+
ur_context_handle_t_(native_type Ctx, uint32_t DevCount, const ur_device_handle_t *phDevices) : Context(Ctx), DeviceCount(DevCount) {
30+
for (uint32_t i = 0; i < DeviceCount; i++) {
31+
Devices.emplace_back(phDevices[i]);
32+
}
33+
}
3334

34-
// uint32_t decrementReferenceCount() noexcept { return --RefCount; }
35+
~ur_context_handle_t_() {}
3536

36-
// uint32_t getReferenceCount() const noexcept { return RefCount; }
37-
// };
37+
native_type get() { return Context; }
38+
};

source/adapters/opencl/device.cpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,31 +75,31 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGet(ur_platform_handle_t hPlatform,
7575
Type = CL_DEVICE_TYPE_ACCELERATOR;
7676
break;
7777
case UR_DEVICE_TYPE_DEFAULT:
78-
Type = CL_DEVICE_TYPE_DEFAULT;
78+
Type = UR_DEVICE_TYPE_DEFAULT;
7979
break;
8080
default:
8181
return UR_RESULT_ERROR_INVALID_ENUMERATION;
8282
}
8383

84-
std::vector<cl_device_id> CLDevices(NumEntries);
85-
cl_int Result = clGetDeviceIDs(
86-
hPlatform->get(), Type, cl_adapter::cast<cl_uint>(NumEntries),
87-
CLDevices.data(), cl_adapter::cast<cl_uint *>(pNumDevices));
88-
89-
// Absorb the CL_DEVICE_NOT_FOUND and just return 0 in num_devices
90-
if (Result == CL_DEVICE_NOT_FOUND) {
91-
Result = CL_SUCCESS;
84+
CL_RETURN_ON_FAILURE(hPlatform->GetDevices(Type));
85+
size_t NumDevices = hPlatform->Devices.size();
86+
try {
9287
if (pNumDevices) {
93-
*pNumDevices = 0;
88+
*pNumDevices = NumDevices;
9489
}
95-
}
96-
if (NumEntries && phDevices) {
97-
for (uint32_t i = 0; i < NumEntries; i++) {
98-
phDevices[i] = new ur_device_handle_t_(CLDevices[i], hPlatform);
90+
91+
if (phDevices) {
92+
for (size_t i = 0; i < std::min(size_t(NumEntries), NumDevices); ++i) {
93+
phDevices[i] = hPlatform->Devices[i];
94+
}
9995
}
100-
}
10196

102-
return mapCLErrorToUR(Result);
97+
return UR_RESULT_SUCCESS;
98+
} catch (ur_result_t Err) {
99+
return Err;
100+
} catch (...) {
101+
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
102+
}
103103
}
104104

105105
static ur_device_fp_capability_flags_t
@@ -861,7 +861,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
861861
case UR_DEVICE_INFO_MAX_PARAMETER_SIZE:
862862
case UR_DEVICE_INFO_PROFILING_TIMER_RESOLUTION:
863863
case UR_DEVICE_INFO_PRINTF_BUFFER_SIZE:
864-
case UR_DEVICE_INFO_PLATFORM:
865864
case UR_DEVICE_INFO_PARENT_DEVICE:
866865
case UR_DEVICE_INFO_IL_VERSION:
867866
case UR_DEVICE_INFO_NAME:
@@ -890,6 +889,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
890889

891890
return UR_RESULT_SUCCESS;
892891
}
892+
case UR_DEVICE_INFO_PLATFORM: {
893+
if (hDevice->Platform && hDevice->Platform->get()) {
894+
return ReturnValue(hDevice->Platform);
895+
}
896+
return UR_RESULT_ERROR_INVALID_DEVICE;
897+
}
893898
case UR_DEVICE_INFO_EXTENSIONS: {
894899
cl_device_id Dev = hDevice->get();
895900
size_t ExtSize = 0;

source/adapters/opencl/memory.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "common.hpp"
12+
#include "context.hpp"
1213

1314
cl_image_format mapURImageFormatToCL(const ur_image_format_t *PImageFormat) {
1415
cl_image_format CLImageFormat;
@@ -230,7 +231,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
230231
// TODO: need to check if all properties are supported by OpenCL RT and
231232
// ignore unsupported
232233
clCreateBufferWithPropertiesINTEL_fn FuncPtr = nullptr;
233-
cl_context CLContext = cl_adapter::cast<cl_context>(hContext);
234+
cl_context CLContext = hContext->get();
234235
// First we need to look up the function pointer
235236
RetErr =
236237
cl_ext::getExtFuncFromContext<clCreateBufferWithPropertiesINTEL_fn>(
@@ -270,7 +271,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
270271

271272
void *HostPtr = pProperties ? pProperties->pHost : nullptr;
272273
*phBuffer = reinterpret_cast<ur_mem_handle_t>(clCreateBuffer(
273-
cl_adapter::cast<cl_context>(hContext), static_cast<cl_mem_flags>(flags),
274+
hContext->get(), static_cast<cl_mem_flags>(flags),
274275
size, HostPtr, cl_adapter::cast<cl_int *>(&RetErr)));
275276
CL_RETURN_ON_FAILURE(RetErr);
276277

@@ -289,7 +290,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreate(
289290
cl_map_flags MapFlags = convertURMemFlagsToCL(flags);
290291

291292
*phMem = reinterpret_cast<ur_mem_handle_t>(clCreateImage(
292-
cl_adapter::cast<cl_context>(hContext), MapFlags, &ImageFormat,
293+
hContext->get(), MapFlags, &ImageFormat,
293294
&ImageDesc, pHost, cl_adapter::cast<cl_int *>(&RetErr)));
294295
CL_RETURN_ON_FAILURE(RetErr);
295296

0 commit comments

Comments
 (0)