Skip to content

Commit 4ac7074

Browse files
authored
Fix the ordering of the external stream. (#23598)
There was an issue with the previous fix that we would not retain the data properly. Signed-off-by: Andrew Woloszyn <andrew.woloszyn@gmail.com>
1 parent eb76100 commit 4ac7074

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

runtime/src/iree/hal/drivers/hip/hip_device.c

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,10 +1566,10 @@ static iree_status_t iree_hal_hip_device_perform_buffer_operation_now(
15661566
}
15671567
}
15681568
IREE_TRACE_ZONE_END(z3);
1569-
1569+
iree_hal_hip_dispatch_completed_data_t* external_stream_dispatch_data = NULL;
15701570
if (device->uses_external_stream) {
1571-
iree_hal_hip_set_external_stream_data_completed(
1572-
data->base.external_stream_dispatch_data);
1571+
external_stream_dispatch_data = data->base.external_stream_dispatch_data;
1572+
iree_hal_resource_retain(external_stream_dispatch_data);
15731573
}
15741574

15751575
const iree_hal_hip_dynamic_symbols_t* symbols = device->hip_symbols;
@@ -1589,6 +1589,12 @@ static iree_status_t iree_hal_hip_device_perform_buffer_operation_now(
15891589
iree_hal_hip_device_destroy_buffer_callback_data(data);
15901590
}
15911591

1592+
if (external_stream_dispatch_data) {
1593+
iree_hal_hip_set_external_stream_data_completed(
1594+
external_stream_dispatch_data);
1595+
iree_hal_resource_release(external_stream_dispatch_data);
1596+
}
1597+
15921598
IREE_TRACE_ZONE_END(z0);
15931599
return iree_status_join(
15941600
status, IREE_HIP_CALL_TO_STATUS(symbols, hipCtxPopCurrent(NULL)));
@@ -2061,6 +2067,11 @@ static iree_status_t iree_hal_hip_device_perform_queue_read_now(
20612067
device, device->devices[device_ordinal].hip_dispatch_stream,
20622068
data->base.wait_semaphore_list, device_ordinal);
20632069
}
2070+
iree_hal_hip_dispatch_completed_data_t* external_stream_dispatch_data = NULL;
2071+
if (device->uses_external_stream) {
2072+
external_stream_dispatch_data = data->base.external_stream_dispatch_data;
2073+
iree_hal_resource_retain(external_stream_dispatch_data);
2074+
}
20642075

20652076
const iree_hal_hip_dynamic_symbols_t* symbols = device->hip_symbols;
20662077
iree_device_size_t amount_left = data->length;
@@ -2152,20 +2163,22 @@ static iree_status_t iree_hal_hip_device_perform_queue_read_now(
21522163
}
21532164
}
21542165

2155-
if (device->uses_external_stream) {
2156-
iree_hal_hip_set_external_stream_data_completed(
2157-
data->base.external_stream_dispatch_data);
2158-
}
2159-
21602166
if (!iree_status_is_ok(status)) {
21612167
for (iree_host_size_t i = 0; i < data->base.signal_semaphore_list.count;
21622168
++i) {
21632169
iree_hal_semaphore_fail(data->base.signal_semaphore_list.semaphores[i],
21642170
iree_status_clone(status));
21652171
}
2172+
21662173
iree_hal_hip_device_destroy_queue_read_callback_data(data);
21672174
}
21682175

2176+
if (external_stream_dispatch_data) {
2177+
iree_hal_hip_set_external_stream_data_completed(
2178+
external_stream_dispatch_data);
2179+
iree_hal_resource_release(external_stream_dispatch_data);
2180+
}
2181+
21692182
IREE_TRACE_ZONE_END(z0);
21702183
return iree_status_join(
21712184
status, IREE_HIP_CALL_TO_STATUS(symbols, hipCtxPopCurrent(NULL)));
@@ -2494,10 +2507,10 @@ static iree_status_t iree_hal_hip_device_execute_now(void* user_data,
24942507
}
24952508

24962509
IREE_TRACE_ZONE_END(z1);
2497-
2510+
iree_hal_hip_dispatch_completed_data_t* external_stream_dispatch_data = NULL;
24982511
if (device->uses_external_stream) {
2499-
iree_hal_hip_set_external_stream_data_completed(
2500-
data->base.external_stream_dispatch_data);
2512+
external_stream_dispatch_data = data->base.external_stream_dispatch_data;
2513+
iree_hal_resource_retain(external_stream_dispatch_data);
25012514
}
25022515

25032516
// Store symbols, because the cleanup may trigger off-thread
@@ -2520,6 +2533,12 @@ static iree_status_t iree_hal_hip_device_execute_now(void* user_data,
25202533
iree_hal_hip_device_destroy_callback_data(data);
25212534
}
25222535

2536+
if (external_stream_dispatch_data) {
2537+
iree_hal_hip_set_external_stream_data_completed(
2538+
external_stream_dispatch_data);
2539+
iree_hal_resource_release(external_stream_dispatch_data);
2540+
}
2541+
25232542
IREE_TRACE_ZONE_END(z0);
25242543
return iree_status_join(
25252544
status, IREE_HIP_CALL_TO_STATUS(symbols, hipCtxPopCurrent(NULL)));

0 commit comments

Comments
 (0)