1515#include " logger/ur_logger.hpp"
1616#include " queue_handle.hpp"
1717
18+ ur_result_t ur_execution_event_handle_t::assign (ur_event_handle_t hNewEvent) {
19+ assert (hNewEvent);
20+ assert (hNewEvent->getURQueueHandle ());
21+
22+ auto newQueue = hNewEvent->getURQueueHandle ();
23+ auto currentQueue = hEvent ? hEvent->getURQueueHandle () : nullptr ;
24+
25+ if (hEvent) {
26+ UR_CALL (hEvent->release ());
27+ }
28+
29+ hEvent = hNewEvent;
30+
31+ if (newQueue != currentQueue) {
32+ if (currentQueue)
33+ UR_CALL (currentQueue->queueRelease ());
34+ UR_CALL (newQueue->queueRetain ());
35+ }
36+
37+ return UR_RESULT_SUCCESS;
38+ }
39+
40+ ur_event_handle_t ur_execution_event_handle_t::get () { return hEvent; }
41+
42+ ur_result_t ur_execution_event_handle_t::release () {
43+ if (hEvent) {
44+ assert (hEvent->getURQueueHandle ());
45+
46+ auto hQueue = hEvent->getURQueueHandle ();
47+ UR_CALL_NOCHECK (hEvent->release ());
48+ UR_CALL_NOCHECK (hQueue->queueRelease ());
49+
50+ hEvent = nullptr ;
51+ }
52+ return UR_RESULT_SUCCESS;
53+ }
54+
55+ ur_execution_event_handle_t ::~ur_execution_event_handle_t () { release (); }
56+
1857namespace {
1958
2059ur_result_t getZeKernelWrapped (ur_kernel_handle_t kernel,
@@ -69,13 +108,12 @@ ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_(
69108 : eventPool(context->getEventPoolCache (PoolCacheType::Regular)
70109 .borrow(device->Id.value(),
71110 isInOrder ? v2::EVENT_FLAGS_COUNTER : 0)),
72- context(context), device(device),
111+ context(context), device(device), currentExecution( nullptr ),
73112 isUpdatable(desc ? desc->isUpdatable : false ),
74113 isInOrder(desc ? desc->isInOrder : false ),
75114 commandListManager(
76115 context, device,
77- std::forward<v2::raii::command_list_unique_handle>(commandList))
78- {}
116+ std::forward<v2::raii::command_list_unique_handle>(commandList)) {}
79117
80118ur_exp_command_buffer_sync_point_t
81119ur_exp_command_buffer_handle_t_::getSyncPoint (ur_event_handle_t event) {
@@ -146,25 +184,16 @@ ur_result_t ur_exp_command_buffer_handle_t_::finalizeCommandBuffer() {
146184 return UR_RESULT_SUCCESS;
147185}
148186ur_event_handle_t ur_exp_command_buffer_handle_t_::getExecutionEventUnlocked () {
149- return currentExecution;
187+ return currentExecution. get () ;
150188}
151189
152190ur_result_t ur_exp_command_buffer_handle_t_::registerExecutionEventUnlocked (
153191 ur_event_handle_t nextExecutionEvent) {
154- if (currentExecution) {
155- UR_CALL (currentExecution->release ());
156- currentExecution = nullptr ;
157- }
158- if (nextExecutionEvent) {
159- currentExecution = nextExecutionEvent;
160- }
192+ UR_CALL (currentExecution.assign (nextExecutionEvent));
161193 return UR_RESULT_SUCCESS;
162194}
163195
164196ur_exp_command_buffer_handle_t_::~ur_exp_command_buffer_handle_t_ () {
165- if (currentExecution) {
166- currentExecution->release ();
167- }
168197 for (auto &event : syncPoints) {
169198 event->release ();
170199 }
@@ -181,14 +210,13 @@ ur_result_t ur_exp_command_buffer_handle_t_::applyUpdateCommands(
181210 this , device, context->getPlatform ()->ZeDriverGlobalOffsetExtensionFound ,
182211 numUpdateCommands, updateCommands));
183212
184- if (currentExecution) {
213+ if (currentExecution. get () ) {
185214 // TODO: Move synchronization to command buffer enqueue
186215 // it would require to remember the update commands and perform update
187216 // before appending to the queue
188217 ZE2UR_CALL (zeEventHostSynchronize,
189- (currentExecution->getZeEvent (), UINT64_MAX));
190- currentExecution->release ();
191- currentExecution = nullptr ;
218+ (currentExecution.get ()->getZeEvent (), UINT64_MAX));
219+ UR_CALL (currentExecution.release ());
192220 }
193221
194222 device_ptr_storage_t zeHandles;
0 commit comments