Skip to content

Commit e10231c

Browse files
authored
[hip] Set the current device before calls into Hip. (iree-org#19103)
This is a bit of a brute-force way to solve our main hip multi-device problems temporarily until the more complete fix is in place. For the single-device case this has negligible performance implications as `hipCtxSetCurrent` is a no-op in that case. For the multi-device case this could cause more significant performance problems if the user program swaps between devices within a thread. --------- Signed-off-by: Andrew Woloszyn <[email protected]>
1 parent 1a28f8d commit e10231c

18 files changed

+215
-39
lines changed

runtime/src/iree/hal/drivers/hip/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ iree_cc_library(
2121
"api.h"
2222
SRCS
2323
"api.h"
24+
"context_util.h"
2425
"event_pool.c"
2526
"event_pool.h"
2627
"event_semaphore.c"
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2024 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#ifndef IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_
8+
#define IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_
9+
10+
#include "iree/base/api.h"
11+
#include "iree/hal/drivers/hip/dynamic_symbols.h"
12+
#include "iree/hal/drivers/hip/status_util.h"
13+
14+
static inline iree_status_t iree_hal_hip_set_context(
15+
const iree_hal_hip_dynamic_symbols_t* syms, hipCtx_t hip_context) {
16+
if (!hip_context) {
17+
return iree_ok_status();
18+
}
19+
IREE_TRACE({
20+
hipCtx_t current_context = NULL;
21+
IREE_HIP_RETURN_IF_ERROR(syms, hipCtxGetCurrent(&current_context),
22+
"hipCtxGetCurrent");
23+
if (current_context != hip_context) {
24+
IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_set_context_switch");
25+
iree_status_t status =
26+
IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context));
27+
IREE_TRACE_ZONE_END(z0);
28+
return status;
29+
}
30+
});
31+
return IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context));
32+
}
33+
34+
#endif // IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_

runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
// HIP symbols
99
//===----------------------------------------------------------------------===//
1010

11+
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxGetCurrent, hipCtx_t *)
1112
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxSetCurrent, hipCtx_t)
1213
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGet, hipDevice_t *, int)
1314
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGetAttribute, int *,

runtime/src/iree/hal/drivers/hip/event_pool.c

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "iree/base/internal/atomics.h"
1515
#include "iree/base/internal/synchronization.h"
1616
#include "iree/hal/api.h"
17+
#include "iree/hal/drivers/hip/context_util.h"
1718
#include "iree/hal/drivers/hip/dynamic_symbols.h"
1819
#include "iree/hal/drivers/hip/status_util.h"
1920

