@@ -476,21 +476,14 @@ void ur_exp_command_buffer_handle_t_::cleanupCommandBufferResources() {
476476
477477ur_exp_command_buffer_command_handle_t_::
478478 ur_exp_command_buffer_command_handle_t_ (
479- ur_exp_command_buffer_handle_t CommandBuffer, uint64_t CommandId,
480- uint32_t WorkDim, bool UserDefinedLocalSize,
481- ur_kernel_handle_t Kernel = nullptr )
482- : CommandBuffer(CommandBuffer), CommandId(CommandId), WorkDim(WorkDim),
483- UserDefinedLocalSize(UserDefinedLocalSize), Kernel(Kernel) {
479+ ur_exp_command_buffer_handle_t CommandBuffer, uint64_t CommandId)
480+ : CommandBuffer(CommandBuffer), CommandId(CommandId) {
484481 ur::level_zero::urCommandBufferRetainExp (CommandBuffer);
485- if (Kernel)
486- ur::level_zero::urKernelRetain (Kernel);
487482}
488483
489484ur_exp_command_buffer_command_handle_t_::
490485 ~ur_exp_command_buffer_command_handle_t_ () {
491486 ur::level_zero::urCommandBufferReleaseExp (CommandBuffer);
492- if (Kernel)
493- ur::level_zero::urKernelRelease (Kernel);
494487}
495488
496489void ur_exp_command_buffer_handle_t_::registerSyncPoint (
@@ -527,6 +520,31 @@ ur_result_t ur_exp_command_buffer_handle_t_::getFenceForQueue(
527520 return UR_RESULT_SUCCESS;
528521}
529522
523+ kernel_command_handle::kernel_command_handle (
524+ ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
525+ uint64_t CommandId, uint32_t WorkDim, bool UserDefinedLocalSize,
526+ uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives)
527+ : ur_exp_command_buffer_command_handle_t_(CommandBuffer, CommandId),
528+ WorkDim(WorkDim), UserDefinedLocalSize(UserDefinedLocalSize),
529+ Kernel(Kernel) {
530+ // Add the default kernel to the list of valid kernels
531+ ur::level_zero::urKernelRetain (Kernel);
532+ ValidKernelHandles.insert (Kernel);
533+ // Add alternative kernels if provided
534+ if (KernelAlternatives) {
535+ for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
536+ ur::level_zero::urKernelRetain (KernelAlternatives[i]);
537+ ValidKernelHandles.insert (KernelAlternatives[i]);
538+ }
539+ }
540+ }
541+
542+ kernel_command_handle::~kernel_command_handle () {
543+ for (const ur_kernel_handle_t &KernelHandle : ValidKernelHandles) {
544+ ur::level_zero::urKernelRelease (KernelHandle);
545+ }
546+ }
547+
530548namespace ur ::level_zero {
531549
532550/* *
@@ -906,7 +924,8 @@ setKernelPendingArguments(ur_exp_command_buffer_handle_t CommandBuffer,
906924ur_result_t
907925createCommandHandle (ur_exp_command_buffer_handle_t CommandBuffer,
908926 ur_kernel_handle_t Kernel, uint32_t WorkDim,
909- const size_t *LocalWorkSize,
927+ const size_t *LocalWorkSize, uint32_t NumKernelAlternatives,
928+ ur_kernel_handle_t *KernelAlternatives,
910929 ur_exp_command_buffer_command_handle_t &Command) {
911930
912931 assert (CommandBuffer->IsUpdatable );
@@ -923,14 +942,41 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
923942 ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;
924943
925944 auto Platform = CommandBuffer->Context ->getPlatform ();
926- ZE2UR_CALL (Platform->ZeMutableCmdListExt .zexCommandListGetNextCommandIdExp ,
927- (CommandBuffer->ZeComputeCommandListTranslated ,
928- &ZeMutableCommandDesc, &CommandId));
945+ if (NumKernelAlternatives > 0 ) {
946+ ZeMutableCommandDesc.flags |=
947+ ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;
948+
949+ std::vector<ze_kernel_handle_t > TranslatedKernelHandles (
950+ NumKernelAlternatives + 1 , nullptr );
951+
952+ // Translate main kernel first
953+ ZE2UR_CALL (zelLoaderTranslateHandle,
954+ (ZEL_HANDLE_KERNEL, Kernel->ZeKernel ,
955+ (void **)&TranslatedKernelHandles[0 ]));
956+
957+ for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
958+ ZE2UR_CALL (zelLoaderTranslateHandle,
959+ (ZEL_HANDLE_KERNEL, KernelAlternatives[i]->ZeKernel ,
960+ (void **)&TranslatedKernelHandles[i + 1 ]));
961+ }
962+
963+ ZE2UR_CALL (Platform->ZeMutableCmdListExt
964+ .zexCommandListGetNextCommandIdWithKernelsExp ,
965+ (CommandBuffer->ZeComputeCommandListTranslated ,
966+ &ZeMutableCommandDesc, NumKernelAlternatives + 1 ,
967+ TranslatedKernelHandles.data (), &CommandId));
968+
969+ } else {
970+ ZE2UR_CALL (Platform->ZeMutableCmdListExt .zexCommandListGetNextCommandIdExp ,
971+ (CommandBuffer->ZeComputeCommandListTranslated ,
972+ &ZeMutableCommandDesc, &CommandId));
973+ }
929974 DEBUG_LOG (CommandId);
930975
931976 try {
932- Command = new ur_exp_command_buffer_command_handle_t_ (
933- CommandBuffer, CommandId, WorkDim, LocalWorkSize != nullptr , Kernel);
977+ Command = new kernel_command_handle (
978+ CommandBuffer, Kernel, CommandId, WorkDim, LocalWorkSize != nullptr ,
979+ NumKernelAlternatives, KernelAlternatives);
934980 } catch (const std::bad_alloc &) {
935981 return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
936982 } catch (...) {
@@ -944,8 +990,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
944990 ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
945991 uint32_t WorkDim, const size_t *GlobalWorkOffset,
946992 const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
947- uint32_t /* numKernelAlternatives*/ ,
948- ur_kernel_handle_t * /* phKernelAlternatives*/ ,
993+ uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives,
949994 uint32_t NumSyncPointsInWaitList,
950995 const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
951996 uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
@@ -960,6 +1005,10 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
9601005 UR_ASSERT (!(Command && !CommandBuffer->IsUpdatable ),
9611006 UR_RESULT_ERROR_INVALID_OPERATION);
9621007
1008+ for (uint32_t i = 0 ; i < NumKernelAlternatives; ++i) {
1009+ UR_ASSERT (KernelAlternatives[i] != Kernel, UR_RESULT_ERROR_INVALID_VALUE);
1010+ }
1011+
9631012 // Lock automatically releases when this goes out of scope.
9641013 std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
9651014 Kernel->Mutex , Kernel->Program ->Mutex , CommandBuffer->Mutex );
@@ -983,18 +1032,21 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
9831032 ZE2UR_CALL (zeKernelSetGroupSize, (Kernel->ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
9841033
9851034 CommandBuffer->KernelsList .push_back (Kernel);
1035+ for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
1036+ CommandBuffer->KernelsList .push_back (KernelAlternatives[i]);
1037+ }
9861038
987- // Increment the reference count of the Kernel and indicate that the Kernel
988- // is in use. Once the event has been signaled, the code in
989- // CleanupCompletedEvent(Event) will do a urKernelRelease to update the
990- // reference count on the kernel, using the kernel saved in CommandData.
991- UR_CALL ( ur::level_zero::urKernelRetain (Kernel));
1039+ ur::level_zero::urKernelRetain ( Kernel);
1040+ // Retain alternative kernels if provided
1041+ for ( size_t i = 0 ; i < NumKernelAlternatives; i++) {
1042+ ur::level_zero::urKernelRetain (KernelAlternatives[i]);
1043+ }
9921044
9931045 if (Command) {
9941046 UR_CALL (createCommandHandle (CommandBuffer, Kernel, WorkDim, LocalWorkSize,
1047+ NumKernelAlternatives, KernelAlternatives,
9951048 *Command));
9961049 }
997-
9981050 std::vector<ze_event_handle_t > ZeEventList;
9991051 ze_event_handle_t ZeLaunchEvent = nullptr ;
10001052 UR_CALL (createSyncPointAndGetZeEvents (
@@ -1690,7 +1742,7 @@ ur_result_t urCommandBufferReleaseCommandExp(
16901742 * @return UR_RESULT_SUCCESS or an error code on failure
16911743 */
16921744ur_result_t validateCommandDesc (
1693- ur_exp_command_buffer_command_handle_t Command,
1745+ kernel_command_handle * Command,
16941746 const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
16951747
16961748 auto CommandBuffer = Command->CommandBuffer ;
@@ -1699,9 +1751,14 @@ ur_result_t validateCommandDesc(
16991751 ->mutableCommandFlags ;
17001752 logger::debug (" Mutable features supported by device {}" , SupportedFeatures);
17011753
1702- // Kernel handle updates are not yet supported.
1703- if (CommandDesc->hNewKernel && CommandDesc->hNewKernel != Command->Kernel ) {
1704- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1754+ UR_ASSERT (
1755+ !CommandDesc->hNewKernel ||
1756+ (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION),
1757+ UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1758+ // Check if the provided new kernel is in the list of valid alternatives.
1759+ if (CommandDesc->hNewKernel &&
1760+ !Command->ValidKernelHandles .count (CommandDesc->hNewKernel )) {
1761+ return UR_RESULT_ERROR_INVALID_VALUE;
17051762 }
17061763
17071764 if (CommandDesc->newWorkDim != Command->WorkDim &&
@@ -1754,7 +1811,7 @@ ur_result_t validateCommandDesc(
17541811 * @return UR_RESULT_SUCCESS or an error code on failure
17551812 */
17561813ur_result_t updateKernelCommand (
1757- ur_exp_command_buffer_command_handle_t Command,
1814+ kernel_command_handle * Command,
17581815 const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
17591816
17601817 // We need the created descriptors to live till the point when
@@ -1769,12 +1826,29 @@ ur_result_t updateKernelCommand(
17691826
17701827 const auto CommandBuffer = Command->CommandBuffer ;
17711828 const void *NextDesc = nullptr ;
1829+ auto Platform = CommandBuffer->Context ->getPlatform ();
17721830
17731831 uint32_t Dim = CommandDesc->newWorkDim ;
17741832 size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset ;
17751833 size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize ;
17761834 size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize ;
17771835
1836+ // Kernel handle must be updated first for a given CommandId if required
1837+ ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel ;
1838+ if (NewKernel && Command->Kernel != NewKernel) {
1839+ ze_kernel_handle_t ZeKernelTranslated = nullptr ;
1840+ ZE2UR_CALL (
1841+ zelLoaderTranslateHandle,
1842+ (ZEL_HANDLE_KERNEL, NewKernel->ZeKernel , (void **)&ZeKernelTranslated));
1843+
1844+ ZE2UR_CALL (Platform->ZeMutableCmdListExt
1845+ .zexCommandListUpdateMutableCommandKernelsExp ,
1846+ (CommandBuffer->ZeComputeCommandListTranslated , 1 ,
1847+ &Command->CommandId , &ZeKernelTranslated));
1848+ // Set current kernel to be the new kernel
1849+ Command->Kernel = NewKernel;
1850+ }
1851+
17781852 // Check if a new global offset is provided.
17791853 if (NewGlobalWorkOffset && Dim > 0 ) {
17801854 auto MutableGroupOffestDesc =
@@ -1973,7 +2047,6 @@ ur_result_t updateKernelCommand(
19732047 MutableCommandDesc.pNext = NextDesc;
19742048 MutableCommandDesc.flags = 0 ;
19752049
1976- auto Platform = CommandBuffer->Context ->getPlatform ();
19772050 ZE2UR_CALL (
19782051 Platform->ZeMutableCmdListExt .zexCommandListUpdateMutableCommandsExp ,
19792052 (CommandBuffer->ZeComputeCommandListTranslated , &MutableCommandDesc));
@@ -2009,18 +2082,22 @@ ur_result_t urCommandBufferUpdateKernelLaunchExp(
20092082 const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
20102083 UR_ASSERT (Command->CommandBuffer ->IsUpdatable ,
20112084 UR_RESULT_ERROR_INVALID_OPERATION);
2012- UR_ASSERT (Command->Kernel , UR_RESULT_ERROR_INVALID_NULL_HANDLE);
2085+
2086+ auto KernelCommandHandle = static_cast <kernel_command_handle *>(Command);
2087+
2088+ UR_ASSERT (KernelCommandHandle->Kernel , UR_RESULT_ERROR_INVALID_NULL_HANDLE);
20132089
20142090 // Lock command, kernel and command buffer for update.
20152091 std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Guard (
2016- Command->Mutex , Command->CommandBuffer ->Mutex , Command->Kernel ->Mutex );
2092+ Command->Mutex , Command->CommandBuffer ->Mutex ,
2093+ KernelCommandHandle->Kernel ->Mutex );
20172094
20182095 UR_ASSERT (Command->CommandBuffer ->IsFinalized ,
20192096 UR_RESULT_ERROR_INVALID_OPERATION);
20202097
2021- UR_CALL (validateCommandDesc (Command , CommandDesc));
2098+ UR_CALL (validateCommandDesc (KernelCommandHandle , CommandDesc));
20222099 UR_CALL (waitForOngoingExecution (Command->CommandBuffer ));
2023- UR_CALL (updateKernelCommand (Command , CommandDesc));
2100+ UR_CALL (updateKernelCommand (KernelCommandHandle , CommandDesc));
20242101
20252102 ZE2UR_CALL (zeCommandListClose,
20262103 (Command->CommandBuffer ->ZeComputeCommandList ));
0 commit comments