@@ -718,31 +718,28 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t CommandBuffer) {
718718 return UR_RESULT_SUCCESS;
719719}
720720
721- UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp (
722- ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
723- uint32_t WorkDim, const size_t *GlobalWorkOffset,
724- const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
725- uint32_t NumSyncPointsInWaitList,
726- const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
727- ur_exp_command_buffer_sync_point_t *RetSyncPoint,
728- ur_exp_command_buffer_command_handle_t *Command) {
729- UR_ASSERT (Kernel->Program , UR_RESULT_ERROR_INVALID_NULL_POINTER);
730- // Lock automatically releases when this goes out of scope.
731- std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
732- Kernel->Mutex , Kernel->Program ->Mutex , CommandBuffer->Mutex );
733-
734- if (GlobalWorkOffset != NULL ) {
735- if (!CommandBuffer->Context ->getPlatform ()
736- ->ZeDriverGlobalOffsetExtensionFound ) {
737- logger::debug (" No global offset extension found on this driver" );
738- return UR_RESULT_ERROR_INVALID_VALUE;
739- }
721+ static ur_result_t
722+ setKernelGlobalOffset (ur_exp_command_buffer_handle_t CommandBuffer,
723+ ur_kernel_handle_t Kernel,
724+ const size_t *GlobalWorkOffset) {
740725
741- ZE2UR_CALL (zeKernelSetGlobalOffsetExp,
742- (Kernel->ZeKernel , GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
743- GlobalWorkOffset[2 ]));
726+ if (!CommandBuffer->Context ->getPlatform ()
727+ ->ZeDriverGlobalOffsetExtensionFound ) {
728+ logger::debug (" No global offset extension found on this driver" );
729+ return UR_RESULT_ERROR_INVALID_VALUE;
744730 }
745731
732+ ZE2UR_CALL (zeKernelSetGlobalOffsetExp,
733+ (Kernel->ZeKernel , GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
734+ GlobalWorkOffset[2 ]));
735+
736+ return UR_RESULT_SUCCESS;
737+ }
738+
739+ static ur_result_t
740+ setKernelPendingArguments (ur_exp_command_buffer_handle_t CommandBuffer,
741+ ur_kernel_handle_t Kernel) {
742+
746743 // If there are any pending arguments set them now.
747744 for (auto &Arg : Kernel->PendingArguments ) {
748745 // The ArgValue may be a NULL pointer in which case a NULL value is used for
@@ -757,25 +754,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
757754 }
758755 Kernel->PendingArguments .clear ();
759756
760- ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
761- uint32_t WG[3 ];
762-
763- UR_CALL (calculateKernelWorkDimensions (Kernel, CommandBuffer->Device ,
764- ZeThreadGroupDimensions, WG, WorkDim,
765- GlobalWorkSize, LocalWorkSize));
766-
767- ZE2UR_CALL (zeKernelSetGroupSize, (Kernel->ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
757+ return UR_RESULT_SUCCESS;
758+ }
768759
769- CommandBuffer->KernelsList .push_back (Kernel);
770- // Increment the reference count of the Kernel and indicate that the Kernel
771- // is in use. Once the event has been signaled, the code in
772- // CleanupCompletedEvent(Event) will do a urKernelRelease to update the
773- // reference count on the kernel, using the kernel saved in CommandData.
774- UR_CALL (urKernelRetain (Kernel));
760+ static ur_result_t
761+ createCommandHandle (ur_exp_command_buffer_handle_t CommandBuffer,
762+ ur_kernel_handle_t Kernel, uint32_t WorkDim,
763+ const size_t *LocalWorkSize,
764+ ur_exp_command_buffer_command_handle_t & Command) {
775765
776766 // If command-buffer is updatable then get command id which is going to be
777767 // used if command is updated in the future. This
778- // zeCommandListGetNextCommandIdExp can be called only if command is
768+ // zeCommandListGetNextCommandIdExp can be called only if the command is
779769 // updatable.
780770 uint64_t CommandId = 0 ;
781771 if (CommandBuffer->IsUpdatable ) {
@@ -794,15 +784,60 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
794784 DEBUG_LOG (CommandId);
795785 }
796786 try {
797- if (Command)
798- *Command = new ur_exp_command_buffer_command_handle_t_ (
799- CommandBuffer, CommandId, WorkDim, LocalWorkSize != nullptr , Kernel);
787+ Command = new ur_exp_command_buffer_command_handle_t_ (
788+ CommandBuffer, CommandId, WorkDim, LocalWorkSize != nullptr , Kernel);
800789 } catch (const std::bad_alloc &) {
801790 return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
802791 } catch (...) {
803792 return UR_RESULT_ERROR_UNKNOWN;
804793 }
805794
795+ return UR_RESULT_SUCCESS;
796+ }
797+ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp (
798+ ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
799+ uint32_t WorkDim, const size_t *GlobalWorkOffset,
800+ const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
801+ uint32_t NumSyncPointsInWaitList,
802+ const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
803+ ur_exp_command_buffer_sync_point_t *RetSyncPoint,
804+ ur_exp_command_buffer_command_handle_t *Command) {
805+ UR_ASSERT (Kernel->Program , UR_RESULT_ERROR_INVALID_NULL_POINTER);
806+
807+ // Lock automatically releases when this goes out of scope.
808+ std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
809+ Kernel->Mutex , Kernel->Program ->Mutex , CommandBuffer->Mutex );
810+
811+ if (GlobalWorkOffset != NULL ) {
812+ UR_CALL (setKernelGlobalOffset (CommandBuffer, Kernel, GlobalWorkOffset));
813+ }
814+
815+ // If there are any pending arguments set them now.
816+ if (!Kernel->PendingArguments .empty ()) {
817+ UR_CALL (setKernelPendingArguments (CommandBuffer, Kernel));
818+ }
819+
820+ ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
821+ uint32_t WG[3 ];
822+ UR_CALL (calculateKernelWorkDimensions (Kernel, CommandBuffer->Device ,
823+ ZeThreadGroupDimensions, WG, WorkDim,
824+ GlobalWorkSize, LocalWorkSize));
825+
826+ ZE2UR_CALL (zeKernelSetGroupSize, (Kernel->ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
827+
828+ CommandBuffer->KernelsList .push_back (Kernel);
829+
830+ // Increment the reference count of the Kernel and indicate that the Kernel
831+ // is in use. Once the event has been signaled, the code in
832+ // CleanupCompletedEvent(Event) will do a urKernelRelease to update the
833+ // reference count on the kernel, using the kernel saved in CommandData.
834+ UR_CALL (urKernelRetain (Kernel));
835+
836+ if (Command && CommandBuffer->IsUpdatable ) {
837+ UR_CALL (createCommandHandle (CommandBuffer, Kernel, WorkDim, LocalWorkSize,
838+ *Command));
839+ }
840+
806841 if (CommandBuffer->IsInOrderCmdList ) {
807842 ZE2UR_CALL (zeCommandListAppendLaunchKernel,
808843 (CommandBuffer->ZeComputeCommandList , Kernel->ZeKernel ,
0 commit comments