@@ -41,6 +41,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
41
41
*OutEvent // /< [in,out][optional] return an event object that identifies
42
42
// /< this particular kernel execution instance.
43
43
) {
44
+ auto ZeDevice = Queue->Device ->ZeDevice ;
45
+
46
+ ze_kernel_handle_t ZeKernel{};
47
+ if (Kernel->ZeKernelMap .empty ()) {
48
+ ZeKernel = Kernel->ZeKernel ;
49
+ } else {
50
+ auto It = Kernel->ZeKernelMap .find (ZeDevice);
51
+ ZeKernel = It->second ;
52
+ }
44
53
// Lock automatically releases when this goes out of scope.
45
54
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
46
55
Queue->Mutex , Kernel->Mutex , Kernel->Program ->Mutex );
@@ -51,7 +60,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
51
60
}
52
61
53
62
ZE2UR_CALL (zeKernelSetGlobalOffsetExp,
54
- (Kernel-> ZeKernel , GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
63
+ (ZeKernel, GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
55
64
GlobalWorkOffset[2 ]));
56
65
}
57
66
@@ -65,18 +74,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
65
74
Queue->Device ));
66
75
}
67
76
ZE2UR_CALL (zeKernelSetArgumentValue,
68
- (Kernel-> ZeKernel , Arg.Index , Arg.Size , ZeHandlePtr));
77
+ (ZeKernel, Arg.Index , Arg.Size , ZeHandlePtr));
69
78
}
70
79
Kernel->PendingArguments .clear ();
71
80
72
81
ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
73
82
uint32_t WG[3 ]{};
74
83
75
84
// global_work_size of unused dimensions must be set to 1
76
- UR_ASSERT (WorkDim == 3 || GlobalWorkSize[2 ] == 1 ,
77
- UR_RESULT_ERROR_INVALID_VALUE);
78
- UR_ASSERT (WorkDim >= 2 || GlobalWorkSize[1 ] == 1 ,
79
- UR_RESULT_ERROR_INVALID_VALUE);
85
+ if (WorkDim >= 2 ) {
86
+ UR_ASSERT (WorkDim >= 2 || GlobalWorkSize[1 ] == 1 ,
87
+ UR_RESULT_ERROR_INVALID_VALUE);
88
+ if (WorkDim == 3 ) {
89
+ UR_ASSERT (WorkDim == 3 || GlobalWorkSize[2 ] == 1 ,
90
+ UR_RESULT_ERROR_INVALID_VALUE);
91
+ }
92
+ }
80
93
if (LocalWorkSize) {
81
94
// L0
82
95
UR_ASSERT (LocalWorkSize[0 ] < (std::numeric_limits<uint32_t >::max)(),
@@ -99,7 +112,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
99
112
}
100
113
if (SuggestGroupSize) {
101
114
ZE2UR_CALL (zeKernelSuggestGroupSize,
102
- (Kernel-> ZeKernel , GlobalWorkSize[0 ], GlobalWorkSize[1 ],
115
+ (ZeKernel, GlobalWorkSize[0 ], GlobalWorkSize[1 ],
103
116
GlobalWorkSize[2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
104
117
} else {
105
118
for (int I : {0 , 1 , 2 }) {
@@ -175,7 +188,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
175
188
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
176
189
}
177
190
178
- ZE2UR_CALL (zeKernelSetGroupSize, (Kernel-> ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
191
+ ZE2UR_CALL (zeKernelSetGroupSize, (ZeKernel, WG[0 ], WG[1 ], WG[2 ]));
179
192
180
193
bool UseCopyEngine = false ;
181
194
_ur_ze_event_list_t TmpWaitList;
@@ -227,18 +240,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
227
240
Queue->CaptureIndirectAccesses ();
228
241
// Add the command to the command list, which implies submission.
229
242
ZE2UR_CALL (zeCommandListAppendLaunchKernel,
230
- (CommandList->first , Kernel->ZeKernel , &ZeThreadGroupDimensions,
231
- ZeEvent, (*Event)->WaitList .Length ,
232
- (*Event)->WaitList .ZeEventList ));
243
+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
244
+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
233
245
} else {
234
246
// Add the command to the command list for later submission.
235
247
// No lock is needed here, unlike the immediate commandlist case above,
236
248
// because the kernels are not actually submitted yet. Kernels will be
237
249
// submitted only when the comamndlist is closed. Then, a lock is held.
238
250
ZE2UR_CALL (zeCommandListAppendLaunchKernel,
239
- (CommandList->first , Kernel->ZeKernel , &ZeThreadGroupDimensions,
240
- ZeEvent, (*Event)->WaitList .Length ,
241
- (*Event)->WaitList .ZeEventList ));
251
+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
252
+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
242
253
}
243
254
244
255
urPrint (" calling zeCommandListAppendLaunchKernel() with"
@@ -363,23 +374,46 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate(
363
374
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
364
375
}
365
376
366
- ZeStruct<ze_kernel_desc_t > ZeKernelDesc;
367
- ZeKernelDesc.flags = 0 ;
368
- ZeKernelDesc.pKernelName = KernelName;
369
-
370
- ze_kernel_handle_t ZeKernel;
371
- ZE2UR_CALL (zeKernelCreate, (Program->ZeModule , &ZeKernelDesc, &ZeKernel));
372
-
373
377
try {
374
- ur_kernel_handle_t_ *UrKernel =
375
- new ur_kernel_handle_t_ (ZeKernel, true , Program);
378
+ ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_ (true , Program);
376
379
*RetKernel = reinterpret_cast <ur_kernel_handle_t >(UrKernel);
377
380
} catch (const std::bad_alloc &) {
378
381
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
379
382
} catch (...) {
380
383
return UR_RESULT_ERROR_UNKNOWN;
381
384
}
382
385
386
+ for (auto It : Program->ZeModuleMap ) {
387
+ auto ZeModule = It.second ;
388
+ ZeStruct<ze_kernel_desc_t > ZeKernelDesc;
389
+ ZeKernelDesc.flags = 0 ;
390
+ ZeKernelDesc.pKernelName = KernelName;
391
+
392
+ ze_kernel_handle_t ZeKernel;
393
+ ZE2UR_CALL (zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
394
+
395
+ auto ZeDevice = It.first ;
396
+
397
+ // Store the kernel in the ZeKernelMap so the correct
398
+ // kernel can be retrieved later for a specific device
399
+ // where a queue is being submitted.
400
+ (*RetKernel)->ZeKernelMap [ZeDevice] = ZeKernel;
401
+ (*RetKernel)->ZeKernels .push_back (ZeKernel);
402
+
403
+ // If the device used to create the module's kernel is a root-device
404
+ // then store the kernel also using the sub-devices, since application
405
+ // could submit the root-device's kernel to a sub-device's queue.
406
+ uint32_t SubDevicesCount = 0 ;
407
+ zeDeviceGetSubDevices (ZeDevice, &SubDevicesCount, nullptr );
408
+ std::vector<ze_device_handle_t > ZeSubDevices (SubDevicesCount);
409
+ zeDeviceGetSubDevices (ZeDevice, &SubDevicesCount, ZeSubDevices.data ());
410
+ for (auto ZeSubDevice : ZeSubDevices) {
411
+ (*RetKernel)->ZeKernelMap [ZeSubDevice] = ZeKernel;
412
+ }
413
+ }
414
+
415
+ (*RetKernel)->ZeKernel = (*RetKernel)->ZeKernelMap .begin ()->second ;
416
+
383
417
UR_CALL ((*RetKernel)->initialize ());
384
418
385
419
return UR_RESULT_SUCCESS;
@@ -396,6 +430,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
396
430
) {
397
431
std::ignore = Properties;
398
432
433
+ UR_ASSERT (Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
434
+
399
435
// OpenCL: "the arg_value pointer can be NULL or point to a NULL value
400
436
// in which case a NULL value will be used as the value for the argument
401
437
// declared as a pointer to global or constant memory in the kernel"
@@ -409,8 +445,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
409
445
}
410
446
411
447
std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
412
- ZE2UR_CALL (zeKernelSetArgumentValue,
413
- (Kernel->ZeKernel , ArgIndex, ArgSize, PArgValue));
448
+ for (auto It : Kernel->ZeKernelMap ) {
449
+ auto ZeKernel = It.second ;
450
+ ZE2UR_CALL (zeKernelSetArgumentValue,
451
+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
452
+ }
414
453
415
454
return UR_RESULT_SUCCESS;
416
455
}
@@ -596,16 +635,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(
596
635
597
636
auto KernelProgram = Kernel->Program ;
598
637
if (Kernel->OwnNativeHandle ) {
599
- auto ZeResult = ZE_CALL_NOCHECK (zeKernelDestroy, (Kernel->ZeKernel ));
600
- // Gracefully handle the case that L0 was already unloaded.
601
- if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
602
- return ze2urResult (ZeResult);
638
+ for (auto &ZeKernel : Kernel->ZeKernels ) {
639
+ auto ZeResult = ZE_CALL_NOCHECK (zeKernelDestroy, (ZeKernel));
640
+ // Gracefully handle the case that L0 was already unloaded.
641
+ if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
642
+ return ze2urResult (ZeResult);
643
+ }
603
644
}
645
+ Kernel->ZeKernelMap .clear ();
604
646
if (IndirectAccessTrackingEnabled) {
605
647
UR_CALL (urContextRelease (KernelProgram->Context ));
606
648
}
607
- // do a release on the program this kernel was part of
608
- UR_CALL (urProgramRelease (KernelProgram));
649
+ // do a release on the program this kernel was part of without delete of the
650
+ // program handle
651
+ KernelProgram->ur_release_program_resources (false );
652
+
609
653
delete Kernel;
610
654
611
655
return UR_RESULT_SUCCESS;
@@ -639,6 +683,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
639
683
std::ignore = PropSize;
640
684
std::ignore = Properties;
641
685
686
+ auto ZeKernel = Kernel->ZeKernel ;
642
687
std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
643
688
if (PropName == UR_KERNEL_EXEC_INFO_USM_INDIRECT_ACCESS &&
644
689
*(static_cast <const ur_bool_t *>(PropValue)) == true ) {
@@ -649,7 +694,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
649
694
ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST |
650
695
ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE |
651
696
ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED;
652
- ZE2UR_CALL (zeKernelSetIndirectAccess, (Kernel-> ZeKernel , IndirectFlags));
697
+ ZE2UR_CALL (zeKernelSetIndirectAccess, (ZeKernel, IndirectFlags));
653
698
} else if (PropName == UR_KERNEL_EXEC_INFO_CACHE_CONFIG) {
654
699
ze_cache_config_flag_t ZeCacheConfig{};
655
700
auto CacheConfig =
@@ -663,7 +708,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
663
708
else
664
709
// Unexpected cache configuration value.
665
710
return UR_RESULT_ERROR_INVALID_VALUE;
666
- ZE2UR_CALL (zeKernelSetCacheConfig, (Kernel-> ZeKernel , ZeCacheConfig););
711
+ ZE2UR_CALL (zeKernelSetCacheConfig, (ZeKernel, ZeCacheConfig););
667
712
} else {
668
713
urPrint (" urKernelSetExecInfo: unsupported ParamName\n " );
669
714
return UR_RESULT_ERROR_INVALID_VALUE;
0 commit comments