@@ -70,8 +70,10 @@ ur_kernel_handle_t_::ur_kernel_handle_t_(ur_program_handle_t hProgram,
7070
7171ur_kernel_handle_t_::ur_kernel_handle_t_ (
7272 ur_native_handle_t hNativeKernel, ur_program_handle_t hProgram,
73+ ur_context_handle_t context,
7374 const ur_kernel_native_properties_t *pProperties)
74- : hProgram(hProgram), deviceKernels(1 ) {
75+ : hProgram(hProgram),
76+ deviceKernels(context ? context->getPlatform ()->getNumDevices() : 0) {
7577 ur::level_zero::urProgramRetain (hProgram);
7678
7779 auto ownZeHandle = pProperties ? pProperties->isNativeHandleOwned : false ;
@@ -82,7 +84,12 @@ ur_kernel_handle_t_::ur_kernel_handle_t_(
8284 throw UR_RESULT_ERROR_INVALID_KERNEL;
8385 }
8486
85- deviceKernels.back ().emplace (nullptr , zeKernel, ownZeHandle);
87+ for (auto &Dev : context->getDevices ()) {
88+ deviceKernels[*Dev->Id ].emplace (Dev, zeKernel, ownZeHandle);
89+
90+ // owned only by the first entry
91+ ownZeHandle = false ;
92+ }
8693 completeInitialization ();
8794}
8895
@@ -128,20 +135,6 @@ size_t ur_kernel_handle_t_::deviceIndex(ur_device_handle_t hDevice) const {
128135 hDevice = hDevice->RootDevice ;
129136 }
130137
131- // supports kernels created from native handle
132- if (deviceKernels.size () == 1 ) {
133- assert (deviceKernels[0 ].has_value ());
134- assert (deviceKernels[0 ].value ().hKernel .get ());
135-
136- auto &kernel = deviceKernels[0 ].value ();
137-
138- if (kernel.hDevice != hDevice) {
139- throw UR_RESULT_ERROR_INVALID_DEVICE;
140- }
141-
142- return 0 ;
143- }
144-
145138 if (!deviceKernels[hDevice->Id .value ()].has_value ()) {
146139 throw UR_RESULT_ERROR_INVALID_DEVICE;
147140 }
@@ -341,8 +334,12 @@ urKernelCreateWithNativeHandle(ur_native_handle_t hNativeKernel,
341334 ur_program_handle_t hProgram,
342335 const ur_kernel_native_properties_t *pProperties,
343336 ur_kernel_handle_t *phKernel) {
344- std::ignore = hContext;
345- *phKernel = new ur_kernel_handle_t_ (hNativeKernel, hProgram, pProperties);
337+ if (!hProgram) {
338+ return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
339+ }
340+
341+ *phKernel =
342+ new ur_kernel_handle_t_ (hNativeKernel, hProgram, hContext, pProperties);
346343 return UR_RESULT_SUCCESS;
347344}
348345
0 commit comments