1010#include " command_buffer.hpp"
1111#include " helpers/kernel_helpers.hpp"
1212#include " logger/ur_logger.hpp"
13+ #include " ur_api.h"
1314#include " ur_interface_loader.hpp"
1415#include " ur_level_zero.hpp"
1516
@@ -170,6 +171,67 @@ ur_result_t getEventsFromSyncPoints(
170171 return UR_RESULT_SUCCESS;
171172}
172173
174+ /* *
175+ * If necessary, it creates a signal event and appends it to the previous
176+ * command list (copy or compute), to indicate when it's finished executing.
177+ * @param[in] CommandBuffer The CommandBuffer where the command is appended.
178+ * @param[in] ZeCommandList the CommandList that's currently in use.
179+ * @param[out] WaitEventList The list of event for the future command list to
180+ * wait on before execution.
181+ * @return UR_RESULT_SUCCESS or an error code on failure
182+ */
183+ ur_result_t createSyncPointBetweenCopyAndCompute (
184+ ur_exp_command_buffer_handle_t CommandBuffer,
185+ ze_command_list_handle_t ZeCommandList,
186+ std::vector<ze_event_handle_t > &WaitEventList) {
187+
188+ if (!CommandBuffer->ZeCopyCommandList ) {
189+ return UR_RESULT_SUCCESS;
190+ }
191+
192+ bool IsCopy{ZeCommandList == CommandBuffer->ZeCopyCommandList };
193+
194+ // Skip synchronization for the first node in a graph or if the current
195+ // command list matches the previous one.
196+ if (!CommandBuffer->MWasPrevCopyCommandList .has_value ()) {
197+ CommandBuffer->MWasPrevCopyCommandList = IsCopy;
198+ return UR_RESULT_SUCCESS;
199+ } else if (IsCopy == CommandBuffer->MWasPrevCopyCommandList ) {
200+ return UR_RESULT_SUCCESS;
201+ }
202+
203+ /*
204+ * If the current CommandList differs from the previously used one, we must
205+ * append a signal event to the previous CommandList to track when
206+ * its execution is complete.
207+ */
208+ ur_event_handle_t SignalPrevCommandEvent = nullptr ;
209+ UR_CALL (EventCreate (CommandBuffer->Context , nullptr /* Queue*/ ,
210+ false /* IsMultiDevice*/ , false , &SignalPrevCommandEvent,
211+ false /* CounterBasedEventEnabled*/ ,
212+ !CommandBuffer->IsProfilingEnabled ,
213+ false /* InterruptBasedEventEnabled*/ ));
214+
215+ // Determine which command list to signal.
216+ auto CommandListToSignal = (!IsCopy && CommandBuffer->MWasPrevCopyCommandList )
217+ ? CommandBuffer->ZeCopyCommandList
218+ : CommandBuffer->ZeComputeCommandList ;
219+ CommandBuffer->MWasPrevCopyCommandList = IsCopy;
220+
221+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
222+ (CommandListToSignal, SignalPrevCommandEvent->ZeEvent ));
223+
224+ // Add the event to the dependencies for future command list to wait on.
225+ WaitEventList.push_back (SignalPrevCommandEvent->ZeEvent );
226+
227+ // Get sync point and register the event with it.
228+ ur_exp_command_buffer_sync_point_t SyncPoint =
229+ CommandBuffer->getNextSyncPoint ();
230+ CommandBuffer->registerSyncPoint (SyncPoint, SignalPrevCommandEvent);
231+
232+ return UR_RESULT_SUCCESS;
233+ }
234+
173235/* *
174236 * If needed, creates a sync point for a given command and returns the L0
175237 * events associated with the sync point.
@@ -190,7 +252,7 @@ ur_result_t getEventsFromSyncPoints(
190252 */
191253ur_result_t createSyncPointAndGetZeEvents (
192254 ur_command_t CommandType, ur_exp_command_buffer_handle_t CommandBuffer,
193- uint32_t NumSyncPointsInWaitList,
255+ ze_command_list_handle_t ZeCommandList, uint32_t NumSyncPointsInWaitList,
194256 const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
195257 bool HostVisible, ur_exp_command_buffer_sync_point_t *RetSyncPoint,
196258 std::vector<ze_event_handle_t > &ZeEventList,
@@ -199,6 +261,11 @@ ur_result_t createSyncPointAndGetZeEvents(
199261 ZeLaunchEvent = nullptr ;
200262
201263 if (CommandBuffer->IsInOrderCmdList ) {
264+ UR_CALL (createSyncPointBetweenCopyAndCompute (CommandBuffer, ZeCommandList,
265+ ZeEventList));
266+ if (!ZeEventList.empty ()) {
267+ NumSyncPointsInWaitList = ZeEventList.size ();
268+ }
202269 return UR_RESULT_SUCCESS;
203270 }
204271
@@ -225,24 +292,24 @@ ur_result_t createSyncPointAndGetZeEvents(
225292 return UR_RESULT_SUCCESS;
226293}
227294
228- // Shared by all memory read/write/copy PI interfaces.
229- // Helper function for common code when enqueuing memory operations to a command
230- // buffer.
295+ // Shared by all memory read/write/copy UR interfaces.
296+ // Helper function for common code when enqueuing memory operations to a
297+ // command buffer.
231298ur_result_t enqueueCommandBufferMemCopyHelper (
232299 ur_command_t CommandType, ur_exp_command_buffer_handle_t CommandBuffer,
233300 void *Dst, const void *Src, size_t Size, bool PreferCopyEngine,
234301 uint32_t NumSyncPointsInWaitList,
235302 const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
236303 ur_exp_command_buffer_sync_point_t *RetSyncPoint) {
237304
305+ ze_command_list_handle_t ZeCommandList =
306+ CommandBuffer->chooseCommandList (PreferCopyEngine);
307+
238308 std::vector<ze_event_handle_t > ZeEventList;
239309 ze_event_handle_t ZeLaunchEvent = nullptr ;
240310 UR_CALL (createSyncPointAndGetZeEvents (
241- CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
242- false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
243-
244- ze_command_list_handle_t ZeCommandList =
245- CommandBuffer->chooseCommandList (PreferCopyEngine);
311+ CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
312+ SyncPointWaitList, false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
246313
247314 ZE2UR_CALL (zeCommandListAppendMemoryCopy,
248315 (ZeCommandList, Dst, Src, Size, ZeLaunchEvent, ZeEventList.size (),
@@ -293,14 +360,14 @@ ur_result_t enqueueCommandBufferMemCopyRectHelper(
293360 const ze_copy_region_t ZeDstRegion = {DstOriginX, DstOriginY, DstOriginZ,
294361 Width, Height, Depth};
295362
363+ ze_command_list_handle_t ZeCommandList =
364+ CommandBuffer->chooseCommandList (PreferCopyEngine);
365+
296366 std::vector<ze_event_handle_t > ZeEventList;
297367 ze_event_handle_t ZeLaunchEvent = nullptr ;
298368 UR_CALL (createSyncPointAndGetZeEvents (
299- CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
300- false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
301-
302- ze_command_list_handle_t ZeCommandList =
303- CommandBuffer->chooseCommandList (PreferCopyEngine);
369+ CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
370+ SyncPointWaitList, false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
304371
305372 ZE2UR_CALL (zeCommandListAppendMemoryCopyRegion,
306373 (ZeCommandList, Dst, &ZeDstRegion, DstPitch, DstSlicePitch, Src,
@@ -321,19 +388,19 @@ ur_result_t enqueueCommandBufferFillHelper(
321388 UR_ASSERT ((PatternSize > 0 ) && ((PatternSize & (PatternSize - 1 )) == 0 ),
322389 UR_RESULT_ERROR_INVALID_VALUE);
323390
324- std::vector<ze_event_handle_t > ZeEventList;
325- ze_event_handle_t ZeLaunchEvent = nullptr ;
326- UR_CALL (createSyncPointAndGetZeEvents (
327- CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
328- true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
329-
330391 bool PreferCopyEngine;
331392 UR_CALL (
332393 preferCopyEngineForFill (CommandBuffer, PatternSize, PreferCopyEngine));
333394
334395 ze_command_list_handle_t ZeCommandList =
335396 CommandBuffer->chooseCommandList (PreferCopyEngine);
336397
398+ std::vector<ze_event_handle_t > ZeEventList;
399+ ze_event_handle_t ZeLaunchEvent = nullptr ;
400+ UR_CALL (createSyncPointAndGetZeEvents (
401+ CommandType, CommandBuffer, ZeCommandList, NumSyncPointsInWaitList,
402+ SyncPointWaitList, true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
403+
337404 ZE2UR_CALL (zeCommandListAppendMemoryFill,
338405 (ZeCommandList, Ptr, Pattern, PatternSize, Size, ZeLaunchEvent,
339406 ZeEventList.size (), getPointerFromVector (ZeEventList)));
@@ -477,12 +544,12 @@ void ur_exp_command_buffer_handle_t_::registerSyncPoint(
477544
478545ze_command_list_handle_t
479546ur_exp_command_buffer_handle_t_::chooseCommandList (bool PreferCopyEngine) {
480- if (PreferCopyEngine && this -> useCopyEngine () && ! this -> IsInOrderCmdList ) {
547+ if (PreferCopyEngine && useCopyEngine ()) {
481548 // We indicate that ZeCopyCommandList contains commands to be submitted.
482- this -> MCopyCommandListEmpty = false ;
483- return this -> ZeCopyCommandList ;
549+ MCopyCommandListEmpty = false ;
550+ return ZeCopyCommandList;
484551 }
485- return this -> ZeComputeCommandList ;
552+ return ZeComputeCommandList;
486553}
487554
488555ur_result_t ur_exp_command_buffer_handle_t_::getFenceForQueue (
@@ -646,7 +713,7 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,
646713 // the current implementation only uses the main copy engine and does not use
647714 // the link engine even if available.
648715 if (Device->hasMainCopyEngine ()) {
649- UR_CALL (createMainCommandList (Context, Device, false , false , true ,
716+ UR_CALL (createMainCommandList (Context, Device, IsInOrder , false , true ,
650717 ZeCopyCommandList));
651718 }
652719
@@ -812,18 +879,18 @@ finalizeWaitEventPath(ur_exp_command_buffer_handle_t CommandBuffer) {
812879 (CommandBuffer->ZeCommandListResetEvents ,
813880 CommandBuffer->ExecutionFinishedEvent ->ZeEvent ));
814881
882+ // Reset the L0 events we use for command-buffer sync-points to the
883+ // non-signaled state. This is required for multiple submissions.
884+ for (auto &Event : CommandBuffer->ZeEventsList ) {
885+ ZE2UR_CALL (zeCommandListAppendEventReset,
886+ (CommandBuffer->ZeCommandListResetEvents , Event));
887+ }
888+
815889 if (CommandBuffer->IsInOrderCmdList ) {
816890 ZE2UR_CALL (zeCommandListAppendSignalEvent,
817891 (CommandBuffer->ZeComputeCommandList ,
818892 CommandBuffer->ExecutionFinishedEvent ->ZeEvent ));
819893 } else {
820- // Reset the L0 events we use for command-buffer sync-points to the
821- // non-signaled state. This is required for multiple submissions.
822- for (auto &Event : CommandBuffer->ZeEventsList ) {
823- ZE2UR_CALL (zeCommandListAppendEventReset,
824- (CommandBuffer->ZeCommandListResetEvents , Event));
825- }
826-
827894 // Wait for all the user added commands to complete, and signal the
828895 // command-buffer signal-event when they are done.
829896 ZE2UR_CALL (zeCommandListAppendBarrier,
@@ -1073,7 +1140,8 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
10731140 std::vector<ze_event_handle_t > ZeEventList;
10741141 ze_event_handle_t ZeLaunchEvent = nullptr ;
10751142 UR_CALL (createSyncPointAndGetZeEvents (
1076- UR_COMMAND_KERNEL_LAUNCH, CommandBuffer, NumSyncPointsInWaitList,
1143+ UR_COMMAND_KERNEL_LAUNCH, CommandBuffer,
1144+ CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
10771145 SyncPointWaitList, false , RetSyncPoint, ZeEventList, ZeLaunchEvent));
10781146
10791147 ZE2UR_CALL (zeCommandListAppendLaunchKernel,
@@ -1306,29 +1374,25 @@ ur_result_t urCommandBufferAppendUSMPrefetchExp(
13061374 std::ignore = Command;
13071375 std::ignore = Flags;
13081376
1309- if (CommandBuffer->IsInOrderCmdList ) {
1310- // Add the prefetch command to the command-buffer.
1311- // Note that L0 does not handle migration flags.
1312- ZE2UR_CALL (zeCommandListAppendMemoryPrefetch,
1313- (CommandBuffer->ZeComputeCommandList , Mem, Size));
1314- } else {
1315- std::vector<ze_event_handle_t > ZeEventList;
1316- ze_event_handle_t ZeLaunchEvent = nullptr ;
1317- UR_CALL (createSyncPointAndGetZeEvents (
1318- UR_COMMAND_USM_PREFETCH, CommandBuffer, NumSyncPointsInWaitList,
1319- SyncPointWaitList, true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
1320-
1321- if (NumSyncPointsInWaitList) {
1322- ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1323- (CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1324- ZeEventList.data ()));
1325- }
1377+ std::vector<ze_event_handle_t > ZeEventList;
1378+ ze_event_handle_t ZeLaunchEvent = nullptr ;
1379+ UR_CALL (createSyncPointAndGetZeEvents (
1380+ UR_COMMAND_USM_PREFETCH, CommandBuffer,
1381+ CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1382+ SyncPointWaitList, true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
1383+
1384+ if (NumSyncPointsInWaitList) {
1385+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1386+ (CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1387+ ZeEventList.data ()));
1388+ }
13261389
1327- // Add the prefetch command to the command-buffer.
1328- // Note that L0 does not handle migration flags.
1329- ZE2UR_CALL (zeCommandListAppendMemoryPrefetch,
1330- (CommandBuffer->ZeComputeCommandList , Mem, Size));
1390+ // Add the prefetch command to the command-buffer.
1391+ // Note that L0 does not handle migration flags.
1392+ ZE2UR_CALL (zeCommandListAppendMemoryPrefetch,
1393+ (CommandBuffer->ZeComputeCommandList , Mem, Size));
13311394
1395+ if (!CommandBuffer->IsInOrderCmdList ) {
13321396 // Level Zero does not have a completion "event" with the prefetch API,
13331397 // so manually add command to signal our event.
13341398 ZE2UR_CALL (zeCommandListAppendSignalEvent,
@@ -1376,27 +1440,24 @@ ur_result_t urCommandBufferAppendUSMAdviseExp(
13761440
13771441 ze_memory_advice_t ZeAdvice = static_cast <ze_memory_advice_t >(Value);
13781442
1379- if (CommandBuffer->IsInOrderCmdList ) {
1380- ZE2UR_CALL (zeCommandListAppendMemAdvise,
1381- (CommandBuffer->ZeComputeCommandList ,
1382- CommandBuffer->Device ->ZeDevice , Mem, Size, ZeAdvice));
1383- } else {
1384- std::vector<ze_event_handle_t > ZeEventList;
1385- ze_event_handle_t ZeLaunchEvent = nullptr ;
1386- UR_CALL (createSyncPointAndGetZeEvents (
1387- UR_COMMAND_USM_ADVISE, CommandBuffer, NumSyncPointsInWaitList,
1388- SyncPointWaitList, true , RetSyncPoint, ZeEventList, ZeLaunchEvent));
1389-
1390- if (NumSyncPointsInWaitList) {
1391- ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1392- (CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1393- ZeEventList.data ()));
1394- }
1443+ std::vector<ze_event_handle_t > ZeEventList;
1444+ ze_event_handle_t ZeLaunchEvent = nullptr ;
1445+ UR_CALL (createSyncPointAndGetZeEvents (
1446+ UR_COMMAND_USM_ADVISE, CommandBuffer, CommandBuffer->ZeComputeCommandList ,
1447+ NumSyncPointsInWaitList, SyncPointWaitList, true , RetSyncPoint,
1448+ ZeEventList, ZeLaunchEvent));
13951449
1396- ZE2UR_CALL (zeCommandListAppendMemAdvise,
1397- (CommandBuffer->ZeComputeCommandList ,
1398- CommandBuffer->Device ->ZeDevice , Mem, Size, ZeAdvice));
1450+ if (NumSyncPointsInWaitList) {
1451+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1452+ (CommandBuffer->ZeComputeCommandList , NumSyncPointsInWaitList,
1453+ ZeEventList.data ()));
1454+ }
1455+
1456+ ZE2UR_CALL (zeCommandListAppendMemAdvise,
1457+ (CommandBuffer->ZeComputeCommandList ,
1458+ CommandBuffer->Device ->ZeDevice , Mem, Size, ZeAdvice));
13991459
1460+ if (!CommandBuffer->IsInOrderCmdList ) {
14001461 // Level Zero does not have a completion "event" with the advise API,
14011462 // so manually add command to signal our event.
14021463 ZE2UR_CALL (zeCommandListAppendSignalEvent,
0 commit comments