77// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
88//
99// ===----------------------------------------------------------------------===//
10+ #include " kernel.hpp"
1011#include " common.hpp"
1112#include " device.hpp"
1213#include " memory.hpp"
@@ -21,19 +22,20 @@ urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName,
2122 ur_kernel_handle_t *phKernel) {
2223
2324 cl_int CLResult;
24- *phKernel = cl_adapter::cast<ur_kernel_handle_t >(
25- clCreateKernel (hProgram->get (), pKernelName, &CLResult));
25+ cl_kernel Kernel = clCreateKernel (hProgram->get (), pKernelName, &CLResult);
2626 CL_RETURN_ON_FAILURE (CLResult);
27+ auto URKernel =
28+ std::make_unique<ur_kernel_handle_t_>(Kernel, hProgram, nullptr );
29+ *phKernel = URKernel.release ();
2730 return UR_RESULT_SUCCESS;
2831}
2932
3033UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue (
3134 ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
3235 const ur_kernel_arg_value_properties_t *, const void *pArgValue) {
3336
34- CL_RETURN_ON_FAILURE (clSetKernelArg (cl_adapter::cast<cl_kernel>(hKernel),
35- cl_adapter::cast<cl_uint>(argIndex),
36- argSize, pArgValue));
37+ CL_RETURN_ON_FAILURE (clSetKernelArg (
38+ hKernel->get (), cl_adapter::cast<cl_uint>(argIndex), argSize, pArgValue));
3739
3840 return UR_RESULT_SUCCESS;
3941}
@@ -42,9 +44,8 @@ UR_APIEXPORT ur_result_t UR_APICALL
4244urKernelSetArgLocal (ur_kernel_handle_t hKernel, uint32_t argIndex,
4345 size_t argSize, const ur_kernel_arg_local_properties_t *) {
4446
45- CL_RETURN_ON_FAILURE (clSetKernelArg (cl_adapter::cast<cl_kernel>(hKernel),
46- cl_adapter::cast<cl_uint>(argIndex),
47- argSize, nullptr ));
47+ CL_RETURN_ON_FAILURE (clSetKernelArg (
48+ hKernel->get (), cl_adapter::cast<cl_uint>(argIndex), argSize, nullptr ));
4849
4950 return UR_RESULT_SUCCESS;
5051}
@@ -76,26 +77,31 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel,
7677 size_t propSize,
7778 void *pPropValue,
7879 size_t *pPropSizeRet) {
80+ UrReturnHelper ReturnValue (propSize, pPropValue, pPropSizeRet);
7981 // We need this little bit of ugliness because the UR NUM_ARGS property is
8082 // size_t whereas the CL one is cl_uint. We should consider changing that see
8183 // #1038
8284 if (propName == UR_KERNEL_INFO_NUM_ARGS) {
8385 if (pPropSizeRet)
8486 *pPropSizeRet = sizeof (size_t );
8587 cl_uint NumArgs = 0 ;
86- CL_RETURN_ON_FAILURE (clGetKernelInfo (cl_adapter::cast<cl_kernel>(hKernel ),
88+ CL_RETURN_ON_FAILURE (clGetKernelInfo (hKernel-> get ( ),
8789 mapURKernelInfoToCL (propName),
8890 sizeof (NumArgs), &NumArgs, nullptr ));
8991 if (pPropValue) {
9092 if (propSize != sizeof (size_t ))
9193 return UR_RESULT_ERROR_INVALID_SIZE;
9294 *static_cast <size_t *>(pPropValue) = static_cast <size_t >(NumArgs);
9395 }
96+ } else if (propName == UR_KERNEL_INFO_PROGRAM) {
97+ return ReturnValue (hKernel->Program );
98+ } else if (propName == UR_KERNEL_INFO_CONTEXT) {
99+ return ReturnValue (hKernel->Context );
94100 } else {
95101 size_t CheckPropSize = 0 ;
96- cl_int ClResult = clGetKernelInfo (cl_adapter::cast<cl_kernel>(hKernel),
97- mapURKernelInfoToCL (propName), propSize,
98- pPropValue, &CheckPropSize);
102+ cl_int ClResult =
103+ clGetKernelInfo (hKernel-> get (), mapURKernelInfoToCL (propName), propSize,
104+ pPropValue, &CheckPropSize);
99105 if (pPropValue && CheckPropSize != propSize) {
100106 return UR_RESULT_ERROR_INVALID_SIZE;
101107 }
@@ -147,8 +153,8 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
147153 }
148154 }
149155 CL_RETURN_ON_FAILURE (clGetKernelWorkGroupInfo (
150- cl_adapter::cast<cl_kernel>(hKernel ), hDevice->get (),
151- mapURKernelGroupInfoToCL (propName), propSize, pPropValue, pPropSizeRet));
156+ hKernel-> get ( ), hDevice->get (), mapURKernelGroupInfoToCL (propName ),
157+ propSize, pPropValue, pPropSizeRet));
152158
153159 return UR_RESULT_SUCCESS;
154160}
@@ -201,9 +207,8 @@ urKernelGetSubGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
201207 }
202208
203209 cl_int Ret = clGetKernelSubGroupInfo (
204- cl_adapter::cast<cl_kernel>(hKernel), hDevice->get (),
205- mapURKernelSubGroupInfoToCL (propName), InputValueSize, InputValue.get (),
206- sizeof (size_t ), &RetVal, pPropSizeRet);
210+ hKernel->get (), hDevice->get (), mapURKernelSubGroupInfoToCL (propName),
211+ InputValueSize, InputValue.get (), sizeof (size_t ), &RetVal, pPropSizeRet);
207212
208213 if (Ret == CL_INVALID_OPERATION) {
209214 // clGetKernelSubGroupInfo returns CL_INVALID_OPERATION if the device does
@@ -252,13 +257,13 @@ urKernelGetSubGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
252257}
253258
254259UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain (ur_kernel_handle_t hKernel) {
255- CL_RETURN_ON_FAILURE (clRetainKernel (cl_adapter::cast<cl_kernel>(hKernel )));
260+ CL_RETURN_ON_FAILURE (clRetainKernel (hKernel-> get ( )));
256261 return UR_RESULT_SUCCESS;
257262}
258263
259264UR_APIEXPORT ur_result_t UR_APICALL
260265urKernelRelease (ur_kernel_handle_t hKernel) {
261- CL_RETURN_ON_FAILURE (clReleaseKernel (cl_adapter::cast<cl_kernel>(hKernel )));
266+ CL_RETURN_ON_FAILURE (clReleaseKernel (hKernel-> get ( )));
262267 return UR_RESULT_SUCCESS;
263268}
264269
@@ -276,41 +281,38 @@ static ur_result_t usmSetIndirectAccess(ur_kernel_handle_t hKernel) {
276281
277282 /* We test that each alloc type is supported before we actually try to set
278283 * KernelExecInfo. */
279- CL_RETURN_ON_FAILURE (clGetKernelInfo (cl_adapter::cast<cl_kernel>(hKernel) ,
280- CL_KERNEL_CONTEXT, sizeof (cl_context),
281- &CLContext, nullptr ));
284+ CL_RETURN_ON_FAILURE (clGetKernelInfo (hKernel-> get (), CL_KERNEL_CONTEXT ,
285+ sizeof (cl_context), &CLContext ,
286+ nullptr ));
282287
283288 UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clHostMemAllocINTEL_fn>(
284289 CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache ,
285290 cl_ext::HostMemAllocName, &HFunc));
286291
287292 if (HFunc) {
288- CL_RETURN_ON_FAILURE (
289- clSetKernelExecInfo (cl_adapter::cast<cl_kernel>(hKernel),
290- CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL,
291- sizeof (cl_bool), &TrueVal));
293+ CL_RETURN_ON_FAILURE (clSetKernelExecInfo (
294+ hKernel->get (), CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL,
295+ sizeof (cl_bool), &TrueVal));
292296 }
293297
294298 UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
295299 CLContext, cl_ext::ExtFuncPtrCache->clDeviceMemAllocINTELCache ,
296300 cl_ext::DeviceMemAllocName, &DFunc));
297301
298302 if (DFunc) {
299- CL_RETURN_ON_FAILURE (
300- clSetKernelExecInfo (cl_adapter::cast<cl_kernel>(hKernel),
301- CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL,
302- sizeof (cl_bool), &TrueVal));
303+ CL_RETURN_ON_FAILURE (clSetKernelExecInfo (
304+ hKernel->get (), CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL,
305+ sizeof (cl_bool), &TrueVal));
303306 }
304307
305308 UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
306309 CLContext, cl_ext::ExtFuncPtrCache->clSharedMemAllocINTELCache ,
307310 cl_ext::SharedMemAllocName, &SFunc));
308311
309312 if (SFunc) {
310- CL_RETURN_ON_FAILURE (
311- clSetKernelExecInfo (cl_adapter::cast<cl_kernel>(hKernel),
312- CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL,
313- sizeof (cl_bool), &TrueVal));
313+ CL_RETURN_ON_FAILURE (clSetKernelExecInfo (
314+ hKernel->get (), CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL,
315+ sizeof (cl_bool), &TrueVal));
314316 }
315317 return UR_RESULT_SUCCESS;
316318}
@@ -332,9 +334,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
332334 return UR_RESULT_SUCCESS;
333335 }
334336 case UR_KERNEL_EXEC_INFO_USM_PTRS: {
335- CL_RETURN_ON_FAILURE (clSetKernelExecInfo (
336- cl_adapter::cast<cl_kernel>(hKernel) ,
337- CL_KERNEL_EXEC_INFO_USM_PTRS_INTEL, propSize, pPropValue));
337+ CL_RETURN_ON_FAILURE (clSetKernelExecInfo (hKernel-> get (),
338+ CL_KERNEL_EXEC_INFO_USM_PTRS_INTEL ,
339+ propSize, pPropValue));
338340 return UR_RESULT_SUCCESS;
339341 }
340342 default : {
@@ -348,9 +350,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgPointer(
348350 const ur_kernel_arg_pointer_properties_t *, const void *pArgValue) {
349351
350352 cl_context CLContext;
351- CL_RETURN_ON_FAILURE (clGetKernelInfo (cl_adapter::cast<cl_kernel>(hKernel) ,
352- CL_KERNEL_CONTEXT, sizeof (cl_context),
353- &CLContext, nullptr ));
353+ CL_RETURN_ON_FAILURE (clGetKernelInfo (hKernel-> get (), CL_KERNEL_CONTEXT ,
354+ sizeof (cl_context), &CLContext ,
355+ nullptr ));
354356
355357 clSetKernelArgMemPointerINTEL_fn FuncPtr = nullptr ;
356358 UR_RETURN_ON_FAILURE (
@@ -364,25 +366,30 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgPointer(
364366 * deref the arg to get the pointer value */
365367 auto PtrToPtr = reinterpret_cast <const intptr_t *>(pArgValue);
366368 auto DerefPtr = reinterpret_cast <void *>(*PtrToPtr);
367- CL_RETURN_ON_FAILURE (FuncPtr (cl_adapter::cast<cl_kernel>(hKernel),
368- cl_adapter::cast<cl_uint>(argIndex),
369- DerefPtr));
369+ CL_RETURN_ON_FAILURE (
370+ FuncPtr (hKernel->get (), cl_adapter::cast<cl_uint>(argIndex), DerefPtr));
370371 }
371372
372373 return UR_RESULT_SUCCESS;
373374}
374375UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle (
375376 ur_kernel_handle_t hKernel, ur_native_handle_t *phNativeKernel) {
376377
377- *phNativeKernel = reinterpret_cast <ur_native_handle_t >(hKernel);
378+ *phNativeKernel = reinterpret_cast <ur_native_handle_t >(hKernel-> get () );
378379 return UR_RESULT_SUCCESS;
379380}
380381
381382UR_APIEXPORT ur_result_t UR_APICALL urKernelCreateWithNativeHandle (
382- ur_native_handle_t hNativeKernel, ur_context_handle_t , ur_program_handle_t ,
383+ ur_native_handle_t hNativeKernel, ur_context_handle_t hContext,
384+ ur_program_handle_t hProgram,
383385 const ur_kernel_native_properties_t *pProperties,
384386 ur_kernel_handle_t *phKernel) {
385- *phKernel = reinterpret_cast <ur_kernel_handle_t >(hNativeKernel);
387+ cl_kernel NativeHandle = reinterpret_cast <cl_kernel>(hNativeKernel);
388+ auto URKernel =
389+ std::make_unique<ur_kernel_handle_t_>(NativeHandle, hProgram, hContext);
390+ UR_RETURN_ON_FAILURE (URKernel->initWithNative ());
391+ *phKernel = URKernel.release ();
392+
386393 if (!pProperties || !pProperties->isNativeHandleOwned ) {
387394 return urKernelRetain (*phKernel);
388395 }
@@ -394,7 +401,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgMemObj(
394401 const ur_kernel_arg_mem_obj_properties_t *, ur_mem_handle_t hArgValue) {
395402
396403 cl_mem CLArgValue = hArgValue ? hArgValue->get () : nullptr ;
397- CL_RETURN_ON_FAILURE (clSetKernelArg (cl_adapter::cast<cl_kernel>(hKernel ),
404+ CL_RETURN_ON_FAILURE (clSetKernelArg (hKernel-> get ( ),
398405 cl_adapter::cast<cl_uint>(argIndex),
399406 sizeof (CLArgValue), &CLArgValue));
400407 return UR_RESULT_SUCCESS;
@@ -405,9 +412,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgSampler(
405412 const ur_kernel_arg_sampler_properties_t *, ur_sampler_handle_t hArgValue) {
406413
407414 cl_sampler CLArgSampler = hArgValue->get ();
408- cl_int RetErr = clSetKernelArg (cl_adapter::cast<cl_kernel>(hKernel),
409- cl_adapter::cast<cl_uint>(argIndex),
410- sizeof (CLArgSampler), &CLArgSampler);
415+ cl_int RetErr =
416+ clSetKernelArg (hKernel-> get (), cl_adapter::cast<cl_uint>(argIndex),
417+ sizeof (CLArgSampler), &CLArgSampler);
411418 CL_RETURN_ON_FAILURE (RetErr);
412419 return UR_RESULT_SUCCESS;
413420}
0 commit comments