@@ -41,6 +41,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
41
41
*OutEvent // /< [in,out][optional] return an event object that identifies
42
42
// /< this particular kernel execution instance.
43
43
) {
44
+ auto ZeDevice = Queue->Device ->ZeDevice ;
45
+
46
+ ze_kernel_handle_t ZeKernel{};
47
+ if (Kernel->ZeKernelMap .empty ()) {
48
+ ZeKernel = Kernel->ZeKernel ;
49
+ } else {
50
+ auto It = Kernel->ZeKernelMap .find (ZeDevice);
51
+ ZeKernel = It->second ;
52
+ }
44
53
// Lock automatically releases when this goes out of scope.
45
54
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
46
55
Queue->Mutex , Kernel->Mutex , Kernel->Program ->Mutex );
@@ -51,7 +60,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
51
60
}
52
61
53
62
ZE2UR_CALL (zeKernelSetGlobalOffsetExp,
54
- (Kernel-> ZeKernel , GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
63
+ (ZeKernel, GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
55
64
GlobalWorkOffset[2 ]));
56
65
}
57
66
@@ -65,7 +74,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
65
74
Queue->Device ));
66
75
}
67
76
ZE2UR_CALL (zeKernelSetArgumentValue,
68
- (Kernel-> ZeKernel , Arg.Index , Arg.Size , ZeHandlePtr));
77
+ (ZeKernel, Arg.Index , Arg.Size , ZeHandlePtr));
69
78
}
70
79
Kernel->PendingArguments .clear ();
71
80
@@ -99,7 +108,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
99
108
}
100
109
if (SuggestGroupSize) {
101
110
ZE2UR_CALL (zeKernelSuggestGroupSize,
102
- (Kernel-> ZeKernel , GlobalWorkSize[0 ], GlobalWorkSize[1 ],
111
+ (ZeKernel, GlobalWorkSize[0 ], GlobalWorkSize[1 ],
103
112
GlobalWorkSize[2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
104
113
} else {
105
114
for (int I : {0 , 1 , 2 }) {
@@ -175,7 +184,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
175
184
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
176
185
}
177
186
178
- ZE2UR_CALL (zeKernelSetGroupSize, (Kernel-> ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
187
+ ZE2UR_CALL (zeKernelSetGroupSize, (ZeKernel, WG[0 ], WG[1 ], WG[2 ]));
179
188
180
189
bool UseCopyEngine = false ;
181
190
_ur_ze_event_list_t TmpWaitList;
@@ -227,18 +236,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
227
236
Queue->CaptureIndirectAccesses ();
228
237
// Add the command to the command list, which implies submission.
229
238
ZE2UR_CALL (zeCommandListAppendLaunchKernel,
230
- (CommandList->first , Kernel->ZeKernel , &ZeThreadGroupDimensions,
231
- ZeEvent, (*Event)->WaitList .Length ,
232
- (*Event)->WaitList .ZeEventList ));
239
+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
240
+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
233
241
} else {
234
242
// Add the command to the command list for later submission.
235
243
// No lock is needed here, unlike the immediate commandlist case above,
236
244
// because the kernels are not actually submitted yet. Kernels will be
237
245
// submitted only when the comamndlist is closed. Then, a lock is held.
238
246
ZE2UR_CALL (zeCommandListAppendLaunchKernel,
239
- (CommandList->first , Kernel->ZeKernel , &ZeThreadGroupDimensions,
240
- ZeEvent, (*Event)->WaitList .Length ,
241
- (*Event)->WaitList .ZeEventList ));
247
+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
248
+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
242
249
}
243
250
244
251
urPrint (" calling zeCommandListAppendLaunchKernel() with"
@@ -363,23 +370,46 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate(
363
370
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
364
371
}
365
372
366
- ZeStruct<ze_kernel_desc_t > ZeKernelDesc;
367
- ZeKernelDesc.flags = 0 ;
368
- ZeKernelDesc.pKernelName = KernelName;
369
-
370
- ze_kernel_handle_t ZeKernel;
371
- ZE2UR_CALL (zeKernelCreate, (Program->ZeModule , &ZeKernelDesc, &ZeKernel));
372
-
373
373
try {
374
- ur_kernel_handle_t_ *UrKernel =
375
- new ur_kernel_handle_t_ (ZeKernel, true , Program);
374
+ ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_ (true , Program);
376
375
*RetKernel = reinterpret_cast <ur_kernel_handle_t >(UrKernel);
377
376
} catch (const std::bad_alloc &) {
378
377
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
379
378
} catch (...) {
380
379
return UR_RESULT_ERROR_UNKNOWN;
381
380
}
382
381
382
+ for (auto It : Program->ZeModuleMap ) {
383
+ auto ZeModule = It.second ;
384
+ ZeStruct<ze_kernel_desc_t > ZeKernelDesc;
385
+ ZeKernelDesc.flags = 0 ;
386
+ ZeKernelDesc.pKernelName = KernelName;
387
+
388
+ ze_kernel_handle_t ZeKernel;
389
+ ZE2UR_CALL (zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
390
+
391
+ auto ZeDevice = It.first ;
392
+
393
+ // Store the kernel in the ZeKernelMap so the correct
394
+ // kernel can be retrieved later for a specific device
395
+ // where a queue is being submitted.
396
+ (*RetKernel)->ZeKernelMap [ZeDevice] = ZeKernel;
397
+ (*RetKernel)->ZeKernels .push_back (ZeKernel);
398
+
399
+ // If the device used to create the module's kernel is a root-device
400
+ // then store the kernel also using the sub-devices, since application
401
+ // could submit the root-device's kernel to a sub-device's queue.
402
+ uint32_t SubDevicesCount = 0 ;
403
+ zeDeviceGetSubDevices (ZeDevice, &SubDevicesCount, nullptr );
404
+ std::vector<ze_device_handle_t > ZeSubDevices (SubDevicesCount);
405
+ zeDeviceGetSubDevices (ZeDevice, &SubDevicesCount, ZeSubDevices.data ());
406
+ for (auto ZeSubDevice : ZeSubDevices) {
407
+ (*RetKernel)->ZeKernelMap [ZeSubDevice] = ZeKernel;
408
+ }
409
+ }
410
+
411
+ (*RetKernel)->ZeKernel = (*RetKernel)->ZeKernelMap .begin ()->second ;
412
+
383
413
UR_CALL ((*RetKernel)->initialize ());
384
414
385
415
return UR_RESULT_SUCCESS;
@@ -396,6 +426,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
396
426
) {
397
427
std::ignore = Properties;
398
428
429
+ UR_ASSERT (Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
430
+
399
431
// OpenCL: "the arg_value pointer can be NULL or point to a NULL value
400
432
// in which case a NULL value will be used as the value for the argument
401
433
// declared as a pointer to global or constant memory in the kernel"
@@ -409,8 +441,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
409
441
}
410
442
411
443
std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
412
- ZE2UR_CALL (zeKernelSetArgumentValue,
413
- (Kernel->ZeKernel , ArgIndex, ArgSize, PArgValue));
444
+ for (auto It : Kernel->ZeKernelMap ) {
445
+ auto ZeKernel = It.second ;
446
+ ZE2UR_CALL (zeKernelSetArgumentValue,
447
+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
448
+ }
414
449
415
450
return UR_RESULT_SUCCESS;
416
451
}
@@ -596,11 +631,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(
596
631
597
632
auto KernelProgram = Kernel->Program ;
598
633
if (Kernel->OwnNativeHandle ) {
599
- auto ZeResult = ZE_CALL_NOCHECK (zeKernelDestroy, (Kernel->ZeKernel ));
600
- // Gracefully handle the case that L0 was already unloaded.
601
- if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
602
- return ze2urResult (ZeResult);
634
+ for (auto &ZeKernel : Kernel->ZeKernels ) {
635
+ auto ZeResult = ZE_CALL_NOCHECK (zeKernelDestroy, (ZeKernel));
636
+ // Gracefully handle the case that L0 was already unloaded.
637
+ if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
638
+ return ze2urResult (ZeResult);
639
+ }
603
640
}
641
+ Kernel->ZeKernelMap .clear ();
604
642
if (IndirectAccessTrackingEnabled) {
605
643
UR_CALL (urContextRelease (KernelProgram->Context ));
606
644
}
@@ -639,6 +677,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
639
677
std::ignore = PropSize;
640
678
std::ignore = Properties;
641
679
680
+ auto ZeKernel = Kernel->ZeKernel ;
642
681
std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
643
682
if (PropName == UR_KERNEL_EXEC_INFO_USM_INDIRECT_ACCESS &&
644
683
*(static_cast <const ur_bool_t *>(PropValue)) == true ) {
@@ -649,7 +688,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
649
688
ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST |
650
689
ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE |
651
690
ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED;
652
- ZE2UR_CALL (zeKernelSetIndirectAccess, (Kernel-> ZeKernel , IndirectFlags));
691
+ ZE2UR_CALL (zeKernelSetIndirectAccess, (ZeKernel, IndirectFlags));
653
692
} else if (PropName == UR_KERNEL_EXEC_INFO_CACHE_CONFIG) {
654
693
ze_cache_config_flag_t ZeCacheConfig{};
655
694
auto CacheConfig =
@@ -663,7 +702,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
663
702
else
664
703
// Unexpected cache configuration value.
665
704
return UR_RESULT_ERROR_INVALID_VALUE;
666
- ZE2UR_CALL (zeKernelSetCacheConfig, (Kernel-> ZeKernel , ZeCacheConfig););
705
+ ZE2UR_CALL (zeKernelSetCacheConfig, (ZeKernel, ZeCacheConfig););
667
706
} else {
668
707
urPrint (" urKernelSetExecInfo: unsupported ParamName\n " );
669
708
return UR_RESULT_ERROR_INVALID_VALUE;
0 commit comments