Skip to content

Commit 2b29c7f

Browse files
committed
Add devices init to platform
1 parent 94fffb4 commit 2b29c7f

File tree

4 files changed

+68
-31
lines changed

4 files changed

+68
-31
lines changed

source/adapters/opencl/device.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ ur_result_t cl_adapter::checkDeviceExtensions(
5454

5555
UR_APIEXPORT ur_result_t UR_APICALL urDeviceGet(ur_platform_handle_t hPlatform,
5656
ur_device_type_t DeviceType,
57-
uint32_t NumEntries,
57+
[[maybe_unused]] uint32_t NumEntries,
5858
ur_device_handle_t *phDevices,
5959
uint32_t *pNumDevices) {
6060

@@ -75,24 +75,26 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGet(ur_platform_handle_t hPlatform,
7575
Type = CL_DEVICE_TYPE_ACCELERATOR;
7676
break;
7777
case UR_DEVICE_TYPE_DEFAULT:
78-
Type = UR_DEVICE_TYPE_DEFAULT;
78+
Type = CL_DEVICE_TYPE_DEFAULT;
7979
break;
8080
default:
8181
return UR_RESULT_ERROR_INVALID_ENUMERATION;
8282
}
83-
84-
CL_RETURN_ON_FAILURE(hPlatform->GetDevices(Type));
85-
size_t NumDevices = hPlatform->Devices.size();
8683
try {
87-
if (pNumDevices) {
88-
*pNumDevices = NumDevices;
89-
}
90-
91-
if (phDevices) {
92-
for (size_t i = 0; i < std::min(size_t(NumEntries), NumDevices); ++i) {
93-
phDevices[i] = hPlatform->Devices[i];
84+
uint32_t AllDevicesNum = hPlatform->Devices.size();
85+
uint32_t DeviceNumIter = 0;
86+
for (uint32_t i = 0; i < AllDevicesNum; i++) {
87+
cl_device_type DeviceType = hPlatform->Devices[i]->Type;
88+
if (DeviceType == Type || Type == CL_DEVICE_TYPE_ALL) {
89+
if (phDevices) {
90+
phDevices[DeviceNumIter] = hPlatform->Devices[i];
91+
}
92+
DeviceNumIter++;
9493
}
9594
}
95+
if (pNumDevices) {
96+
*pNumDevices = DeviceNumIter;
97+
}
9698

9799
return UR_RESULT_SUCCESS;
98100
} catch (ur_result_t Err) {
@@ -329,9 +331,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
329331
* to UR */
330332
switch (static_cast<uint32_t>(propName)) {
331333
case UR_DEVICE_INFO_TYPE: {
332-
cl_device_type CLType;
333-
CL_RETURN_ON_FAILURE(clGetDeviceInfo(
334-
hDevice->get(), CLPropName, sizeof(cl_device_type), &CLType, nullptr));
334+
cl_device_type CLType = hDevice->Type;
335335

336336
/* TODO UR: If the device is an Accelerator (FPGA, VPU, etc.), there is not
337337
* enough information in the OpenCL runtime to know exactly which type it
@@ -861,7 +861,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
861861
case UR_DEVICE_INFO_MAX_PARAMETER_SIZE:
862862
case UR_DEVICE_INFO_PROFILING_TIMER_RESOLUTION:
863863
case UR_DEVICE_INFO_PRINTF_BUFFER_SIZE:
864-
case UR_DEVICE_INFO_PARENT_DEVICE:
865864
case UR_DEVICE_INFO_IL_VERSION:
866865
case UR_DEVICE_INFO_NAME:
867866
case UR_DEVICE_INFO_VENDOR:
@@ -895,6 +894,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
895894
}
896895
return UR_RESULT_ERROR_INVALID_DEVICE;
897896
}
897+
case UR_DEVICE_INFO_PARENT_DEVICE: {
898+
return ReturnValue(hDevice->ParentDevice);
899+
}
898900
case UR_DEVICE_INFO_EXTENSIONS: {
899901
cl_device_id Dev = hDevice->get();
900902
size_t ExtSize = 0;
@@ -997,9 +999,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urDevicePartition(
997999
CL_RETURN_ON_FAILURE(clCreateSubDevices(hDevice->get(), CLProperties.data(),
9981000
CLNumDevicesRet,
9991001
CLSubDevices.data(), nullptr));
1000-
1001-
std::memcpy(phSubDevices, CLSubDevices.data(),
1002-
sizeof(cl_device_id) * NumDevices);
1002+
for (uint32_t i = 0; i < NumDevices; i++) {
1003+
phSubDevices[i] = new ur_device_handle_t_(CLSubDevices[i], hDevice->Platform, hDevice);
1004+
}
10031005
}
10041006

10051007
return UR_RESULT_SUCCESS;
@@ -1031,7 +1033,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
10311033
ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform,
10321034
const ur_device_native_properties_t *, ur_device_handle_t *phDevice) {
10331035
cl_device_id NativeHandle = reinterpret_cast<cl_device_id>(hNativeDevice);
1034-
*phDevice = new ur_device_handle_t_(NativeHandle, hPlatform);
1036+
*phDevice = new ur_device_handle_t_(NativeHandle, hPlatform, nullptr);
10351037
return UR_RESULT_SUCCESS;
10361038
}
10371039

source/adapters/opencl/device.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,17 @@ struct ur_device_handle_t_ {
2323
using native_type = cl_device_id;
2424
native_type Device;
2525
ur_platform_handle_t Platform;
26+
cl_device_type Type = 0;
27+
ur_device_handle_t ParentDevice = nullptr;
2628

27-
ur_device_handle_t_(native_type Dev, ur_platform_handle_t Plat)
28-
: Device(Dev), Platform(Plat) {}
29+
ur_device_handle_t_(native_type Dev, ur_platform_handle_t Plat, ur_device_handle_t Parent)
30+
: Device(Dev), Platform(Plat), ParentDevice(Parent) {
31+
if (Parent) {
32+
Type = Parent->Type;
33+
} else {
34+
clGetDeviceInfo(Device, CL_DEVICE_TYPE, sizeof(cl_device_type), &Type, nullptr);
35+
}
36+
}
2937

3038
~ur_device_handle_t_() {}
3139

source/adapters/opencl/platform.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,46 @@ UR_APIEXPORT ur_result_t UR_APICALL
8787
urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
8888
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {
8989

90-
std::vector<cl_platform_id> CLPlatforms(NumEntries);
91-
cl_int Result = clGetPlatformIDs(cl_adapter::cast<cl_uint>(NumEntries),
92-
CLPlatforms.data(),
93-
cl_adapter::cast<cl_uint *>(pNumPlatforms));
90+
static std::vector<ur_platform_handle_t> URPlatforms;
91+
static std::once_flag InitFlag;
92+
static uint32_t NumPlatforms = 0;
93+
cl_int Result = CL_SUCCESS;
94+
95+
std::call_once(
96+
InitFlag,
97+
[](cl_int &Result) {
98+
Result = clGetPlatformIDs(0, nullptr, &NumPlatforms);
99+
if (Result != CL_SUCCESS) {
100+
return Result;
101+
}
102+
std::vector<cl_platform_id> CLPlatforms(NumPlatforms);
103+
Result = clGetPlatformIDs(cl_adapter::cast<cl_uint>(NumPlatforms),
104+
CLPlatforms.data(),
105+
nullptr);
106+
if (Result != CL_SUCCESS) {
107+
return Result;
108+
}
109+
URPlatforms.resize(NumPlatforms);
110+
for (uint32_t i = 0; i < NumPlatforms; i++) {
111+
URPlatforms[i] = new ur_platform_handle_t_(CLPlatforms[i]);
112+
}
113+
return Result;
114+
},
115+
Result);
116+
94117
/* Absorb the CL_PLATFORM_NOT_FOUND_KHR and just return 0 in num_platforms */
95118
if (Result == CL_PLATFORM_NOT_FOUND_KHR) {
96119
Result = CL_SUCCESS;
97120
if (pNumPlatforms) {
98121
*pNumPlatforms = 0;
99122
}
100123
}
124+
if (pNumPlatforms != nullptr) {
125+
*pNumPlatforms = NumPlatforms;
126+
}
101127
if (NumEntries && phPlatforms) {
102128
for (uint32_t i = 0; i < NumEntries; i++) {
103-
phPlatforms[i] = new ur_platform_handle_t_(CLPlatforms[i]);
129+
phPlatforms[i] = URPlatforms[i];
104130
}
105131
}
106132
return mapCLErrorToUR(Result);

source/adapters/opencl/platform.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ struct ur_platform_handle_t_ {
179179

180180
ur_platform_handle_t_(native_type Plat) : Platform(Plat) {
181181
ExtFuncPtr = std::make_unique<cl_adapter::ExtFuncPtrT>();
182+
InitDevices();
182183
}
183184

184185
~ur_platform_handle_t_() { ExtFuncPtr.reset(); }
@@ -199,16 +200,16 @@ struct ur_platform_handle_t_ {
199200

200201
native_type get() { return Platform; }
201202

202-
ur_result_t GetDevices(cl_device_type Type) {
203+
ur_result_t InitDevices() {
203204
cl_uint DeviceNum = 0;
204-
CL_RETURN_ON_FAILURE(clGetDeviceIDs(Platform, Type, 0, nullptr, &DeviceNum));
205+
CL_RETURN_ON_FAILURE(clGetDeviceIDs(Platform, CL_DEVICE_TYPE_ALL, 0, nullptr, &DeviceNum));
205206

206207
std::vector<cl_device_id> CLDevices(DeviceNum);
207-
CL_RETURN_ON_FAILURE(clGetDeviceIDs(Platform, Type, DeviceNum, CLDevices.data(), nullptr));
208+
CL_RETURN_ON_FAILURE(clGetDeviceIDs(Platform, CL_DEVICE_TYPE_ALL, DeviceNum, CLDevices.data(), nullptr));
208209

209210
Devices = std::vector<ur_device_handle_t>(DeviceNum);
210211
for (size_t i = 0; i < DeviceNum; i++) {
211-
Devices[i] = new ur_device_handle_t_(CLDevices[i], this);
212+
Devices[i] = new ur_device_handle_t_(CLDevices[i], this, nullptr);
212213
}
213214

214215
return UR_RESULT_SUCCESS;

0 commit comments

Comments
 (0)