@@ -100,16 +100,12 @@ ur_result_t ur_exp_command_buffer_handle_t_::addWaitNodes(
100100 return Err;
101101}
102102
103- kernel_command_handle::kernel_command_handle (
104- ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
105- CUgraphNode Node, CUDA_KERNEL_NODE_PARAMS Params, uint32_t WorkDim,
103+ kernel_command_data::kernel_command_data (
104+ ur_kernel_handle_t Kernel, CUDA_KERNEL_NODE_PARAMS Params, uint32_t WorkDim,
106105 const size_t *GlobalWorkOffsetPtr, const size_t *GlobalWorkSizePtr,
107106 const size_t *LocalWorkSizePtr, uint32_t NumKernelAlternatives,
108- ur_kernel_handle_t *KernelAlternatives, CUgraphNode SignalNode,
109- const std::vector<CUgraphNode> &WaitNodes)
110- : ur_exp_command_buffer_command_handle_t_(CommandBuffer, Node, SignalNode,
111- WaitNodes),
112- Kernel(Kernel), Params(Params), WorkDim(WorkDim) {
107+ ur_kernel_handle_t *KernelAlternatives)
108+ : Kernel(Kernel), Params(Params), WorkDim(WorkDim) {
113109 const size_t CopySize = sizeof (size_t ) * WorkDim;
114110 std::memcpy (GlobalWorkOffset, GlobalWorkOffsetPtr, CopySize);
115111 std::memcpy (GlobalWorkSize, GlobalWorkSizePtr, CopySize);
@@ -191,8 +187,8 @@ static void setCopyParams(const void *SrcPtr, const CUmemorytype_enum SrcType,
191187}
192188
193189// Helper function for enqueuing memory fills. Templated on the CommandType
194- // enum class for the type of fill being created.
195- template <class T >
190+ // variant for the type of fill being created.
191+ template <CommandType CT >
196192static ur_result_t enqueueCommandBufferFillHelper (
197193 ur_exp_command_buffer_handle_t CommandBuffer, void *DstDevice,
198194 const CUmemorytype_enum DstType, const void *Pattern, size_t PatternSize,
@@ -331,8 +327,9 @@ static ur_result_t enqueueCommandBufferFillHelper(
331327
332328 std::vector<CUgraphNode> WaitNodes =
333329 NumEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
334- auto NewCommand = std::make_unique<T>(CommandBuffer, GraphNode, SignalNode,
335- WaitNodes, std::move (DecomposedNodes));
330+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
331+ CT, CommandBuffer, GraphNode, SignalNode, WaitNodes,
332+ fill_command_data{std::move (DecomposedNodes)});
336333 if (RetCommand) {
337334 *RetCommand = NewCommand.get ();
338335 }
@@ -528,10 +525,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
528525
529526 std::vector<CUgraphNode> WaitNodes =
530527 numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
531- auto NewCommand = std::make_unique<kernel_command_handle>(
532- hCommandBuffer, hKernel, GraphNode, NodeParams, workDim,
533- pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
534- numKernelAlternatives, phKernelAlternatives, SignalNode, WaitNodes);
528+ auto KernelData = kernel_command_data{hKernel,
529+ NodeParams,
530+ workDim,
531+ pGlobalWorkOffset,
532+ pGlobalWorkSize,
533+ pLocalWorkSize,
534+ numKernelAlternatives,
535+ phKernelAlternatives};
536+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
537+ CommandType::Kernel, hCommandBuffer, GraphNode, SignalNode, WaitNodes,
538+ KernelData);
535539
536540 if (phCommand) {
537541 *phCommand = NewCommand.get ();
@@ -585,8 +589,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp(
585589
586590 std::vector<CUgraphNode> WaitNodes =
587591 numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
588- auto NewCommand = std::make_unique<usm_memcpy_command_handle >(
589- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
592+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_ >(
593+ CommandType::USMMemcpy, hCommandBuffer, GraphNode, SignalNode, WaitNodes);
590594 if (phCommand) {
591595 *phCommand = NewCommand.get ();
592596 }
@@ -650,8 +654,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
650654
651655 std::vector<CUgraphNode> WaitNodes =
652656 numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
653- auto NewCommand = std::make_unique<buffer_copy_command_handle>(
654- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
657+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
658+ CommandType::MemBufferCopy, hCommandBuffer, GraphNode, SignalNode,
659+ WaitNodes);
655660
656661 if (phCommand) {
657662 *phCommand = NewCommand.get ();
@@ -713,8 +718,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
713718
714719 std::vector<CUgraphNode> WaitNodes =
715720 numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
716- auto NewCommand = std::make_unique<buffer_copy_rect_command_handle>(
717- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
721+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
722+ CommandType::MemBufferCopyRect, hCommandBuffer, GraphNode, SignalNode,
723+ WaitNodes);
718724
719725 if (phCommand) {
720726 *phCommand = NewCommand.get ();
@@ -772,8 +778,9 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp(
772778
773779 std::vector<CUgraphNode> WaitNodes =
774780 numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
775- auto NewCommand = std::make_unique<buffer_write_command_handle>(
776- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
781+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
782+ CommandType::MemBufferWrite, hCommandBuffer, GraphNode, SignalNode,
783+ WaitNodes);
777784 if (phCommand) {
778785 *phCommand = NewCommand.get ();
779786 }
@@ -829,8 +836,9 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp(
829836
830837 std::vector<CUgraphNode> WaitNodes =
831838 numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
832- auto NewCommand = std::make_unique<buffer_read_command_handle>(
833- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
839+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
840+ CommandType::MemBufferRead, hCommandBuffer, GraphNode, SignalNode,
841+ WaitNodes);
834842 if (phCommand) {
835843 *phCommand = NewCommand.get ();
836844 }
@@ -890,8 +898,9 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
890898
891899 std::vector<CUgraphNode> WaitNodes =
892900 numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
893- auto NewCommand = std::make_unique<buffer_write_rect_command_handle>(
894- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
901+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
902+ CommandType::MemBufferWriteRect, hCommandBuffer, GraphNode, SignalNode,
903+ WaitNodes);
895904
896905 if (phCommand) {
897906 *phCommand = NewCommand.get ();
@@ -952,8 +961,9 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
952961
953962 std::vector<CUgraphNode> WaitNodes =
954963 numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
955- auto NewCommand = std::make_unique<buffer_read_rect_command_handle>(
956- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
964+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
965+ CommandType::MemBufferReadRect, hCommandBuffer, GraphNode, SignalNode,
966+ WaitNodes);
957967
958968 if (phCommand) {
959969 *phCommand = NewCommand.get ();
@@ -1006,8 +1016,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
10061016
10071017 std::vector<CUgraphNode> WaitNodes =
10081018 numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
1009- auto NewCommand = std::make_unique<usm_prefetch_command_handle>(
1010- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
1019+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
1020+ CommandType::USMPrefetch, hCommandBuffer, GraphNode, SignalNode,
1021+ WaitNodes);
10111022
10121023 if (phCommand) {
10131024 *phCommand = NewCommand.get ();
@@ -1060,8 +1071,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
10601071
10611072 std::vector<CUgraphNode> WaitNodes =
10621073 numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
1063- auto NewCommand = std::make_unique<usm_advise_command_handle >(
1064- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
1074+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_ >(
1075+ CommandType::USMAdvise, hCommandBuffer, GraphNode, SignalNode, WaitNodes);
10651076
10661077 if (phCommand) {
10671078 *phCommand = NewCommand.get ();
@@ -1096,7 +1107,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
10961107 auto DstDevice = std::get<BufferMem>(hBuffer->Mem )
10971108 .getPtrWithOffset (hCommandBuffer->Device , offset);
10981109
1099- return enqueueCommandBufferFillHelper<buffer_fill_command_handle >(
1110+ return enqueueCommandBufferFillHelper<CommandType::MemBufferFill >(
11001111 hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize,
11011112 size, numSyncPointsInWaitList, pSyncPointWaitList, numEventsInWaitList,
11021113 phEventWaitList, pSyncPoint, phEvent, phCommand);
@@ -1116,7 +1127,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp(
11161127 (patternSize > 0 ); // is a positive power of two
11171128
11181129 UR_ASSERT (PatternIsValid && PatternSizeIsValid, UR_RESULT_ERROR_INVALID_SIZE);
1119- return enqueueCommandBufferFillHelper<usm_fill_command_handle >(
1130+ return enqueueCommandBufferFillHelper<CommandType::USMFill >(
11201131 hCommandBuffer, pPtr, CU_MEMORYTYPE_UNIFIED, pPattern, patternSize, size,
11211132 numSyncPointsInWaitList, pSyncPointWaitList, numEventsInWaitList,
11221133 phEventWaitList, pSyncPoint, phEvent, phCommand);
@@ -1165,12 +1176,12 @@ ur_result_t
11651176validateCommandDesc (ur_exp_command_buffer_handle_t CommandBuffer,
11661177 const ur_exp_command_buffer_update_kernel_launch_desc_t
11671178 &UpdateCommandDesc) {
1168- if (UpdateCommandDesc.hCommand ->getCommandType () != CommandType::Kernel) {
1179+ if (UpdateCommandDesc.hCommand ->Type != CommandType::Kernel) {
11691180 return UR_RESULT_ERROR_INVALID_VALUE;
11701181 }
11711182
1172- auto Command =
1173- static_cast <kernel_command_handle *>(UpdateCommandDesc. hCommand );
1183+ auto * Command = UpdateCommandDesc. hCommand ;
1184+ auto &KernelData = std::get<kernel_command_data>(Command-> CommandData );
11741185 if (CommandBuffer != Command->CommandBuffer ) {
11751186 return UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP;
11761187 }
@@ -1180,14 +1191,14 @@ validateCommandDesc(ur_exp_command_buffer_handle_t CommandBuffer,
11801191 return UR_RESULT_ERROR_INVALID_OPERATION;
11811192 }
11821193
1183- if (UpdateCommandDesc.newWorkDim != Command-> WorkDim &&
1194+ if (UpdateCommandDesc.newWorkDim != KernelData. WorkDim &&
11841195 (!UpdateCommandDesc.pNewGlobalWorkOffset ||
11851196 !UpdateCommandDesc.pNewGlobalWorkSize )) {
11861197 return UR_RESULT_ERROR_INVALID_VALUE;
11871198 }
11881199
11891200 if (UpdateCommandDesc.hNewKernel &&
1190- !Command-> ValidKernelHandles .count (UpdateCommandDesc.hNewKernel )) {
1201+ !KernelData. ValidKernelHandles .count (UpdateCommandDesc.hNewKernel )) {
11911202 return UR_RESULT_ERROR_INVALID_VALUE;
11921203 }
11931204 return UR_RESULT_SUCCESS;
@@ -1202,9 +1213,9 @@ validateCommandDesc(ur_exp_command_buffer_handle_t CommandBuffer,
12021213ur_result_t
12031214updateKernelArguments (const ur_exp_command_buffer_update_kernel_launch_desc_t
12041215 &UpdateCommandDesc) {
1205- auto Command =
1206- static_cast <kernel_command_handle *>(UpdateCommandDesc. hCommand );
1207- ur_kernel_handle_t Kernel = Command-> Kernel ;
1216+ auto * Command = UpdateCommandDesc. hCommand ;
1217+ auto &KernelData = std::get<kernel_command_data>(Command-> CommandData );
1218+ ur_kernel_handle_t Kernel = KernelData. Kernel ;
12081219 ur_device_handle_t Device = Command->CommandBuffer ->Device ;
12091220
12101221 // Update pointer arguments to the kernel
@@ -1284,29 +1295,29 @@ updateKernelArguments(const ur_exp_command_buffer_update_kernel_launch_desc_t
12841295ur_result_t
12851296updateCommand (const ur_exp_command_buffer_update_kernel_launch_desc_t
12861297 &UpdateCommandDesc) {
1287- auto Command =
1288- static_cast <kernel_command_handle *>(UpdateCommandDesc. hCommand );
1298+ auto * Command = UpdateCommandDesc. hCommand ;
1299+ auto &KernelData = std::get<kernel_command_data>(Command-> CommandData );
12891300 if (UpdateCommandDesc.hNewKernel ) {
1290- Command-> Kernel = UpdateCommandDesc.hNewKernel ;
1301+ KernelData. Kernel = UpdateCommandDesc.hNewKernel ;
12911302 }
12921303
12931304 if (UpdateCommandDesc.newWorkDim ) {
1294- Command-> WorkDim = UpdateCommandDesc.newWorkDim ;
1305+ KernelData. WorkDim = UpdateCommandDesc.newWorkDim ;
12951306 }
12961307
12971308 if (UpdateCommandDesc.pNewGlobalWorkOffset ) {
1298- Command-> setGlobalOffset (UpdateCommandDesc.pNewGlobalWorkOffset );
1309+ KernelData. setGlobalOffset (UpdateCommandDesc.pNewGlobalWorkOffset );
12991310 }
13001311
13011312 if (UpdateCommandDesc.pNewGlobalWorkSize ) {
1302- Command-> setGlobalSize (UpdateCommandDesc.pNewGlobalWorkSize );
1313+ KernelData. setGlobalSize (UpdateCommandDesc.pNewGlobalWorkSize );
13031314 if (!UpdateCommandDesc.pNewLocalWorkSize ) {
1304- Command-> setNullLocalSize ();
1315+ KernelData. setNullLocalSize ();
13051316 }
13061317 }
13071318
13081319 if (UpdateCommandDesc.pNewLocalWorkSize ) {
1309- Command-> setLocalSize (UpdateCommandDesc.pNewLocalWorkSize );
1320+ KernelData. setLocalSize (UpdateCommandDesc.pNewLocalWorkSize );
13101321 }
13111322
13121323 return UR_RESULT_SUCCESS;
@@ -1334,27 +1345,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
13341345
13351346 // If no work-size is provided make sure we pass nullptr to setKernelParams
13361347 // so it can guess the local work size.
1337- auto KernelCommandHandle =
1338- static_cast <kernel_command_handle *>(UpdateCommandDesc.hCommand );
1339- const bool ProvidedLocalSize = !KernelCommandHandle->isNullLocalSize ();
1348+ auto *KernelCommandHandle = UpdateCommandDesc.hCommand ;
1349+ auto &KernelData =
1350+ std::get<kernel_command_data>(KernelCommandHandle->CommandData );
1351+ const bool ProvidedLocalSize = !KernelData.isNullLocalSize ();
13401352 size_t *LocalWorkSize =
1341- ProvidedLocalSize ? KernelCommandHandle-> LocalWorkSize : nullptr ;
1353+ ProvidedLocalSize ? KernelData. LocalWorkSize : nullptr ;
13421354
13431355 // Set the number of threads per block to the number of threads per warp
13441356 // by default unless user has provided a better number.
13451357 size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
13461358 size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
1347- CUfunction CuFunc = KernelCommandHandle-> Kernel ->get ();
1359+ CUfunction CuFunc = KernelData. Kernel ->get ();
13481360 auto Result = setKernelParams (
1349- hCommandBuffer->Context , hCommandBuffer->Device ,
1350- KernelCommandHandle->WorkDim , KernelCommandHandle->GlobalWorkOffset ,
1351- KernelCommandHandle->GlobalWorkSize , LocalWorkSize,
1352- KernelCommandHandle->Kernel , CuFunc, ThreadsPerBlock, BlocksPerGrid);
1361+ hCommandBuffer->Context , hCommandBuffer->Device , KernelData.WorkDim ,
1362+ KernelData.GlobalWorkOffset , KernelData.GlobalWorkSize , LocalWorkSize,
1363+ KernelData.Kernel , CuFunc, ThreadsPerBlock, BlocksPerGrid);
13531364 if (Result != UR_RESULT_SUCCESS) {
13541365 return Result;
13551366 }
13561367
1357- CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle-> Params ;
1368+ CUDA_KERNEL_NODE_PARAMS &Params = KernelData. Params ;
13581369
13591370 Params.func = CuFunc;
13601371 Params.gridDimX = BlocksPerGrid[0 ];
@@ -1363,9 +1374,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
13631374 Params.blockDimX = ThreadsPerBlock[0 ];
13641375 Params.blockDimY = ThreadsPerBlock[1 ];
13651376 Params.blockDimZ = ThreadsPerBlock[2 ];
1366- Params.sharedMemBytes = KernelCommandHandle-> Kernel ->getLocalSize ();
1367- Params.kernelParams = const_cast < void **>(
1368- KernelCommandHandle-> Kernel ->getArgPointers ().data ());
1377+ Params.sharedMemBytes = KernelData. Kernel ->getLocalSize ();
1378+ Params.kernelParams =
1379+ const_cast < void **>(KernelData. Kernel ->getArgPointers ().data ());
13691380
13701381 CUgraphNode Node = KernelCommandHandle->Node ;
13711382 CUgraphExec CudaGraphExec = hCommandBuffer->CudaGraphExec ;
0 commit comments