@@ -949,41 +949,53 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
949949
950950 auto Platform = CommandBuffer->Context ->getPlatform ();
951951 auto ZeDevice = CommandBuffer->Device ->ZeDevice ;
952+ ze_command_list_handle_t ZeCommandList =
953+ CommandBuffer->ZeComputeCommandListTranslated ;
954+ if (Platform->ZeMutableCmdListExt .LoaderExtension ) {
955+ ZeCommandList = CommandBuffer->ZeComputeCommandList ;
956+ }
952957
953958 if (NumKernelAlternatives > 0 ) {
954959 ZeMutableCommandDesc.flags |=
955960 ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;
956961
957- std::vector<ze_kernel_handle_t > TranslatedKernelHandles (
958- NumKernelAlternatives + 1 , nullptr );
962+ std::vector<ze_kernel_handle_t > KernelHandles (NumKernelAlternatives + 1 ,
963+ nullptr );
959964
960965 ze_kernel_handle_t ZeMainKernel{};
961966 UR_CALL (getZeKernel (ZeDevice, Kernel, &ZeMainKernel));
962967
963- // Translate main kernel first
964- ZE2UR_CALL (zelLoaderTranslateHandle,
965- (ZEL_HANDLE_KERNEL, ZeMainKernel,
966- (void **)&TranslatedKernelHandles[0 ]));
968+ if (Platform->ZeMutableCmdListExt .LoaderExtension ) {
969+ KernelHandles[0 ] = ZeMainKernel;
970+ } else {
971+ // If the L0 loader is not aware of the MCL extension, the main kernel
972+ // handle needs to be translated.
973+ ZE2UR_CALL (zelLoaderTranslateHandle,
974+ (ZEL_HANDLE_KERNEL, ZeMainKernel, (void **)&KernelHandles[0 ]));
975+ }
967976
968977 for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
969978 ze_kernel_handle_t ZeAltKernel{};
970979 UR_CALL (getZeKernel (ZeDevice, KernelAlternatives[i], &ZeAltKernel));
971980
972- ZE2UR_CALL (zelLoaderTranslateHandle,
973- (ZEL_HANDLE_KERNEL, ZeAltKernel,
974- (void **)&TranslatedKernelHandles[i + 1 ]));
981+ if (Platform->ZeMutableCmdListExt .LoaderExtension ) {
982+ KernelHandles[i + 1 ] = ZeAltKernel;
983+ } else {
984+ // If the L0 loader is not aware of the MCL extension, the kernel
985+ // alternatives need to be translated.
986+ ZE2UR_CALL (zelLoaderTranslateHandle, (ZEL_HANDLE_KERNEL, ZeAltKernel,
987+ (void **)&KernelHandles[i + 1 ]));
988+ }
975989 }
976990
977991 ZE2UR_CALL (Platform->ZeMutableCmdListExt
978992 .zexCommandListGetNextCommandIdWithKernelsExp ,
979- (CommandBuffer->ZeComputeCommandListTranslated ,
980- &ZeMutableCommandDesc, NumKernelAlternatives + 1 ,
981- TranslatedKernelHandles.data (), &CommandId));
993+ (ZeCommandList, &ZeMutableCommandDesc, NumKernelAlternatives + 1 ,
994+ KernelHandles.data (), &CommandId));
982995
983996 } else {
984997 ZE2UR_CALL (Platform->ZeMutableCmdListExt .zexCommandListGetNextCommandIdExp ,
985- (CommandBuffer->ZeComputeCommandListTranslated ,
986- &ZeMutableCommandDesc, &CommandId));
998+ (ZeCommandList, &ZeMutableCommandDesc, &CommandId));
987999 }
9881000 DEBUG_LOG (CommandId);
9891001
@@ -1863,17 +1875,22 @@ ur_result_t updateKernelCommand(
18631875 ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel ;
18641876
18651877 if (NewKernel && Command->Kernel != NewKernel) {
1878+ ze_kernel_handle_t KernelHandle{};
18661879 ze_kernel_handle_t ZeNewKernel{};
18671880 UR_CALL (getZeKernel (ZeDevice, NewKernel, &ZeNewKernel));
18681881
1869- ze_kernel_handle_t ZeKernelTranslated = nullptr ;
1870- ZE2UR_CALL (zelLoaderTranslateHandle,
1871- (ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&ZeKernelTranslated));
1882+ ze_command_list_handle_t ZeCommandList =
1883+ CommandBuffer->ZeComputeCommandList ;
1884+ KernelHandle = ZeNewKernel;
1885+ if (!Platform->ZeMutableCmdListExt .LoaderExtension ) {
1886+ ZeCommandList = CommandBuffer->ZeComputeCommandListTranslated ;
1887+ ZE2UR_CALL (zelLoaderTranslateHandle,
1888+ (ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&KernelHandle));
1889+ }
18721890
18731891 ZE2UR_CALL (Platform->ZeMutableCmdListExt
18741892 .zexCommandListUpdateMutableCommandKernelsExp ,
1875- (CommandBuffer->ZeComputeCommandListTranslated , 1 ,
1876- &Command->CommandId , &ZeKernelTranslated));
1893+ (ZeCommandList, 1 , &Command->CommandId , &KernelHandle));
18771894 // Set current kernel to be the new kernel
18781895 Command->Kernel = NewKernel;
18791896 }
@@ -2079,9 +2096,15 @@ ur_result_t updateKernelCommand(
20792096 MutableCommandDesc.pNext = NextDesc;
20802097 MutableCommandDesc.flags = 0 ;
20812098
2099+ ze_command_list_handle_t ZeCommandList =
2100+ CommandBuffer->ZeComputeCommandListTranslated ;
2101+ if (Platform->ZeMutableCmdListExt .LoaderExtension ) {
2102+ ZeCommandList = CommandBuffer->ZeComputeCommandList ;
2103+ }
2104+
20822105 ZE2UR_CALL (
20832106 Platform->ZeMutableCmdListExt .zexCommandListUpdateMutableCommandsExp ,
2084- (CommandBuffer-> ZeComputeCommandListTranslated , &MutableCommandDesc));
2107+ (ZeCommandList , &MutableCommandDesc));
20852108
20862109 return UR_RESULT_SUCCESS;
20872110}
0 commit comments