Skip to content

Commit 9cf7e2f

Browse files
committed
Change CreateWithNative impl for multiple handles
1 parent 9d0d48e commit 9cf7e2f

File tree

9 files changed

+101
-14
lines changed

9 files changed

+101
-14
lines changed

source/adapters/opencl/context.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreateWithNativeHandle(
145145
cl_context NativeHandle = reinterpret_cast<cl_context>(hNativeContext);
146146
auto URContext = std::make_unique<ur_context_handle_t_>(
147147
NativeHandle, numDevices, phDevices);
148+
UR_RETURN_ON_FAILURE(URContext->initWithNative());
148149
*phContext = URContext.release();
149150
return UR_RESULT_SUCCESS;
150151
}

source/adapters/opencl/context.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,25 @@ struct ur_context_handle_t_ {
3333
Devices.emplace_back(phDevices[i]);
3434
}
3535
}
36+
ur_result_t initWithNative() {
37+
if (!DeviceCount) {
38+
CL_RETURN_ON_FAILURE(clGetContextInfo(Context, CL_CONTEXT_NUM_DEVICES,
39+
sizeof(DeviceCount), &DeviceCount,
40+
nullptr));
41+
std::vector<cl_device_id> CLDevices(DeviceCount);
42+
CL_RETURN_ON_FAILURE(clGetContextInfo(Context, CL_CONTEXT_DEVICES,
43+
sizeof(CLDevices), CLDevices.data(),
44+
nullptr));
45+
Devices.resize(DeviceCount);
46+
for (uint32_t i = 0; i < DeviceCount; i++) {
47+
ur_native_handle_t NativeDevice =
48+
reinterpret_cast<ur_native_handle_t>(CLDevices[i]);
49+
UR_RETURN_ON_FAILURE(urDeviceCreateWithNativeHandle(
50+
NativeDevice, nullptr, nullptr, &Devices[i]));
51+
}
52+
}
53+
return UR_RESULT_SUCCESS;
54+
}
3655

3756
~ur_context_handle_t_() {}
3857

source/adapters/opencl/device.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,13 +1032,32 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
10321032
}
10331033

10341034
UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
1035-
ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform,
1035+
ur_native_handle_t hNativeDevice, ur_platform_handle_t,
10361036
const ur_device_native_properties_t *, ur_device_handle_t *phDevice) {
10371037
cl_device_id NativeHandle = reinterpret_cast<cl_device_id>(hNativeDevice);
1038-
auto URDevice =
1039-
std::make_unique<ur_device_handle_t_>(NativeHandle, hPlatform, nullptr);
1040-
*phDevice = URDevice.release();
1041-
return UR_RESULT_SUCCESS;
1038+
1039+
uint32_t NumPlatforms = 0;
1040+
UR_RETURN_ON_FAILURE(urPlatformGet(nullptr, 0, 0, nullptr, &NumPlatforms));
1041+
std::vector<ur_platform_handle_t> Platforms(NumPlatforms);
1042+
UR_RETURN_ON_FAILURE(
1043+
urPlatformGet(nullptr, 0, NumPlatforms, Platforms.data(), nullptr));
1044+
1045+
for (uint32_t i = 0; i < NumPlatforms; i++) {
1046+
uint32_t NumDevices = 0;
1047+
UR_RETURN_ON_FAILURE(
1048+
urDeviceGet(Platforms[i], UR_DEVICE_TYPE_ALL, 0, nullptr, &NumDevices));
1049+
std::vector<ur_device_handle_t> Devices(NumDevices);
1050+
UR_RETURN_ON_FAILURE(urDeviceGet(Platforms[i], UR_DEVICE_TYPE_ALL,
1051+
NumDevices, Devices.data(), nullptr));
1052+
1053+
for (auto &Device : Devices) {
1054+
if (Device->get() == NativeHandle) {
1055+
*phDevice = Device;
1056+
return UR_RESULT_SUCCESS;
1057+
}
1058+
}
1059+
}
1060+
return UR_RESULT_ERROR_INVALID_DEVICE;
10421061
}
10431062

