1212#include " context.hpp"
1313#include " device.hpp"
1414#include " platform.hpp"
15+ #include " program.hpp"
1516
1617static ur_result_t getDevicesFromProgram (
1718 ur_program_handle_t hProgram,
1819 std::unique_ptr<std::vector<cl_device_id>> &DevicesInProgram) {
1920
20- cl_uint DeviceCount;
21- CL_RETURN_ON_FAILURE (clGetProgramInfo (cl_adapter::cast<cl_program>(hProgram),
22- CL_PROGRAM_NUM_DEVICES, sizeof (cl_uint),
23- &DeviceCount, nullptr ));
24-
25- if (DeviceCount < 1 ) {
26- return UR_RESULT_ERROR_INVALID_CONTEXT;
21+ if (!hProgram->Context || !hProgram->Context ->DeviceCount ) {
22+ return UR_RESULT_ERROR_INVALID_PROGRAM;
2723 }
28-
24+ cl_uint DeviceCount = hProgram-> Context -> DeviceCount ;
2925 DevicesInProgram = std::make_unique<std::vector<cl_device_id>>(DeviceCount);
30-
31- CL_RETURN_ON_FAILURE (clGetProgramInfo (
32- cl_adapter::cast<cl_program>(hProgram), CL_PROGRAM_DEVICES,
33- DeviceCount * sizeof (cl_device_id), (*DevicesInProgram).data (), nullptr ));
34-
26+ for (uint32_t i = 0 ; i < DeviceCount; i++) {
27+ (*DevicesInProgram)[i] = hProgram->Context ->Devices [i]->get ();
28+ }
3529 return UR_RESULT_SUCCESS;
3630}
3731
3832UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithIL (
3933 ur_context_handle_t hContext, const void *pIL, size_t length,
4034 const ur_program_properties_t *, ur_program_handle_t *phProgram) {
4135
42- std::unique_ptr<std::vector<cl_device_id>> DevicesInCtx;
43- CL_RETURN_ON_FAILURE_AND_SET_NULL (
44- cl_adapter::getDevicesFromContext (hContext, DevicesInCtx), phProgram);
45-
46- cl_platform_id CurPlatform;
47- CL_RETURN_ON_FAILURE_AND_SET_NULL (
48- clGetDeviceInfo ((*DevicesInCtx)[0 ], CL_DEVICE_PLATFORM,
49- sizeof (cl_platform_id), &CurPlatform, nullptr ),
50- phProgram);
36+ if (!hContext->DeviceCount || !hContext->Devices [0 ]->Platform ) {
37+ return UR_RESULT_ERROR_INVALID_CONTEXT;
38+ }
39+ cl_platform_id CurPlatform = hContext->Devices [0 ]->Platform ->get ();
5140
5241 oclv::OpenCLVersion PlatVer;
5342 CL_RETURN_ON_FAILURE_AND_SET_NULL (
@@ -57,7 +46,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithIL(
5746 if (PlatVer >= oclv::V2_1) {
5847
5948 /* Make sure all devices support CL 2.1 or newer as well. */
60- for (cl_device_id Dev : *DevicesInCtx) {
49+ for (ur_device_handle_t URDev : hContext->Devices ) {
50+ cl_device_id Dev = URDev->get ();
6151 oclv::OpenCLVersion DevVer;
6252
6353 CL_RETURN_ON_FAILURE_AND_SET_NULL (
@@ -79,15 +69,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithIL(
7969 }
8070 }
8171
82- *phProgram = cl_adapter::cast<ur_program_handle_t >(clCreateProgramWithIL (
83- hContext->get (), pIL, length, &Err));
72+ cl_program Program = clCreateProgramWithIL (hContext->get (), pIL, length, &Err);
8473 CL_RETURN_ON_FAILURE (Err);
74+
75+ *phProgram = new ur_program_handle_t_ (Program, hContext);
8576 } else {
8677
8778 /* If none of the devices conform with CL 2.1 or newer make sure they all
8879 * support the cl_khr_il_program extension.
8980 */
90- for (cl_device_id Dev : *DevicesInCtx) {
81+ for (ur_device_handle_t URDev : hContext->Devices ) {
82+ cl_device_id Dev = URDev->get ();
9183 bool Supported = false ;
9284 CL_RETURN_ON_FAILURE_AND_SET_NULL (
9385 cl_adapter::checkDeviceExtensions (Dev, {" cl_khr_il_program" },
@@ -106,9 +98,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithIL(
10698 CurPlatform, " clCreateProgramWithILKHR" ));
10799
108100 assert (FuncPtr != nullptr );
101+ cl_program Program = FuncPtr (hContext->get (), pIL, length, &Err);
102+ *phProgram = new ur_program_handle_t_ (Program, hContext);
109103
110- *phProgram = cl_adapter::cast<ur_program_handle_t >(
111- FuncPtr (hContext->get (), pIL, length, &Err));
112104 CL_RETURN_ON_FAILURE (Err);
113105 }
114106
@@ -124,9 +116,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
124116 const size_t Lengths[1 ] = {size};
125117 cl_int BinaryStatus[1 ];
126118 cl_int CLResult;
127- *phProgram = cl_adapter::cast< ur_program_handle_t >( clCreateProgramWithBinary (
119+ cl_program Program = clCreateProgramWithBinary (
128120 hContext->get (), cl_adapter::cast<cl_uint>(1u ),
129- Devices, Lengths, &pBinary, BinaryStatus, &CLResult));
121+ Devices, Lengths, &pBinary, BinaryStatus, &CLResult);
122+ *phProgram = new ur_program_handle_t_ (Program, hContext);
130123 CL_RETURN_ON_FAILURE (BinaryStatus[0 ]);
131124 CL_RETURN_ON_FAILURE (CLResult);
132125
@@ -140,7 +133,7 @@ urProgramCompile([[maybe_unused]] ur_context_handle_t hContext,
140133 std::unique_ptr<std::vector<cl_device_id>> DevicesInProgram;
141134 CL_RETURN_ON_FAILURE (getDevicesFromProgram (hProgram, DevicesInProgram));
142135
143- CL_RETURN_ON_FAILURE (clCompileProgram (cl_adapter::cast<cl_program>(hProgram ),
136+ CL_RETURN_ON_FAILURE (clCompileProgram (hProgram-> get ( ),
144137 DevicesInProgram->size (),
145138 DevicesInProgram->data (), pOptions, 0 ,
146139 nullptr , nullptr , nullptr , nullptr ));
@@ -178,7 +171,7 @@ UR_APIEXPORT ur_result_t UR_APICALL
178171urProgramGetInfo (ur_program_handle_t hProgram, ur_program_info_t propName,
179172 size_t propSize, void *pPropValue, size_t *pPropSizeRet) {
180173 size_t CheckPropSize = 0 ;
181- auto ClResult = clGetProgramInfo (cl_adapter::cast<cl_program>(hProgram ),
174+ auto ClResult = clGetProgramInfo (hProgram-> get ( ),
182175 mapURProgramInfoToCL (propName), propSize,
183176 pPropValue, &CheckPropSize);
184177 if (pPropValue && CheckPropSize != propSize) {
@@ -199,7 +192,7 @@ urProgramBuild([[maybe_unused]] ur_context_handle_t hContext,
199192 CL_RETURN_ON_FAILURE (getDevicesFromProgram (hProgram, DevicesInProgram));
200193
201194 CL_RETURN_ON_FAILURE (clBuildProgram (
202- cl_adapter::cast<cl_program>(hProgram ), DevicesInProgram->size (),
195+ hProgram-> get ( ), DevicesInProgram->size (),
203196 DevicesInProgram->data (), pOptions, nullptr , nullptr ));
204197 return UR_RESULT_SUCCESS;
205198}
@@ -210,11 +203,16 @@ urProgramLink(ur_context_handle_t hContext, uint32_t count,
210203 ur_program_handle_t *phProgram) {
211204
212205 cl_int CLResult;
213- *phProgram = cl_adapter::cast<ur_program_handle_t >(
206+ std::vector<cl_program> CLPrograms (count);
207+ for (uint32_t i = 0 ; i < count; i++) {
208+ CLPrograms[i] = phPrograms[i]->get ();
209+ }
210+ cl_program Program =
214211 clLinkProgram (hContext->get (), 0 , nullptr ,
215212 pOptions, cl_adapter::cast<cl_uint>(count),
216- cl_adapter::cast<const cl_program *>(phPrograms), nullptr ,
217- nullptr , &CLResult));
213+ CLPrograms.data (), nullptr ,
214+ nullptr , &CLResult);
215+ *phProgram = new ur_program_handle_t_ (Program, hContext);
218216 CL_RETURN_ON_FAILURE (CLResult);
219217
220218 return UR_RESULT_SUCCESS;
@@ -280,14 +278,14 @@ urProgramGetBuildInfo(ur_program_handle_t hProgram, ur_device_handle_t hDevice,
280278 UrReturnHelper ReturnValue (propSize, pPropValue, pPropSizeRet);
281279 cl_program_binary_type BinaryType;
282280 CL_RETURN_ON_FAILURE (clGetProgramBuildInfo (
283- cl_adapter::cast<cl_program>(hProgram ), hDevice->get (),
281+ hProgram-> get ( ), hDevice->get (),
284282 mapURProgramBuildInfoToCL (propName), sizeof (cl_program_binary_type),
285283 &BinaryType, nullptr ));
286284 return ReturnValue (mapCLBinaryTypeToUR (BinaryType));
287285 }
288286 size_t CheckPropSize = 0 ;
289287 cl_int ClErr =
290- clGetProgramBuildInfo (cl_adapter::cast<cl_program>(hProgram ),
288+ clGetProgramBuildInfo (hProgram-> get ( ),
291289 hDevice->get (), mapURProgramBuildInfoToCL (propName),
292290 propSize, pPropValue, &CheckPropSize);
293291 if (pPropValue && CheckPropSize != propSize) {
@@ -304,30 +302,32 @@ urProgramGetBuildInfo(ur_program_handle_t hProgram, ur_device_handle_t hDevice,
304302UR_APIEXPORT ur_result_t UR_APICALL
305303urProgramRetain (ur_program_handle_t hProgram) {
306304
307- CL_RETURN_ON_FAILURE (clRetainProgram (cl_adapter::cast<cl_program>(hProgram )));
305+ CL_RETURN_ON_FAILURE (clRetainProgram (hProgram-> get ( )));
308306 return UR_RESULT_SUCCESS;
309307}
310308
311309UR_APIEXPORT ur_result_t UR_APICALL
312310urProgramRelease (ur_program_handle_t hProgram) {
313311
314312 CL_RETURN_ON_FAILURE (
315- clReleaseProgram (cl_adapter::cast<cl_program>(hProgram )));
313+ clReleaseProgram (hProgram-> get ( )));
316314 return UR_RESULT_SUCCESS;
317315}
318316
319317UR_APIEXPORT ur_result_t UR_APICALL urProgramGetNativeHandle (
320318 ur_program_handle_t hProgram, ur_native_handle_t *phNativeProgram) {
321319
322- *phNativeProgram = reinterpret_cast <ur_native_handle_t >(hProgram);
320+ *phNativeProgram = reinterpret_cast <ur_native_handle_t >(hProgram-> get () );
323321 return UR_RESULT_SUCCESS;
324322}
325323
326324UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithNativeHandle (
327- ur_native_handle_t hNativeProgram, ur_context_handle_t ,
325+ ur_native_handle_t hNativeProgram, ur_context_handle_t hContext ,
328326 const ur_program_native_properties_t *pProperties,
329327 ur_program_handle_t *phProgram) {
330- *phProgram = reinterpret_cast <ur_program_handle_t >(hNativeProgram);
328+ cl_program NativeHandle =
329+ reinterpret_cast <cl_program>(hNativeProgram);
330+ *phProgram = new ur_program_handle_t_ (NativeHandle, hContext);
331331 if (!pProperties || !pProperties->isNativeHandleOwned ) {
332332 return urProgramRetain (*phProgram);
333333 }
@@ -338,20 +338,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramSetSpecializationConstants(
338338 ur_program_handle_t hProgram, uint32_t count,
339339 const ur_specialization_constant_info_t *pSpecConstants) {
340340
341- cl_program CLProg = cl_adapter::cast<cl_program>(hProgram);
342- cl_context Ctx = nullptr ;
343- size_t RetSize = 0 ;
344-
345- CL_RETURN_ON_FAILURE (clGetProgramInfo (CLProg, CL_PROGRAM_CONTEXT, sizeof (Ctx),
346- &Ctx, &RetSize));
341+ cl_program CLProg = hProgram->get ();
342+ if (!hProgram->Context ) {
343+ return UR_RESULT_ERROR_INVALID_PROGRAM;
344+ }
345+ ur_context_handle_t Ctx = hProgram->Context ;
346+ if (!Ctx->DeviceCount || !Ctx->Devices [0 ]->Platform ) {
347+ return UR_RESULT_ERROR_INVALID_CONTEXT;
348+ }
347349
348350 std::unique_ptr<std::vector<cl_device_id>> DevicesInCtx;
349- cl_adapter::getDevicesFromContext (cl_adapter::cast<ur_context_handle_t >(Ctx),
350- DevicesInCtx);
351+ cl_adapter::getDevicesFromContext (Ctx, DevicesInCtx);
351352
352- cl_platform_id CurPlatform;
353- clGetDeviceInfo ((*DevicesInCtx)[0 ], CL_DEVICE_PLATFORM,
354- sizeof (cl_platform_id), &CurPlatform, nullptr );
353+ cl_platform_id CurPlatform = Ctx->Devices [0 ]->Platform ->get ();
355354
356355 oclv::OpenCLVersion PlatVer;
357356 cl_adapter::getPlatformVersion (CurPlatform, PlatVer);
@@ -383,7 +382,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramSetSpecializationConstants(
383382 SetProgramSpecializationConstant = nullptr ;
384383 const ur_result_t URResult = cl_ext::getExtFuncFromContext<
385384 decltype (SetProgramSpecializationConstant)>(
386- Ctx, cl_ext::ExtFuncPtrCache->clSetProgramSpecializationConstantCache ,
385+ Ctx-> get () , cl_ext::ExtFuncPtrCache->clSetProgramSpecializationConstantCache ,
387386 cl_ext::SetProgramSpecializationConstantName,
388387 &SetProgramSpecializationConstant);
389388
@@ -430,10 +429,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetFunctionPointer(
430429 ur_device_handle_t hDevice, ur_program_handle_t hProgram,
431430 const char *pFunctionName, void **ppFunctionPointer) {
432431
433- cl_context CLContext = nullptr ;
434- CL_RETURN_ON_FAILURE (clGetProgramInfo (cl_adapter::cast<cl_program>(hProgram),
435- CL_PROGRAM_CONTEXT, sizeof (CLContext),
436- &CLContext, nullptr ));
432+ cl_context CLContext = hProgram->Context ->get ();
437433
438434 cl_ext::clGetDeviceFunctionPointer_fn FuncT = nullptr ;
439435
@@ -453,14 +449,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetFunctionPointer(
453449 // throws exceptions.
454450 *ppFunctionPointer = 0 ;
455451 size_t Size;
456- CL_RETURN_ON_FAILURE (clGetProgramInfo (cl_adapter::cast<cl_program>(hProgram ),
452+ CL_RETURN_ON_FAILURE (clGetProgramInfo (hProgram-> get ( ),
457453 CL_PROGRAM_KERNEL_NAMES, 0 , nullptr ,
458454 &Size));
459455
460456 std::string KernelNames (Size, ' ' );
461457
462458 CL_RETURN_ON_FAILURE (clGetProgramInfo (
463- cl_adapter::cast<cl_program>(hProgram ), CL_PROGRAM_KERNEL_NAMES,
459+ hProgram-> get ( ), CL_PROGRAM_KERNEL_NAMES,
464460 KernelNames.size (), &KernelNames[0 ], nullptr ));
465461
466462 // Get rid of the null terminator and search for the kernel name. If the
@@ -471,7 +467,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetFunctionPointer(
471467 }
472468
473469 const cl_int CLResult =
474- FuncT (hDevice->get (), cl_adapter::cast<cl_program>(hProgram ),
470+ FuncT (hDevice->get (), hProgram-> get ( ),
475471 pFunctionName, reinterpret_cast <cl_ulong *>(ppFunctionPointer));
476472 // GPU runtime sometimes returns CL_INVALID_ARG_VALUE if the function address
477473 // cannot be found but the kernel exists. As the kernel does exist, return
0 commit comments