@@ -894,28 +894,31 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t CommandBuffer) {
894894/* *
895895 * Sets the kernel arguments for a kernel command that will be appended to the
896896 * command buffer.
897- * @param[in] CommandBuffer The CommandBuffer where the command will be
897+ * @param[in] Device The Device associated with the command-buffer where the
898+ * kernel command will be appended.
899+ * @param[in,out] Arguments stored in the ur_kernel_handle_t object to be set
900+ * on the /p ZeKernel object.
901+ * @param[in] ZeKernel The handle to the Level-Zero kernel that will be
898902 * appended.
899- * @param[in] Kernel The handle to the kernel that will be appended.
900903 * @return UR_RESULT_SUCCESS or an error code on failure
901904 */
902- ur_result_t
903- setKernelPendingArguments ( ur_exp_command_buffer_handle_t CommandBuffer ,
904- ur_kernel_handle_t Kernel) {
905-
905+ ur_result_t setKernelPendingArguments (
906+ ur_device_handle_t Device ,
907+ std::vector<ur_kernel_handle_t_::ArgumentInfo> &PendingArguments,
908+ ze_kernel_handle_t ZeKernel) {
906909 // If there are any pending arguments set them now.
907- for (auto &Arg : Kernel-> PendingArguments ) {
910+ for (auto &Arg : PendingArguments) {
908911 // The ArgValue may be a NULL pointer in which case a NULL value is used for
909912 // the kernel argument declared as a pointer to global or constant memory.
910913 char **ZeHandlePtr = nullptr ;
911914 if (Arg.Value ) {
912- UR_CALL (Arg.Value ->getZeHandlePtr (ZeHandlePtr, Arg.AccessMode ,
913- CommandBuffer-> Device , nullptr , 0u ));
915+ UR_CALL (Arg.Value ->getZeHandlePtr (ZeHandlePtr, Arg.AccessMode , Device,
916+ nullptr , 0u ));
914917 }
915918 ZE2UR_CALL (zeKernelSetArgumentValue,
916- (Kernel-> ZeKernel , Arg.Index , Arg.Size , ZeHandlePtr));
919+ (ZeKernel, Arg.Index , Arg.Size , ZeHandlePtr));
917920 }
918- Kernel-> PendingArguments .clear ();
921+ PendingArguments.clear ();
919922
920923 return UR_RESULT_SUCCESS;
921924}
@@ -951,21 +954,29 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
951954 ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;
952955
953956 auto Platform = CommandBuffer->Context ->getPlatform ();
957+ auto ZeDevice = CommandBuffer->Device ->ZeDevice ;
958+
954959 if (NumKernelAlternatives > 0 ) {
955960 ZeMutableCommandDesc.flags |=
956961 ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;
957962
958963 std::vector<ze_kernel_handle_t > TranslatedKernelHandles (
959964 NumKernelAlternatives + 1 , nullptr );
960965
966+ ze_kernel_handle_t ZeMainKernel{};
967+ UR_CALL (getZeKernel (ZeDevice, Kernel, &ZeMainKernel));
968+
961969 // Translate main kernel first
962970 ZE2UR_CALL (zelLoaderTranslateHandle,
963- (ZEL_HANDLE_KERNEL, Kernel-> ZeKernel ,
971+ (ZEL_HANDLE_KERNEL, ZeMainKernel ,
964972 (void **)&TranslatedKernelHandles[0 ]));
965973
966974 for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
975+ ze_kernel_handle_t ZeAltKernel{};
976+ UR_CALL (getZeKernel (ZeDevice, KernelAlternatives[i], &ZeAltKernel));
977+
967978 ZE2UR_CALL (zelLoaderTranslateHandle,
968- (ZEL_HANDLE_KERNEL, KernelAlternatives[i]-> ZeKernel ,
979+ (ZEL_HANDLE_KERNEL, ZeAltKernel ,
969980 (void **)&TranslatedKernelHandles[i + 1 ]));
970981 }
971982
@@ -1022,23 +1033,28 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
10221033 std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
10231034 Kernel->Mutex , Kernel->Program ->Mutex , CommandBuffer->Mutex );
10241035
1036+ auto Device = CommandBuffer->Device ;
1037+ ze_kernel_handle_t ZeKernel{};
1038+ UR_CALL (getZeKernel (Device->ZeDevice , Kernel, &ZeKernel));
1039+
10251040 if (GlobalWorkOffset != NULL ) {
1026- UR_CALL (setKernelGlobalOffset (CommandBuffer->Context , Kernel-> ZeKernel ,
1027- WorkDim, GlobalWorkOffset));
1041+ UR_CALL (setKernelGlobalOffset (CommandBuffer->Context , ZeKernel, WorkDim ,
1042+ GlobalWorkOffset));
10281043 }
10291044
10301045 // If there are any pending arguments set them now.
10311046 if (!Kernel->PendingArguments .empty ()) {
1032- UR_CALL (setKernelPendingArguments (CommandBuffer, Kernel));
1047+ UR_CALL (
1048+ setKernelPendingArguments (Device, Kernel->PendingArguments , ZeKernel));
10331049 }
10341050
10351051 ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
10361052 uint32_t WG[3 ];
1037- UR_CALL (calculateKernelWorkDimensions (Kernel-> ZeKernel , CommandBuffer-> Device ,
1053+ UR_CALL (calculateKernelWorkDimensions (ZeKernel, Device,
10381054 ZeThreadGroupDimensions, WG, WorkDim,
10391055 GlobalWorkSize, LocalWorkSize));
10401056
1041- ZE2UR_CALL (zeKernelSetGroupSize, (Kernel-> ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
1057+ ZE2UR_CALL (zeKernelSetGroupSize, (ZeKernel, WG[0 ], WG[1 ], WG[2 ]));
10421058
10431059 CommandBuffer->KernelsList .push_back (Kernel);
10441060 for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
@@ -1063,7 +1079,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
10631079 SyncPointWaitList, false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
10641080
10651081 ZE2UR_CALL (zeCommandListAppendLaunchKernel,
1066- (CommandBuffer->ZeComputeCommandList , Kernel-> ZeKernel ,
1082+ (CommandBuffer->ZeComputeCommandList , ZeKernel,
10671083 &ZeThreadGroupDimensions, ZeLaunchEvent, ZeEventList.size (),
10681084 getPointerFromVector (ZeEventList)));
10691085
@@ -1836,6 +1852,7 @@ ur_result_t updateKernelCommand(
18361852 const auto CommandBuffer = Command->CommandBuffer ;
18371853 const void *NextDesc = nullptr ;
18381854 auto Platform = CommandBuffer->Context ->getPlatform ();
1855+ auto ZeDevice = CommandBuffer->Device ->ZeDevice ;
18391856
18401857 uint32_t Dim = CommandDesc->newWorkDim ;
18411858 size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset ;
@@ -1844,11 +1861,14 @@ ur_result_t updateKernelCommand(
18441861
18451862 // Kernel handle must be updated first for a given CommandId if required
18461863 ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel ;
1864+
18471865 if (NewKernel && Command->Kernel != NewKernel) {
1866+ ze_kernel_handle_t ZeNewKernel{};
1867+ UR_CALL (getZeKernel (ZeDevice, NewKernel, &ZeNewKernel));
1868+
18481869 ze_kernel_handle_t ZeKernelTranslated = nullptr ;
1849- ZE2UR_CALL (
1850- zelLoaderTranslateHandle,
1851- (ZEL_HANDLE_KERNEL, NewKernel->ZeKernel , (void **)&ZeKernelTranslated));
1870+ ZE2UR_CALL (zelLoaderTranslateHandle,
1871+ (ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&ZeKernelTranslated));
18521872
18531873 ZE2UR_CALL (Platform->ZeMutableCmdListExt
18541874 .zexCommandListUpdateMutableCommandKernelsExp ,
@@ -1905,10 +1925,13 @@ ur_result_t updateKernelCommand(
19051925 // by the driver for the kernel.
19061926 bool UpdateWGSize = NewLocalWorkSize == nullptr ;
19071927
1928+ ze_kernel_handle_t ZeKernel{};
1929+ UR_CALL (getZeKernel (ZeDevice, Command->Kernel , &ZeKernel));
1930+
19081931 uint32_t WG[3 ];
1909- UR_CALL (calculateKernelWorkDimensions (
1910- Command-> Kernel -> ZeKernel , CommandBuffer-> Device ,
1911- ZeThreadGroupDimensions, WG, Dim, NewGlobalWorkSize, NewLocalWorkSize));
1932+ UR_CALL (calculateKernelWorkDimensions (ZeKernel, CommandBuffer-> Device ,
1933+ ZeThreadGroupDimensions, WG, Dim ,
1934+ NewGlobalWorkSize, NewLocalWorkSize));
19121935
19131936 auto MutableGroupCountDesc =
19141937 std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t >>();
0 commit comments