@@ -895,28 +895,31 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t CommandBuffer) {
895895/* *
896896 * Sets the kernel arguments for a kernel command that will be appended to the
897897 * command buffer.
898- * @param[in] CommandBuffer The CommandBuffer where the command will be
898+ * @param[in] Device The Device associated with the command-buffer where the
899+ * kernel command will be appended.
900+ * @param[in,out] Arguments stored in the ur_kernel_handle_t object to be set
901+ * on the /p ZeKernel object.
902+ * @param[in] ZeKernel The handle to the Level-Zero kernel that will be
899903 * appended.
900- * @param[in] Kernel The handle to the kernel that will be appended.
901904 * @return UR_RESULT_SUCCESS or an error code on failure
902905 */
903- ur_result_t
904- setKernelPendingArguments ( ur_exp_command_buffer_handle_t CommandBuffer ,
905- ur_kernel_handle_t Kernel) {
906-
906+ ur_result_t setKernelPendingArguments (
907+ ur_device_handle_t Device ,
908+ std::vector<ur_kernel_handle_t_::ArgumentInfo> &PendingArguments,
909+ ze_kernel_handle_t ZeKernel) {
907910 // If there are any pending arguments set them now.
908- for (auto &Arg : Kernel-> PendingArguments ) {
911+ for (auto &Arg : PendingArguments) {
909912 // The ArgValue may be a NULL pointer in which case a NULL value is used for
910913 // the kernel argument declared as a pointer to global or constant memory.
911914 char **ZeHandlePtr = nullptr ;
912915 if (Arg.Value ) {
913- UR_CALL (Arg.Value ->getZeHandlePtr (ZeHandlePtr, Arg.AccessMode ,
914- CommandBuffer-> Device , nullptr , 0u ));
916+ UR_CALL (Arg.Value ->getZeHandlePtr (ZeHandlePtr, Arg.AccessMode , Device,
917+ nullptr , 0u ));
915918 }
916919 ZE2UR_CALL (zeKernelSetArgumentValue,
917- (Kernel-> ZeKernel , Arg.Index , Arg.Size , ZeHandlePtr));
920+ (ZeKernel, Arg.Index , Arg.Size , ZeHandlePtr));
918921 }
919- Kernel-> PendingArguments .clear ();
922+ PendingArguments.clear ();
920923
921924 return UR_RESULT_SUCCESS;
922925}
@@ -952,21 +955,29 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
952955 ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;
953956
954957 auto Platform = CommandBuffer->Context ->getPlatform ();
958+ auto ZeDevice = CommandBuffer->Device ->ZeDevice ;
959+
955960 if (NumKernelAlternatives > 0 ) {
956961 ZeMutableCommandDesc.flags |=
957962 ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;
958963
959964 std::vector<ze_kernel_handle_t > TranslatedKernelHandles (
960965 NumKernelAlternatives + 1 , nullptr );
961966
967+ ze_kernel_handle_t ZeMainKernel{};
968+ UR_CALL (getZeKernel (ZeDevice, Kernel, &ZeMainKernel));
969+
962970 // Translate main kernel first
963971 ZE2UR_CALL (zelLoaderTranslateHandle,
964- (ZEL_HANDLE_KERNEL, Kernel-> ZeKernel ,
972+ (ZEL_HANDLE_KERNEL, ZeMainKernel ,
965973 (void **)&TranslatedKernelHandles[0 ]));
966974
967975 for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
976+ ze_kernel_handle_t ZeAltKernel{};
977+ UR_CALL (getZeKernel (ZeDevice, KernelAlternatives[i], &ZeAltKernel));
978+
968979 ZE2UR_CALL (zelLoaderTranslateHandle,
969- (ZEL_HANDLE_KERNEL, KernelAlternatives[i]-> ZeKernel ,
980+ (ZEL_HANDLE_KERNEL, ZeAltKernel ,
970981 (void **)&TranslatedKernelHandles[i + 1 ]));
971982 }
972983
@@ -1023,23 +1034,28 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
10231034 std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
10241035 Kernel->Mutex , Kernel->Program ->Mutex , CommandBuffer->Mutex );
10251036
1037+ auto Device = CommandBuffer->Device ;
1038+ ze_kernel_handle_t ZeKernel{};
1039+ UR_CALL (getZeKernel (Device->ZeDevice , Kernel, &ZeKernel));
1040+
10261041 if (GlobalWorkOffset != NULL ) {
1027- UR_CALL (setKernelGlobalOffset (CommandBuffer->Context , Kernel-> ZeKernel ,
1028- WorkDim, GlobalWorkOffset));
1042+ UR_CALL (setKernelGlobalOffset (CommandBuffer->Context , ZeKernel, WorkDim ,
1043+ GlobalWorkOffset));
10291044 }
10301045
10311046 // If there are any pending arguments set them now.
10321047 if (!Kernel->PendingArguments .empty ()) {
1033- UR_CALL (setKernelPendingArguments (CommandBuffer, Kernel));
1048+ UR_CALL (
1049+ setKernelPendingArguments (Device, Kernel->PendingArguments , ZeKernel));
10341050 }
10351051
10361052 ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
10371053 uint32_t WG[3 ];
1038- UR_CALL (calculateKernelWorkDimensions (Kernel-> ZeKernel , CommandBuffer-> Device ,
1054+ UR_CALL (calculateKernelWorkDimensions (ZeKernel, Device,
10391055 ZeThreadGroupDimensions, WG, WorkDim,
10401056 GlobalWorkSize, LocalWorkSize));
10411057
1042- ZE2UR_CALL (zeKernelSetGroupSize, (Kernel-> ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
1058+ ZE2UR_CALL (zeKernelSetGroupSize, (ZeKernel, WG[0 ], WG[1 ], WG[2 ]));
10431059
10441060 CommandBuffer->KernelsList .push_back (Kernel);
10451061 for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
@@ -1064,7 +1080,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
10641080 SyncPointWaitList, false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
10651081
10661082 ZE2UR_CALL (zeCommandListAppendLaunchKernel,
1067- (CommandBuffer->ZeComputeCommandList , Kernel-> ZeKernel ,
1083+ (CommandBuffer->ZeComputeCommandList , ZeKernel,
10681084 &ZeThreadGroupDimensions, ZeLaunchEvent, ZeEventList.size (),
10691085 getPointerFromVector (ZeEventList)));
10701086
@@ -1837,6 +1853,7 @@ ur_result_t updateKernelCommand(
18371853 const auto CommandBuffer = Command->CommandBuffer ;
18381854 const void *NextDesc = nullptr ;
18391855 auto Platform = CommandBuffer->Context ->getPlatform ();
1856+ auto ZeDevice = CommandBuffer->Device ->ZeDevice ;
18401857
18411858 uint32_t Dim = CommandDesc->newWorkDim ;
18421859 size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset ;
@@ -1845,11 +1862,14 @@ ur_result_t updateKernelCommand(
18451862
18461863 // Kernel handle must be updated first for a given CommandId if required
18471864 ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel ;
1865+
18481866 if (NewKernel && Command->Kernel != NewKernel) {
1867+ ze_kernel_handle_t ZeNewKernel{};
1868+ UR_CALL (getZeKernel (ZeDevice, NewKernel, &ZeNewKernel));
1869+
18491870 ze_kernel_handle_t ZeKernelTranslated = nullptr ;
1850- ZE2UR_CALL (
1851- zelLoaderTranslateHandle,
1852- (ZEL_HANDLE_KERNEL, NewKernel->ZeKernel , (void **)&ZeKernelTranslated));
1871+ ZE2UR_CALL (zelLoaderTranslateHandle,
1872+ (ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&ZeKernelTranslated));
18531873
18541874 ZE2UR_CALL (Platform->ZeMutableCmdListExt
18551875 .zexCommandListUpdateMutableCommandKernelsExp ,
@@ -1906,10 +1926,13 @@ ur_result_t updateKernelCommand(
19061926 // by the driver for the kernel.
19071927 bool UpdateWGSize = NewLocalWorkSize == nullptr ;
19081928
1929+ ze_kernel_handle_t ZeKernel{};
1930+ UR_CALL (getZeKernel (ZeDevice, Command->Kernel , &ZeKernel));
1931+
19091932 uint32_t WG[3 ];
1910- UR_CALL (calculateKernelWorkDimensions (
1911- Command-> Kernel -> ZeKernel , CommandBuffer-> Device ,
1912- ZeThreadGroupDimensions, WG, Dim, NewGlobalWorkSize, NewLocalWorkSize));
1933+ UR_CALL (calculateKernelWorkDimensions (ZeKernel, CommandBuffer-> Device ,
1934+ ZeThreadGroupDimensions, WG, Dim ,
1935+ NewGlobalWorkSize, NewLocalWorkSize));
19131936
19141937 auto MutableGroupCountDesc =
19151938 std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t >>();
0 commit comments