@@ -41,6 +41,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
4141 *OutEvent // /< [in,out][optional] return an event object that identifies
4242 // /< this particular kernel execution instance.
4343) {
44+ auto ZeDevice = Queue->Device ->ZeDevice ;
45+
46+ ze_kernel_handle_t ZeKernel{};
47+ if (Kernel->ZeKernelMap .empty ()) {
48+ ZeKernel = Kernel->ZeKernel ;
49+ } else {
50+ auto It = Kernel->ZeKernelMap .find (ZeDevice);
51+ ZeKernel = It->second ;
52+ }
4453 // Lock automatically releases when this goes out of scope.
4554 std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
4655 Queue->Mutex , Kernel->Mutex , Kernel->Program ->Mutex );
@@ -51,7 +60,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
5160 }
5261
5362 ZE2UR_CALL (zeKernelSetGlobalOffsetExp,
54- (Kernel-> ZeKernel , GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
63+ (ZeKernel, GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
5564 GlobalWorkOffset[2 ]));
5665 }
5766
@@ -65,18 +74,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
6574 Queue->Device ));
6675 }
6776 ZE2UR_CALL (zeKernelSetArgumentValue,
68- (Kernel-> ZeKernel , Arg.Index , Arg.Size , ZeHandlePtr));
77+ (ZeKernel, Arg.Index , Arg.Size , ZeHandlePtr));
6978 }
7079 Kernel->PendingArguments .clear ();
7180
7281 ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
7382 uint32_t WG[3 ]{};
7483
7584 // global_work_size of unused dimensions must be set to 1
76- UR_ASSERT (WorkDim == 3 || GlobalWorkSize[2 ] == 1 ,
77- UR_RESULT_ERROR_INVALID_VALUE);
78- UR_ASSERT (WorkDim >= 2 || GlobalWorkSize[1 ] == 1 ,
79- UR_RESULT_ERROR_INVALID_VALUE);
85+ if (WorkDim >= 2 ) {
86+ UR_ASSERT (WorkDim >= 2 || GlobalWorkSize[1 ] == 1 ,
87+ UR_RESULT_ERROR_INVALID_VALUE);
88+ if (WorkDim == 3 ) {
89+ UR_ASSERT (WorkDim == 3 || GlobalWorkSize[2 ] == 1 ,
90+ UR_RESULT_ERROR_INVALID_VALUE);
91+ }
92+ }
8093 if (LocalWorkSize) {
8194 // L0
8295 UR_ASSERT (LocalWorkSize[0 ] < (std::numeric_limits<uint32_t >::max)(),
@@ -99,7 +112,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
99112 }
100113 if (SuggestGroupSize) {
101114 ZE2UR_CALL (zeKernelSuggestGroupSize,
102- (Kernel-> ZeKernel , GlobalWorkSize[0 ], GlobalWorkSize[1 ],
115+ (ZeKernel, GlobalWorkSize[0 ], GlobalWorkSize[1 ],
103116 GlobalWorkSize[2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
104117 } else {
105118 for (int I : {0 , 1 , 2 }) {
@@ -175,7 +188,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
175188 return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
176189 }
177190
178- ZE2UR_CALL (zeKernelSetGroupSize, (Kernel-> ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
191+ ZE2UR_CALL (zeKernelSetGroupSize, (ZeKernel, WG[0 ], WG[1 ], WG[2 ]));
179192
180193 bool UseCopyEngine = false ;
181194 _ur_ze_event_list_t TmpWaitList;
@@ -227,18 +240,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
227240 Queue->CaptureIndirectAccesses ();
228241 // Add the command to the command list, which implies submission.
229242 ZE2UR_CALL (zeCommandListAppendLaunchKernel,
230- (CommandList->first , Kernel->ZeKernel , &ZeThreadGroupDimensions,
231- ZeEvent, (*Event)->WaitList .Length ,
232- (*Event)->WaitList .ZeEventList ));
243+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
244+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
233245 } else {
234246 // Add the command to the command list for later submission.
235247 // No lock is needed here, unlike the immediate commandlist case above,
236248 // because the kernels are not actually submitted yet. Kernels will be
237249 // submitted only when the comamndlist is closed. Then, a lock is held.
238250 ZE2UR_CALL (zeCommandListAppendLaunchKernel,
239- (CommandList->first , Kernel->ZeKernel , &ZeThreadGroupDimensions,
240- ZeEvent, (*Event)->WaitList .Length ,
241- (*Event)->WaitList .ZeEventList ));
251+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
252+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
242253 }
243254
244255 urPrint (" calling zeCommandListAppendLaunchKernel() with"
@@ -363,23 +374,46 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate(
363374 return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
364375 }
365376
366- ZeStruct<ze_kernel_desc_t > ZeKernelDesc;
367- ZeKernelDesc.flags = 0 ;
368- ZeKernelDesc.pKernelName = KernelName;
369-
370- ze_kernel_handle_t ZeKernel;
371- ZE2UR_CALL (zeKernelCreate, (Program->ZeModule , &ZeKernelDesc, &ZeKernel));
372-
373377 try {
374- ur_kernel_handle_t_ *UrKernel =
375- new ur_kernel_handle_t_ (ZeKernel, true , Program);
378+ ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_ (true , Program);
376379 *RetKernel = reinterpret_cast <ur_kernel_handle_t >(UrKernel);
377380 } catch (const std::bad_alloc &) {
378381 return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
379382 } catch (...) {
380383 return UR_RESULT_ERROR_UNKNOWN;
381384 }
382385
386+ for (auto It : Program->ZeModuleMap ) {
387+ auto ZeModule = It.second ;
388+ ZeStruct<ze_kernel_desc_t > ZeKernelDesc;
389+ ZeKernelDesc.flags = 0 ;
390+ ZeKernelDesc.pKernelName = KernelName;
391+
392+ ze_kernel_handle_t ZeKernel;
393+ ZE2UR_CALL (zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
394+
395+ auto ZeDevice = It.first ;
396+
397+ // Store the kernel in the ZeKernelMap so the correct
398+ // kernel can be retrieved later for a specific device
399+ // where a queue is being submitted.
400+ (*RetKernel)->ZeKernelMap [ZeDevice] = ZeKernel;
401+ (*RetKernel)->ZeKernels .push_back (ZeKernel);
402+
403+ // If the device used to create the module's kernel is a root-device
404+ // then store the kernel also using the sub-devices, since application
405+ // could submit the root-device's kernel to a sub-device's queue.
406+ uint32_t SubDevicesCount = 0 ;
407+ zeDeviceGetSubDevices (ZeDevice, &SubDevicesCount, nullptr );
408+ std::vector<ze_device_handle_t > ZeSubDevices (SubDevicesCount);
409+ zeDeviceGetSubDevices (ZeDevice, &SubDevicesCount, ZeSubDevices.data ());
410+ for (auto ZeSubDevice : ZeSubDevices) {
411+ (*RetKernel)->ZeKernelMap [ZeSubDevice] = ZeKernel;
412+ }
413+ }
414+
415+ (*RetKernel)->ZeKernel = (*RetKernel)->ZeKernelMap .begin ()->second ;
416+
383417 UR_CALL ((*RetKernel)->initialize ());
384418
385419 return UR_RESULT_SUCCESS;
@@ -396,6 +430,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
396430) {
397431 std::ignore = Properties;
398432
433+ UR_ASSERT (Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
434+
399435 // OpenCL: "the arg_value pointer can be NULL or point to a NULL value
400436 // in which case a NULL value will be used as the value for the argument
401437 // declared as a pointer to global or constant memory in the kernel"
@@ -409,8 +445,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
409445 }
410446
411447 std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
412- ZE2UR_CALL (zeKernelSetArgumentValue,
413- (Kernel->ZeKernel , ArgIndex, ArgSize, PArgValue));
448+ for (auto It : Kernel->ZeKernelMap ) {
449+ auto ZeKernel = It.second ;
450+ ZE2UR_CALL (zeKernelSetArgumentValue,
451+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
452+ }
414453
415454 return UR_RESULT_SUCCESS;
416455}
@@ -596,16 +635,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(
596635
597636 auto KernelProgram = Kernel->Program ;
598637 if (Kernel->OwnNativeHandle ) {
599- auto ZeResult = ZE_CALL_NOCHECK (zeKernelDestroy, (Kernel->ZeKernel ));
600- // Gracefully handle the case that L0 was already unloaded.
601- if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
602- return ze2urResult (ZeResult);
638+ for (auto &ZeKernel : Kernel->ZeKernels ) {
639+ auto ZeResult = ZE_CALL_NOCHECK (zeKernelDestroy, (ZeKernel));
640+ // Gracefully handle the case that L0 was already unloaded.
641+ if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
642+ return ze2urResult (ZeResult);
643+ }
603644 }
645+ Kernel->ZeKernelMap .clear ();
604646 if (IndirectAccessTrackingEnabled) {
605647 UR_CALL (urContextRelease (KernelProgram->Context ));
606648 }
607- // do a release on the program this kernel was part of
608- UR_CALL (urProgramRelease (KernelProgram));
649+ // do a release on the program this kernel was part of without delete of the
650+ // program handle
651+ KernelProgram->ur_release_program_resources (false );
652+
609653 delete Kernel;
610654
611655 return UR_RESULT_SUCCESS;
@@ -639,6 +683,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
639683 std::ignore = PropSize;
640684 std::ignore = Properties;
641685
686+ auto ZeKernel = Kernel->ZeKernel ;
642687 std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
643688 if (PropName == UR_KERNEL_EXEC_INFO_USM_INDIRECT_ACCESS &&
644689 *(static_cast <const ur_bool_t *>(PropValue)) == true ) {
@@ -649,7 +694,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
649694 ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST |
650695 ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE |
651696 ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED;
652- ZE2UR_CALL (zeKernelSetIndirectAccess, (Kernel-> ZeKernel , IndirectFlags));
697+ ZE2UR_CALL (zeKernelSetIndirectAccess, (ZeKernel, IndirectFlags));
653698 } else if (PropName == UR_KERNEL_EXEC_INFO_CACHE_CONFIG) {
654699 ze_cache_config_flag_t ZeCacheConfig{};
655700 auto CacheConfig =
@@ -663,7 +708,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
663708 else
664709 // Unexpected cache configuration value.
665710 return UR_RESULT_ERROR_INVALID_VALUE;
666- ZE2UR_CALL (zeKernelSetCacheConfig, (Kernel-> ZeKernel , ZeCacheConfig););
711+ ZE2UR_CALL (zeKernelSetCacheConfig, (ZeKernel, ZeCacheConfig););
667712 } else {
668713 urPrint (" urKernelSetExecInfo: unsupported ParamName\n " );
669714 return UR_RESULT_ERROR_INVALID_VALUE;
0 commit comments