@@ -41,6 +41,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
4141 *OutEvent // /< [in,out][optional] return an event object that identifies
4242 // /< this particular kernel execution instance.
4343) {
44+ auto ZeDevice = Queue->Device ->ZeDevice ;
45+
46+ ze_kernel_handle_t ZeKernel{};
47+ if (Kernel->ZeKernelMap .empty ()) {
48+ ZeKernel = Kernel->ZeKernel ;
49+ } else {
50+ auto It = Kernel->ZeKernelMap .find (ZeDevice);
51+ ZeKernel = It->second ;
52+ }
4453 // Lock automatically releases when this goes out of scope.
4554 std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
4655 Queue->Mutex , Kernel->Mutex , Kernel->Program ->Mutex );
@@ -51,7 +60,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
5160 }
5261
5362 ZE2UR_CALL (zeKernelSetGlobalOffsetExp,
54- (Kernel-> ZeKernel , GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
63+ (ZeKernel, GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
5564 GlobalWorkOffset[2 ]));
5665 }
5766
@@ -65,7 +74,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
6574 Queue->Device ));
6675 }
6776 ZE2UR_CALL (zeKernelSetArgumentValue,
68- (Kernel-> ZeKernel , Arg.Index , Arg.Size , ZeHandlePtr));
77+ (ZeKernel, Arg.Index , Arg.Size , ZeHandlePtr));
6978 }
7079 Kernel->PendingArguments .clear ();
7180
@@ -99,7 +108,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
99108 }
100109 if (SuggestGroupSize) {
101110 ZE2UR_CALL (zeKernelSuggestGroupSize,
102- (Kernel-> ZeKernel , GlobalWorkSize[0 ], GlobalWorkSize[1 ],
111+ (ZeKernel, GlobalWorkSize[0 ], GlobalWorkSize[1 ],
103112 GlobalWorkSize[2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
104113 } else {
105114 for (int I : {0 , 1 , 2 }) {
@@ -175,7 +184,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
175184 return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
176185 }
177186
178- ZE2UR_CALL (zeKernelSetGroupSize, (Kernel-> ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
187+ ZE2UR_CALL (zeKernelSetGroupSize, (ZeKernel, WG[0 ], WG[1 ], WG[2 ]));
179188
180189 bool UseCopyEngine = false ;
181190 _ur_ze_event_list_t TmpWaitList;
@@ -227,18 +236,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
227236 Queue->CaptureIndirectAccesses ();
228237 // Add the command to the command list, which implies submission.
229238 ZE2UR_CALL (zeCommandListAppendLaunchKernel,
230- (CommandList->first , Kernel->ZeKernel , &ZeThreadGroupDimensions,
231- ZeEvent, (*Event)->WaitList .Length ,
232- (*Event)->WaitList .ZeEventList ));
239+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
240+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
233241 } else {
234242 // Add the command to the command list for later submission.
235243 // No lock is needed here, unlike the immediate commandlist case above,
236244 // because the kernels are not actually submitted yet. Kernels will be
237245 // submitted only when the comamndlist is closed. Then, a lock is held.
238246 ZE2UR_CALL (zeCommandListAppendLaunchKernel,
239- (CommandList->first , Kernel->ZeKernel , &ZeThreadGroupDimensions,
240- ZeEvent, (*Event)->WaitList .Length ,
241- (*Event)->WaitList .ZeEventList ));
247+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
248+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
242249 }
243250
244251 urPrint (" calling zeCommandListAppendLaunchKernel() with"
@@ -363,23 +370,46 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate(
363370 return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
364371 }
365372
366- ZeStruct<ze_kernel_desc_t > ZeKernelDesc;
367- ZeKernelDesc.flags = 0 ;
368- ZeKernelDesc.pKernelName = KernelName;
369-
370- ze_kernel_handle_t ZeKernel;
371- ZE2UR_CALL (zeKernelCreate, (Program->ZeModule , &ZeKernelDesc, &ZeKernel));
372-
373373 try {
374- ur_kernel_handle_t_ *UrKernel =
375- new ur_kernel_handle_t_ (ZeKernel, true , Program);
374+ ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_ (true , Program);
376375 *RetKernel = reinterpret_cast <ur_kernel_handle_t >(UrKernel);
377376 } catch (const std::bad_alloc &) {
378377 return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
379378 } catch (...) {
380379 return UR_RESULT_ERROR_UNKNOWN;
381380 }
382381
382+ for (auto It : Program->ZeModuleMap ) {
383+ auto ZeModule = It.second ;
384+ ZeStruct<ze_kernel_desc_t > ZeKernelDesc;
385+ ZeKernelDesc.flags = 0 ;
386+ ZeKernelDesc.pKernelName = KernelName;
387+
388+ ze_kernel_handle_t ZeKernel;
389+ ZE2UR_CALL (zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
390+
391+ auto ZeDevice = It.first ;
392+
393+ // Store the kernel in the ZeKernelMap so the correct
394+ // kernel can be retrieved later for a specific device
395+ // where a queue is being submitted.
396+ (*RetKernel)->ZeKernelMap [ZeDevice] = ZeKernel;
397+ (*RetKernel)->ZeKernels .push_back (ZeKernel);
398+
399+ // If the device used to create the module's kernel is a root-device
400+ // then store the kernel also using the sub-devices, since application
401+ // could submit the root-device's kernel to a sub-device's queue.
402+ uint32_t SubDevicesCount = 0 ;
403+ zeDeviceGetSubDevices (ZeDevice, &SubDevicesCount, nullptr );
404+ std::vector<ze_device_handle_t > ZeSubDevices (SubDevicesCount);
405+ zeDeviceGetSubDevices (ZeDevice, &SubDevicesCount, ZeSubDevices.data ());
406+ for (auto ZeSubDevice : ZeSubDevices) {
407+ (*RetKernel)->ZeKernelMap [ZeSubDevice] = ZeKernel;
408+ }
409+ }
410+
411+ (*RetKernel)->ZeKernel = (*RetKernel)->ZeKernelMap .begin ()->second ;
412+
383413 UR_CALL ((*RetKernel)->initialize ());
384414
385415 return UR_RESULT_SUCCESS;
@@ -396,6 +426,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
396426) {
397427 std::ignore = Properties;
398428
429+ UR_ASSERT (Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
430+
399431 // OpenCL: "the arg_value pointer can be NULL or point to a NULL value
400432 // in which case a NULL value will be used as the value for the argument
401433 // declared as a pointer to global or constant memory in the kernel"
@@ -409,8 +441,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
409441 }
410442
411443 std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
412- ZE2UR_CALL (zeKernelSetArgumentValue,
413- (Kernel->ZeKernel , ArgIndex, ArgSize, PArgValue));
444+ for (auto It : Kernel->ZeKernelMap ) {
445+ auto ZeKernel = It.second ;
446+ ZE2UR_CALL (zeKernelSetArgumentValue,
447+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
448+ }
414449
415450 return UR_RESULT_SUCCESS;
416451}
@@ -596,11 +631,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(
596631
597632 auto KernelProgram = Kernel->Program ;
598633 if (Kernel->OwnNativeHandle ) {
599- auto ZeResult = ZE_CALL_NOCHECK (zeKernelDestroy, (Kernel->ZeKernel ));
600- // Gracefully handle the case that L0 was already unloaded.
601- if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
602- return ze2urResult (ZeResult);
634+ for (auto &ZeKernel : Kernel->ZeKernels ) {
635+ auto ZeResult = ZE_CALL_NOCHECK (zeKernelDestroy, (ZeKernel));
636+ // Gracefully handle the case that L0 was already unloaded.
637+ if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
638+ return ze2urResult (ZeResult);
639+ }
603640 }
641+ Kernel->ZeKernelMap .clear ();
604642 if (IndirectAccessTrackingEnabled) {
605643 UR_CALL (urContextRelease (KernelProgram->Context ));
606644 }
@@ -639,6 +677,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
639677 std::ignore = PropSize;
640678 std::ignore = Properties;
641679
680+ auto ZeKernel = Kernel->ZeKernel ;
642681 std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
643682 if (PropName == UR_KERNEL_EXEC_INFO_USM_INDIRECT_ACCESS &&
644683 *(static_cast <const ur_bool_t *>(PropValue)) == true ) {
@@ -649,7 +688,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
649688 ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST |
650689 ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE |
651690 ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED;
652- ZE2UR_CALL (zeKernelSetIndirectAccess, (Kernel-> ZeKernel , IndirectFlags));
691+ ZE2UR_CALL (zeKernelSetIndirectAccess, (ZeKernel, IndirectFlags));
653692 } else if (PropName == UR_KERNEL_EXEC_INFO_CACHE_CONFIG) {
654693 ze_cache_config_flag_t ZeCacheConfig{};
655694 auto CacheConfig =
@@ -663,7 +702,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
663702 else
664703 // Unexpected cache configuration value.
665704 return UR_RESULT_ERROR_INVALID_VALUE;
666- ZE2UR_CALL (zeKernelSetCacheConfig, (Kernel-> ZeKernel , ZeCacheConfig););
705+ ZE2UR_CALL (zeKernelSetCacheConfig, (ZeKernel, ZeCacheConfig););
667706 } else {
668707 urPrint (" urKernelSetExecInfo: unsupported ParamName\n " );
669708 return UR_RESULT_ERROR_INVALID_VALUE;
0 commit comments