@@ -36,6 +37,10 @@ struct iree_hal_hip_event_t {
3637
// The event pool that owns this event. This cannot be NULL. We retain it to
3738
// make sure the event outlive the pool.
3839
iree_hal_hip_event_pool_t* pool;
40+
41+
// The context to use to free this event, it must be the same
42+
// context as was used when allocating the event.
43+
hipCtx_t hip_context;
3944
// The underlying hipEvent_t object.
4045
hipEvent_t hip_event;
4146
};
@@ -48,6 +53,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) {
4853
iree_allocator_t host_allocator = event->host_allocator;
4954
const iree_hal_hip_dynamic_symbols_t* symbols = event->symbols;
5055
IREE_TRACE_ZONE_BEGIN(z0);
56+
IREE_IGNORE_ERROR(
57+
iree_hal_hip_set_context(event->symbols, event->hip_context));
5158

5259
IREE_ASSERT_REF_COUNT_ZERO(&event->ref_count);
5360
IREE_HIP_IGNORE_ERROR(symbols, hipEventDestroy(event->hip_event));
@@ -58,8 +65,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) {
5865

5966
static inline iree_status_t iree_hal_hip_event_create(
6067
const iree_hal_hip_dynamic_symbols_t* symbols,
61-
iree_hal_hip_event_pool_t* pool, iree_allocator_t host_allocator,
62-
iree_hal_hip_event_t** out_event) {
68+
iree_hal_hip_event_pool_t* pool, hipCtx_t context,
69+
iree_allocator_t host_allocator, iree_hal_hip_event_t** out_event) {
6370
IREE_ASSERT_ARGUMENT(symbols);
6471
IREE_ASSERT_ARGUMENT(pool);
6572
IREE_ASSERT_ARGUMENT(out_event);
@@ -75,6 +82,7 @@ static inline iree_status_t iree_hal_hip_event_create(
7582
event->symbols = symbols;
7683
event->pool = pool;
7784
event->hip_event = NULL;
85+
event->hip_context = context;
7886

7987
iree_status_t status = IREE_HIP_RESULT_TO_STATUS(
8088
symbols,
@@ -122,6 +130,10 @@ struct iree_hal_hip_event_pool_t {
122130
// The symbols used to create and destroy hipEvent_t objects.
123131
const iree_hal_hip_dynamic_symbols_t* symbols;
124132

133+
// The context for this event pool to use to allocate
134+
// events.
135+
hipCtx_t hip_context;
136+
125137
// Guards event related fields in the pool. We don't expect a performant
126138
// program to frequently allocate events for synchronization purposes; the
127139
// traffic to this pool should be low. So it should be fine to use mutex to
@@ -142,7 +154,7 @@ struct iree_hal_hip_event_pool_t {
142154
static void iree_hal_hip_event_pool_free(iree_hal_hip_event_pool_t* event_pool);
143155

144156
iree_status_t iree_hal_hip_event_pool_allocate(
145-
const iree_hal_hip_dynamic_symbols_t* symbols,
157+
const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context,
146158
iree_host_size_t available_capacity, iree_allocator_t host_allocator,
147159
iree_hal_hip_event_pool_t** out_event_pool) {
148160
IREE_ASSERT_ARGUMENT(symbols);
@@ -163,11 +175,12 @@ iree_status_t iree_hal_hip_event_pool_allocate(
163175
iree_slim_mutex_initialize(&event_pool->event_mutex);
164176
event_pool->available_capacity = available_capacity;
165177
event_pool->available_count = 0;
178+
event_pool->hip_context = hip_context;
166179

167180
iree_status_t status = iree_ok_status();
168181
for (iree_host_size_t i = 0; i < available_capacity; ++i) {
169182
status = iree_hal_hip_event_create(
170-
symbols, event_pool, host_allocator,
183+
symbols, event_pool, hip_context, host_allocator,
171184
&event_pool->available_list[event_pool->available_count++]);
172185
if (!iree_status_is_ok(status)) break;
173186
}
@@ -240,9 +253,9 @@ iree_status_t iree_hal_hip_event_pool_acquire(
240253
IREE_TRACE_ZONE_BEGIN_NAMED(z1, "event-pool-unpooled-acquire");
241254
iree_status_t status = iree_ok_status();
242255
for (iree_host_size_t i = 0; i < remaining_count; ++i) {
243-
status = iree_hal_hip_event_create(event_pool->symbols, event_pool,
244-
event_pool->host_allocator,
245-
&out_events[from_pool_count + i]);
256+
status = iree_hal_hip_event_create(
257+
event_pool->symbols, event_pool, event_pool->hip_context,
258+
event_pool->host_allocator, &out_events[from_pool_count + i]);
246259
if (!iree_status_is_ok(status)) {
247260
// Must release all events we've acquired so far.
248261
iree_hal_hip_event_pool_release_event(event_pool, from_pool_count + i,

runtime/src/iree/hal/drivers/hip/event_pool.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ typedef struct iree_hal_hip_event_pool_t iree_hal_hip_event_pool_t;
5252
// Extra events requested beyond the capability are directly created and
5353
// destroyed without pooling.
5454
iree_status_t iree_hal_hip_event_pool_allocate(
55-
const iree_hal_hip_dynamic_symbols_t* symbols,
55+
const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context,
5656
iree_host_size_t available_capacity, iree_allocator_t host_allocator,
5757
iree_hal_hip_event_pool_t** out_event_pool);
5858

runtime/src/iree/hal/drivers/hip/event_semaphore.c

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
#include "iree/base/internal/synchronization.h"
1010
#include "iree/base/internal/wait_handle.h"
1111
#include "iree/base/status.h"
12+
#include "iree/hal/drivers/hip/context_util.h"
1213
#include "iree/hal/drivers/hip/dynamic_symbols.h"
14+
#include "iree/hal/drivers/hip/status_util.h"
1315
#include "iree/hal/drivers/hip/timepoint_pool.h"
1416
#include "iree/hal/utils/semaphore_base.h"
1517

@@ -30,6 +32,8 @@ typedef struct iree_hal_hip_semaphore_t {
3032
// new signaled values.
3133
iree_hal_deferred_work_queue_t* work_queue;
3234

35+
hipCtx_t hip_context;
36+
3337
// Guards value and status. We expect low contention on semaphores and since
3438
// iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler
3539
// than trying to make the entire structure lock-free.
@@ -56,7 +60,7 @@ static iree_hal_hip_semaphore_t* iree_hal_hip_semaphore_cast(
5660

5761
iree_status_t iree_hal_hip_event_semaphore_create(
5862
uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols,
59-
iree_hal_hip_timepoint_pool_t* timepoint_pool,
63+
hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool,
6064
iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator,
6165
iree_hal_semaphore_t** out_semaphore) {
6266
IREE_ASSERT_ARGUMENT(symbols);
@@ -65,6 +69,8 @@ iree_status_t iree_hal_hip_event_semaphore_create(
6569
IREE_ASSERT_ARGUMENT(out_semaphore);
6670
IREE_TRACE_ZONE_BEGIN(z0);
6771

72+
IREE_RETURN_AND_END_ZONE_IF_ERROR(
73+
z0, iree_hal_hip_set_context(symbols, hip_context));
6874
iree_hal_hip_semaphore_t* semaphore = NULL;
6975
IREE_RETURN_AND_END_ZONE_IF_ERROR(
7076
z0, iree_allocator_malloc(host_allocator, sizeof(*semaphore),
@@ -79,6 +85,7 @@ iree_status_t iree_hal_hip_event_semaphore_create(
7985
iree_slim_mutex_initialize(&semaphore->mutex);
8086
semaphore->current_value = initial_value;
8187
semaphore->failure_status = iree_ok_status();
88+
semaphore->hip_context = hip_context;
8289

8390
*out_semaphore = &semaphore->base;
8491

@@ -92,6 +99,8 @@ static void iree_hal_hip_semaphore_destroy(
9299
iree_hal_hip_semaphore_cast(base_semaphore);
93100
iree_allocator_t host_allocator = semaphore->host_allocator;
94101
IREE_TRACE_ZONE_BEGIN(z0);
102+
IREE_IGNORE_ERROR(
103+
iree_hal_hip_set_context(semaphore->symbols, semaphore->hip_context));
95104

96105
iree_status_ignore(semaphore->failure_status);
97106
iree_slim_mutex_deinitialize(&semaphore->mutex);

runtime/src/iree/hal/drivers/hip/event_semaphore.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ extern "C" {
3131
// Thread-safe; multiple threads may signal/wait values on the same semaphore.
3232
iree_status_t iree_hal_hip_event_semaphore_create(
3333
uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols,
34-
iree_hal_hip_timepoint_pool_t* timepoint_pool,
34+
hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool,
3535
iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator,
3636
iree_hal_semaphore_t** out_semaphore);
3737

runtime/src/iree/hal/drivers/hip/hip_allocator.c

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "iree/base/api.h"
1212
#include "iree/base/tracing.h"
13+
#include "iree/hal/drivers/hip/context_util.h"
1314
#include "iree/hal/drivers/hip/dynamic_symbols.h"
1415
#include "iree/hal/drivers/hip/hip_buffer.h"
1516
#include "iree/hal/drivers/hip/status_util.h"
@@ -29,6 +30,8 @@ typedef struct iree_hal_hip_allocator_t {
2930
// The HIP stream that allocations should be used in.
3031
hipStream_t stream;
3132

33+
hipCtx_t hip_context;
34+
3235
// NOTE: optional depending on device support.
3336
iree_hal_hip_memory_pools_t* pools;
3437

@@ -54,11 +57,14 @@ static iree_hal_hip_allocator_t* iree_hal_hip_allocator_cast(
5457

5558
iree_status_t iree_hal_hip_allocator_create(
5659
const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device,
57-
hipStream_t stream, iree_hal_hip_memory_pools_t* pools,
58-
iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) {
60+
hipCtx_t hip_context, hipStream_t stream,
61+
iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator,
62+
iree_hal_allocator_t** out_allocator) {
5963
IREE_ASSERT_ARGUMENT(hip_symbols);
6064
IREE_ASSERT_ARGUMENT(out_allocator);
6165
IREE_TRACE_ZONE_BEGIN(z0);
66+
IREE_RETURN_AND_END_ZONE_IF_ERROR(
67+
z0, iree_hal_hip_set_context(hip_symbols, hip_context));
6268

6369
// To support device-local + host-visible memory we need concurrent managed
6470
// access indicating that the host and devices can concurrently access the
@@ -94,6 +100,7 @@ iree_status_t iree_hal_hip_allocator_create(
94100
allocator->host_allocator = host_allocator;
95101
allocator->supports_concurrent_managed_access =
96102
supports_concurrent_managed_access != 0;
103+
allocator->hip_context = hip_context;
97104
*out_allocator = (iree_hal_allocator_t*)allocator;
98105

99106
IREE_TRACE_ZONE_END(z0);
@@ -352,6 +359,9 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer(
352359
void* host_ptr = NULL;
353360
hipDeviceptr_t device_ptr = NULL;
354361
IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_buffer_allocate");
362+
IREE_RETURN_AND_END_ZONE_IF_ERROR(
363+
z0, iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));
364+
355365
IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, allocation_size);
356366
if (iree_all_bits_set(compat_params.type,
357367
IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) {
@@ -431,6 +441,9 @@ static void iree_hal_hip_allocator_deallocate_buffer(
431441
iree_hal_hip_allocator_t* allocator =
432442
iree_hal_hip_allocator_cast(base_allocator);
433443

444+
IREE_IGNORE_ERROR(
445+
iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));
446+
434447
const iree_hal_hip_buffer_type_t buffer_type =
435448
iree_hal_hip_buffer_type(base_buffer);
436449

@@ -466,6 +479,9 @@ static iree_status_t iree_hal_hip_allocator_import_buffer(
466479
iree_hal_hip_allocator_t* allocator =
467480
iree_hal_hip_allocator_cast(base_allocator);
468481

482+
IREE_RETURN_IF_ERROR(
483+
iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));
484+
469485
// Coerce options into those required by the current device.
470486
iree_hal_buffer_params_t compat_params = *params;
471487
iree_device_size_t allocation_size = external_buffer->size;
@@ -600,6 +616,9 @@ iree_status_t iree_hal_hip_allocator_alloc_async(
600616
iree_hal_hip_allocator_t* allocator =
601617
iree_hal_hip_allocator_cast(base_allocator);
602618

619+
IREE_RETURN_IF_ERROR(
620+
iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));
621+
603622
hipDeviceptr_t ptr = NULL;
604623
iree_status_t status = IREE_HIP_RESULT_TO_STATUS(
605624
allocator->symbols,
@@ -625,6 +644,9 @@ iree_status_t iree_hal_hip_allocator_free_async(
625644
iree_hal_buffer_t* buffer) {
626645
iree_hal_hip_allocator_t* allocator =
627646
iree_hal_hip_allocator_cast(base_allocator);
647+
IREE_RETURN_IF_ERROR(
648+
iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));
649+
628650
hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer);
629651
if (!device_ptr) {
630652
return iree_ok_status();

runtime/src/iree/hal/drivers/hip/hip_allocator.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@ extern "C" {
1717
#endif // __cplusplus
1818

1919
// Creates a HIP memory allocator.
20-
// |device| and |stream| will be used for management operations.
20+
// |device| |hip_context| and |stream| will be used for management operations.
2121
// |pools| provides memory pools that may be shared across multiple allocators
2222
// and the pointer must remain valid for the lifetime of the allocator.
2323
iree_status_t iree_hal_hip_allocator_create(
2424
const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device,
25-
hipStream_t stream, iree_hal_hip_memory_pools_t* pools,
26-
iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator);
25+
hipCtx_t hip_context, hipStream_t stream,
26+
iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator,
27+
iree_hal_allocator_t** out_allocator);
2728

2829
bool iree_hal_hip_allocator_isa(iree_hal_allocator_t* base_value);
2930

0 commit comments

Comments
 (0)