@@ -406,7 +406,7 @@ static ur_result_t enqueueCommandBufferMemCopyHelper(
406406 } else {
407407 // FIXME Why doesn't the event need to be host visible
408408 std::vector<ze_event_handle_t > ZeEventList;
409- ur_event_handle_t LaunchEvent;
409+ ur_event_handle_t LaunchEvent = nullptr ;
410410 UR_CALL (createSyncPoint (CommandType, CommandBuffer, NumSyncPointsInWaitList,
411411 SyncPointWaitList, RetSyncPoint, false , ZeEventList,
412412 LaunchEvent));
@@ -761,7 +761,7 @@ static ur_result_t
761761createCommandHandle (ur_exp_command_buffer_handle_t CommandBuffer,
762762 ur_kernel_handle_t Kernel, uint32_t WorkDim,
763763 const size_t *LocalWorkSize,
764- ur_exp_command_buffer_command_handle_t & Command) {
764+ ur_exp_command_buffer_command_handle_t & Command) {
765765
766766 // If command-buffer is updatable then get command id which is going to be
767767 // used if command is updated in the future. This
@@ -1371,20 +1371,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferReleaseCommandExp(
13711371 return UR_RESULT_SUCCESS;
13721372}
13731373
1374- UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp (
1374+ static ur_result_t validateCommandDesc (
13751375 ur_exp_command_buffer_command_handle_t Command,
13761376 const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
1377- UR_ASSERT (Command->Kernel , UR_RESULT_ERROR_INVALID_NULL_HANDLE);
1378- UR_ASSERT (CommandDesc->newWorkDim <= 3 ,
1379- UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
13801377
1381- // Lock command, kernel and command buffer for update.
1382- std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Guard (
1383- Command->Mutex , Command->CommandBuffer ->Mutex , Command->Kernel ->Mutex );
1384- UR_ASSERT (Command->CommandBuffer ->IsUpdatable ,
1385- UR_RESULT_ERROR_INVALID_OPERATION);
1386- UR_ASSERT (Command->CommandBuffer ->IsFinalized ,
1387- UR_RESULT_ERROR_INVALID_OPERATION);
1378+ auto CommandBuffer = Command->CommandBuffer ;
1379+ auto SupportedFeatures =
1380+ Command->CommandBuffer ->Device ->ZeDeviceMutableCmdListsProperties
1381+ ->mutableCommandFlags ;
1382+ logger::debug (" Mutable features supported by device {}" , SupportedFeatures);
13881383
13891384 uint32_t Dim = CommandDesc->newWorkDim ;
13901385 if (Dim != 0 ) {
@@ -1409,25 +1404,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
14091404 }
14101405 }
14111406
1412- auto CommandBuffer = Command->CommandBuffer ;
1413- const void *NextDesc = nullptr ;
1414- auto SupportedFeatures =
1415- Command->CommandBuffer ->Device ->ZeDeviceMutableCmdListsProperties
1416- ->mutableCommandFlags ;
1417- logger::debug (" Mutable features supported by device {}" , SupportedFeatures);
1418-
1419- // We need the created descriptors to live till the point when
1420- // zexCommandListUpdateMutableCommandsExp is called at the end of the
1421- // function.
1422- std::vector<std::unique_ptr<ZeStruct<ze_mutable_kernel_argument_exp_desc_t >>>
1423- ArgDescs;
1424- std::vector<std::unique_ptr<ZeStruct<ze_mutable_global_offset_exp_desc_t >>>
1425- OffsetDescs;
1426- std::vector<std::unique_ptr<ZeStruct<ze_mutable_group_size_exp_desc_t >>>
1427- GroupSizeDescs;
1428- std::vector<std::unique_ptr<ZeStruct<ze_mutable_group_count_exp_desc_t >>>
1429- GroupCountDescs;
1430-
14311407 // Check if new global offset is provided.
14321408 size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset ;
14331409 UR_ASSERT (!NewGlobalWorkOffset ||
@@ -1439,6 +1415,56 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
14391415 logger::error (" No global offset extension found on this driver" );
14401416 return UR_RESULT_ERROR_INVALID_VALUE;
14411417 }
1418+ }
1419+
1420+ // Check if new group size is provided.
1421+ size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize ;
1422+ UR_ASSERT (!NewLocalWorkSize ||
1423+ (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE),
1424+ UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1425+
1426+ // Check if new global size is provided and we need to update group count.
1427+ size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize ;
1428+ UR_ASSERT (!NewGlobalWorkSize ||
1429+ (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT),
1430+ UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1431+ UR_ASSERT (!(NewGlobalWorkSize && !NewLocalWorkSize) ||
1432+ (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE),
1433+ UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1434+
1435+ UR_ASSERT (
1436+ (!CommandDesc->numNewMemObjArgs && !CommandDesc->numNewPointerArgs &&
1437+ !CommandDesc->numNewValueArgs ) ||
1438+ (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS),
1439+ UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1440+
1441+ return UR_RESULT_SUCCESS;
1442+ }
1443+
1444+ static ur_result_t updateKernelCommand (
1445+ ur_exp_command_buffer_command_handle_t Command,
1446+ const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
1447+
1448+ // We need the created descriptors to live till the point when
1449+ // zeCommandListUpdateMutableCommandsExp is called at the end of the
1450+ // function.
1451+ std::vector<std::variant<
1452+ std::unique_ptr<ZeStruct<ze_mutable_kernel_argument_exp_desc_t >>,
1453+ std::unique_ptr<ZeStruct<ze_mutable_global_offset_exp_desc_t >>,
1454+ std::unique_ptr<ZeStruct<ze_mutable_group_size_exp_desc_t >>,
1455+ std::unique_ptr<ZeStruct<ze_mutable_group_count_exp_desc_t >>>>
1456+ Descs;
1457+
1458+ const auto CommandBuffer = Command->CommandBuffer ;
1459+ const void *NextDesc = nullptr ;
1460+
1461+ uint32_t Dim = CommandDesc->newWorkDim ;
1462+ size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset ;
1463+ size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize ;
1464+ size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize ;
1465+
1466+ // Check if a new global offset is provided.
1467+ if (NewGlobalWorkOffset && Dim > 0 ) {
14421468 auto MutableGroupOffestDesc =
14431469 std::make_unique<ZeStruct<ze_mutable_global_offset_exp_desc_t >>();
14441470 MutableGroupOffestDesc->commandId = Command->CommandId ;
@@ -1451,15 +1477,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
14511477 DEBUG_LOG (MutableGroupOffestDesc->offsetY );
14521478 MutableGroupOffestDesc->offsetZ = Dim == 3 ? NewGlobalWorkOffset[2 ] : 0 ;
14531479 DEBUG_LOG (MutableGroupOffestDesc->offsetZ );
1480+
14541481 NextDesc = MutableGroupOffestDesc.get ();
1455- OffsetDescs .push_back (std::move (MutableGroupOffestDesc));
1482+ Descs .push_back (std::move (MutableGroupOffestDesc));
14561483 }
14571484
1458- // Check if new group size is provided.
1459- size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize ;
1460- UR_ASSERT (!NewLocalWorkSize ||
1461- (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE),
1462- UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1485+ // Check if a new group size is provided.
14631486 if (NewLocalWorkSize && Dim > 0 ) {
14641487 auto MutableGroupSizeDesc =
14651488 std::make_unique<ZeStruct<ze_mutable_group_size_exp_desc_t >>();
@@ -1473,29 +1496,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
14731496 DEBUG_LOG (MutableGroupSizeDesc->groupSizeY );
14741497 MutableGroupSizeDesc->groupSizeZ = Dim == 3 ? NewLocalWorkSize[2 ] : 1 ;
14751498 DEBUG_LOG (MutableGroupSizeDesc->groupSizeZ );
1499+
14761500 NextDesc = MutableGroupSizeDesc.get ();
1477- GroupSizeDescs .push_back (std::move (MutableGroupSizeDesc));
1501+ Descs .push_back (std::move (MutableGroupSizeDesc));
14781502 }
14791503
1480- // Check if new global size is provided and we need to update group count.
1481- size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize ;
1482- UR_ASSERT (!NewGlobalWorkSize ||
1483- (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT),
1484- UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1485- UR_ASSERT (!(NewGlobalWorkSize && !NewLocalWorkSize) ||
1486- (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE),
1487- UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1488-
1504+ // Check if a new global size is provided and if we need to update the group
1505+ // count.
14891506 ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
14901507 if (NewGlobalWorkSize && Dim > 0 ) {
1491- uint32_t WG[3 ];
1492- // If new global work size is provided but new local work size is not
1493- // provided then we still need to update local work size based on size
1494- // suggested by the driver for the kernel.
1508+ // If a new global work size is provided but a new local work size is not
1509+ // then we still need to update local work size based on the size suggested
1510+ // by the driver for the kernel.
14951511 bool UpdateWGSize = NewLocalWorkSize == nullptr ;
1512+
1513+ uint32_t WG[3 ];
14961514 UR_CALL (calculateKernelWorkDimensions (
14971515 Command->Kernel , CommandBuffer->Device , ZeThreadGroupDimensions, WG,
14981516 Dim, NewGlobalWorkSize, NewLocalWorkSize));
1517+
14991518 auto MutableGroupCountDesc =
15001519 std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t >>();
15011520 MutableGroupCountDesc->commandId = Command->CommandId ;
@@ -1506,8 +1525,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
15061525 DEBUG_LOG (MutableGroupCountDesc->pGroupCount ->groupCountX );
15071526 DEBUG_LOG (MutableGroupCountDesc->pGroupCount ->groupCountY );
15081527 DEBUG_LOG (MutableGroupCountDesc->pGroupCount ->groupCountZ );
1528+
15091529 NextDesc = MutableGroupCountDesc.get ();
1510- GroupCountDescs .push_back (std::move (MutableGroupCountDesc));
1530+ Descs .push_back (std::move (MutableGroupCountDesc));
15111531
15121532 if (UpdateWGSize) {
15131533 auto MutableGroupSizeDesc =
@@ -1524,16 +1544,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
15241544 DEBUG_LOG (MutableGroupSizeDesc->groupSizeZ );
15251545
15261546 NextDesc = MutableGroupSizeDesc.get ();
1527- GroupSizeDescs .push_back (std::move (MutableGroupSizeDesc));
1547+ Descs .push_back (std::move (MutableGroupSizeDesc));
15281548 }
15291549 }
15301550
1531- UR_ASSERT (
1532- (!CommandDesc->numNewMemObjArgs && !CommandDesc->numNewPointerArgs &&
1533- !CommandDesc->numNewValueArgs ) ||
1534- (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS),
1535- UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1536-
15371551 // Check if new memory object arguments are provided.
15381552 for (uint32_t NewMemObjArgNum = CommandDesc->numNewMemObjArgs ;
15391553 NewMemObjArgNum-- > 0 ;) {
@@ -1557,6 +1571,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
15571571 return UR_RESULT_ERROR_INVALID_ARGUMENT;
15581572 }
15591573 }
1574+
15601575 ur_mem_handle_t NewMemObjArg = NewMemObjArgDesc.hNewMemObjArg ;
15611576 // The NewMemObjArg may be a NULL pointer in which case a NULL value is used
15621577 // for the kernel argument declared as a pointer to global or constant
@@ -1566,6 +1581,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
15661581 UR_CALL (NewMemObjArg->getZeHandlePtr (ZeHandlePtr, UrAccessMode,
15671582 CommandBuffer->Device ));
15681583 }
1584+
15691585 auto ZeMutableArgDesc =
15701586 std::make_unique<ZeStruct<ze_mutable_kernel_argument_exp_desc_t >>();
15711587 ZeMutableArgDesc->commandId = Command->CommandId ;
@@ -1580,14 +1596,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
15801596 DEBUG_LOG (ZeMutableArgDesc->pArgValue );
15811597
15821598 NextDesc = ZeMutableArgDesc.get ();
1583- ArgDescs .push_back (std::move (ZeMutableArgDesc));
1599+ Descs .push_back (std::move (ZeMutableArgDesc));
15841600 }
15851601
15861602 // Check if there are new pointer arguments.
15871603 for (uint32_t NewPointerArgNum = CommandDesc->numNewPointerArgs ;
15881604 NewPointerArgNum-- > 0 ;) {
15891605 ur_exp_command_buffer_update_pointer_arg_desc_t NewPointerArgDesc =
15901606 CommandDesc->pNewPointerArgList [NewPointerArgNum];
1607+
15911608 auto ZeMutableArgDesc =
15921609 std::make_unique<ZeStruct<ze_mutable_kernel_argument_exp_desc_t >>();
15931610 ZeMutableArgDesc->commandId = Command->CommandId ;
@@ -1602,14 +1619,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
16021619 DEBUG_LOG (ZeMutableArgDesc->pArgValue );
16031620
16041621 NextDesc = ZeMutableArgDesc.get ();
1605- ArgDescs .push_back (std::move (ZeMutableArgDesc));
1622+ Descs .push_back (std::move (ZeMutableArgDesc));
16061623 }
16071624
16081625 // Check if there are new value arguments.
16091626 for (uint32_t NewValueArgNum = CommandDesc->numNewValueArgs ;
16101627 NewValueArgNum-- > 0 ;) {
16111628 ur_exp_command_buffer_update_value_arg_desc_t NewValueArgDesc =
16121629 CommandDesc->pNewValueArgList [NewValueArgNum];
1630+
16131631 auto ZeMutableArgDesc =
16141632 std::make_unique<ZeStruct<ze_mutable_kernel_argument_exp_desc_t >>();
16151633 ZeMutableArgDesc->commandId = Command->CommandId ;
@@ -1634,26 +1652,52 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
16341652 }
16351653 ZeMutableArgDesc->pArgValue = ArgValuePtr;
16361654 DEBUG_LOG (ZeMutableArgDesc->pArgValue );
1655+
16371656 NextDesc = ZeMutableArgDesc.get ();
1638- ArgDescs .push_back (std::move (ZeMutableArgDesc));
1657+ Descs .push_back (std::move (ZeMutableArgDesc));
16391658 }
16401659
16411660 ZeStruct<ze_mutable_commands_exp_desc_t > MutableCommandDesc;
16421661 MutableCommandDesc.pNext = NextDesc;
16431662 MutableCommandDesc.flags = 0 ;
16441663
1645- // We must synchronize mutable command list execution before mutating.
1646- if (ze_fence_handle_t &ZeFence = CommandBuffer->ZeActiveFence ) {
1647- ZE2UR_CALL (zeFenceHostSynchronize, (ZeFence, UINT64_MAX));
1648- }
1649-
16501664 auto Plt = CommandBuffer->Context ->getPlatform ();
16511665 UR_ASSERT (Plt->ZeMutableCmdListExt .Supported ,
16521666 UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
16531667 ZE2UR_CALL (
16541668 Plt->ZeMutableCmdListExt .zexCommandListUpdateMutableCommandsExp ,
16551669 (CommandBuffer->ZeComputeCommandListTranslated , &MutableCommandDesc));
1656- ZE2UR_CALL (zeCommandListClose, (CommandBuffer->ZeComputeCommandList ));
1670+
1671+ return UR_RESULT_SUCCESS;
1672+ }
1673+
1674+ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp (
1675+ ur_exp_command_buffer_command_handle_t Command,
1676+ const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
1677+ UR_ASSERT (Command->Kernel , UR_RESULT_ERROR_INVALID_NULL_HANDLE);
1678+ UR_ASSERT (CommandDesc->newWorkDim <= 3 ,
1679+ UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
1680+
1681+ // Lock command, kernel and command buffer for update.
1682+ std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Guard (
1683+ Command->Mutex , Command->CommandBuffer ->Mutex , Command->Kernel ->Mutex );
1684+
1685+ UR_ASSERT (Command->CommandBuffer ->IsUpdatable ,
1686+ UR_RESULT_ERROR_INVALID_OPERATION);
1687+ UR_ASSERT (Command->CommandBuffer ->IsFinalized ,
1688+ UR_RESULT_ERROR_INVALID_OPERATION);
1689+
1690+ UR_CALL (validateCommandDesc (Command, CommandDesc));
1691+
1692+ // We must synchronize mutable command list execution before mutating.
1693+ if (ze_fence_handle_t &ZeFence = Command->CommandBuffer ->ZeActiveFence ) {
1694+ ZE2UR_CALL (zeFenceHostSynchronize, (ZeFence, UINT64_MAX));
1695+ }
1696+
1697+ UR_CALL (updateKernelCommand (Command, CommandDesc));
1698+
1699+ ZE2UR_CALL (zeCommandListClose,
1700+ (Command->CommandBuffer ->ZeComputeCommandList ));
16571701
16581702 return UR_RESULT_SUCCESS;
16591703}
0 commit comments