10441063
UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(

source/adapters/opencl/device.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include "common.hpp"
13+
#include "platform.hpp"
1314

1415
namespace cl_adapter {
1516
ur_result_t getDeviceVersion(cl_device_id Dev, oclv::OpenCLVersion &Version);

source/adapters/opencl/memory.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle(
349349
const ur_mem_native_properties_t *pProperties, ur_mem_handle_t *phMem) {
350350
cl_mem NativeHandle = reinterpret_cast<cl_mem>(hNativeMem);
351351
auto URMem = std::make_unique<ur_mem_handle_t_>(NativeHandle, hContext);
352+
UR_RETURN_ON_FAILURE(URMem->initWithNative());
352353
*phMem = URMem.release();
353354
if (!pProperties || !pProperties->isNativeHandleOwned) {
354355
return urMemRetain(*phMem);
@@ -363,6 +364,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreateWithNativeHandle(
363364
const ur_mem_native_properties_t *pProperties, ur_mem_handle_t *phMem) {
364365
cl_mem NativeHandle = reinterpret_cast<cl_mem>(hNativeMem);
365366
auto URMem = std::make_unique<ur_mem_handle_t_>(NativeHandle, hContext);
367+
UR_RETURN_ON_FAILURE(URMem->initWithNative());
366368
*phMem = URMem.release();
367369
if (!pProperties || !pProperties->isNativeHandleOwned) {
368370
return urMemRetain(*phMem);

source/adapters/opencl/memory.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,18 @@ struct ur_mem_handle_t_ {
2323

2424
~ur_mem_handle_t_() {}
2525

26+
ur_result_t initWithNative() {
27+
if (!Context) {
28+
cl_context CLContext;
29+
CL_RETURN_ON_FAILURE(clGetMemObjectInfo(
30+
Memory, CL_MEM_CONTEXT, sizeof(CLContext), &CLContext, nullptr));
31+
ur_native_handle_t NativeContext =
32+
reinterpret_cast<ur_native_handle_t>(CLContext);
33+
UR_RETURN_ON_FAILURE(urContextCreateWithNativeHandle(
34+
NativeContext, 0, nullptr, nullptr, &Context));
35+
}
36+
return UR_RESULT_SUCCESS;
37+
}
38+
2639
native_type get() { return Memory; }
2740
};

source/adapters/opencl/platform.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformCreateWithNativeHandle(
143143
ur_platform_handle_t *phPlatform) {
144144
cl_platform_id NativeHandle =
145145
reinterpret_cast<cl_platform_id>(hNativePlatform);
146-
auto URPlatform = std::make_unique<ur_platform_handle_t_>(NativeHandle);
147-
*phPlatform = URPlatform.release();
148-
return UR_RESULT_SUCCESS;
146+
147+
uint32_t NumPlatforms = 0;
148+
UR_RETURN_ON_FAILURE(urPlatformGet(nullptr, 0, 0, nullptr, &NumPlatforms));
149+
std::vector<ur_platform_handle_t> Platforms(NumPlatforms);
150+
UR_RETURN_ON_FAILURE(
151+
urPlatformGet(nullptr, 0, NumPlatforms, Platforms.data(), nullptr));
152+
153+
for (uint32_t i = 0; i < NumPlatforms; i++) {
154+
if (Platforms[i]->get() == NativeHandle) {
155+
*phPlatform = Platforms[i];
156+
return UR_RESULT_SUCCESS;
157+
}
158+
}
159+
return UR_RESULT_ERROR_INVALID_PLATFORM;
149160
}
150161

151162
// Returns plugin specific backend option.

source/adapters/opencl/queue.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,21 +162,20 @@ urQueueGetNativeHandle(ur_queue_handle_t hQueue, ur_queue_native_desc_t *,
162162
}
163163

164164
UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
165-
ur_native_handle_t hNativeQueue,
166-
[[maybe_unused]] ur_context_handle_t hContext,
167-
[[maybe_unused]] ur_device_handle_t hDevice,
165+
ur_native_handle_t hNativeQueue, ur_context_handle_t hContext,
166+
ur_device_handle_t hDevice,
168167
[[maybe_unused]] const ur_queue_native_properties_t *pProperties,
169168
ur_queue_handle_t *phQueue) {
170169

171170
cl_command_queue NativeHandle =
172171
reinterpret_cast<cl_command_queue>(hNativeQueue);
173172
auto URQueue =
174173
std::make_unique<ur_queue_handle_t_>(NativeHandle, hContext, hDevice);
174+
UR_RETURN_ON_FAILURE(URQueue->initWithNative());
175175
*phQueue = URQueue.release();
176176

177-
cl_int RetErr =
178-
clRetainCommandQueue(cl_adapter::cast<cl_command_queue>(hNativeQueue));
179-
CL_RETURN_ON_FAILURE(RetErr);
177+
CL_RETURN_ON_FAILURE(clRetainCommandQueue(NativeHandle));
178+
180179
return UR_RESULT_SUCCESS;
181180
}
182181

source/adapters/opencl/queue.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,28 @@ struct ur_queue_handle_t_ {
2323
ur_device_handle_t Dev)
2424
: Queue(Queue), Context(Ctx), Device(Dev) {}
2525

26+
ur_result_t initWithNative() {
27+
if (!Context) {
28+
cl_context CLContext;
29+
CL_RETURN_ON_FAILURE(clGetCommandQueueInfo(
30+
Queue, CL_QUEUE_CONTEXT, sizeof(CLContext), &CLContext, nullptr));
31+
ur_native_handle_t NativeContext =
32+
reinterpret_cast<ur_native_handle_t>(CLContext);
33+
UR_RETURN_ON_FAILURE(urContextCreateWithNativeHandle(
34+
NativeContext, 0, nullptr, nullptr, &Context));
35+
}
36+
if (!Device) {
37+
cl_device_id CLDevice;
38+
CL_RETURN_ON_FAILURE(clGetCommandQueueInfo(
39+
Queue, CL_QUEUE_DEVICE, sizeof(CLDevice), &CLDevice, nullptr));
40+
ur_native_handle_t NativeDevice =
41+
reinterpret_cast<ur_native_handle_t>(CLDevice);
42+
UR_RETURN_ON_FAILURE(urDeviceCreateWithNativeHandle(NativeDevice, nullptr,
43+
nullptr, &Device));
44+
}
45+
return UR_RESULT_SUCCESS;
46+
}
47+
2648
~ur_queue_handle_t_() {}
2749

2850
native_type get() { return Queue; }

0 commit comments

Comments
 (0)