1414#include "iree/base/internal/atomics.h"
1515#include "iree/base/internal/synchronization.h"
1616#include "iree/hal/api.h"
17+ #include "iree/hal/drivers/hip/context_util.h"
1718#include "iree/hal/drivers/hip/dynamic_symbols.h"
1819#include "iree/hal/drivers/hip/status_util.h"
1920
@@ -36,6 +37,10 @@ struct iree_hal_hip_event_t {
3637 // The event pool that owns this event. This cannot be NULL. We retain it to
3738 // make sure the event outlive the pool.
3839 iree_hal_hip_event_pool_t * pool ;
40+
41+ // The context to use to free this event, it must be the same
42+ // context as was used when allocating the event.
43+ hipCtx_t hip_context ;
3944 // The underlying hipEvent_t object.
4045 hipEvent_t hip_event ;
4146};
@@ -48,6 +53,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) {
4853 iree_allocator_t host_allocator = event -> host_allocator ;
4954 const iree_hal_hip_dynamic_symbols_t * symbols = event -> symbols ;
5055 IREE_TRACE_ZONE_BEGIN (z0 );
56+ IREE_IGNORE_ERROR (
57+ iree_hal_hip_set_context (event -> symbols , event -> hip_context ));
5158
5259 IREE_ASSERT_REF_COUNT_ZERO (& event -> ref_count );
5360 IREE_HIP_IGNORE_ERROR (symbols , hipEventDestroy (event -> hip_event ));
@@ -58,8 +65,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) {
5865
5966static inline iree_status_t iree_hal_hip_event_create (
6067 const iree_hal_hip_dynamic_symbols_t * symbols ,
61- iree_hal_hip_event_pool_t * pool , iree_allocator_t host_allocator ,
62- iree_hal_hip_event_t * * out_event ) {
68+ iree_hal_hip_event_pool_t * pool , hipCtx_t context ,
69+ iree_allocator_t host_allocator , iree_hal_hip_event_t * * out_event ) {
6370 IREE_ASSERT_ARGUMENT (symbols );
6471 IREE_ASSERT_ARGUMENT (pool );
6572 IREE_ASSERT_ARGUMENT (out_event );
@@ -75,6 +82,7 @@ static inline iree_status_t iree_hal_hip_event_create(
7582 event -> symbols = symbols ;
7683 event -> pool = pool ;
7784 event -> hip_event = NULL ;
85+ event -> hip_context = context ;
7886
7987 iree_status_t status = IREE_HIP_RESULT_TO_STATUS (
8088 symbols ,
@@ -122,6 +130,10 @@ struct iree_hal_hip_event_pool_t {
122130 // The symbols used to create and destroy hipEvent_t objects.
123131 const iree_hal_hip_dynamic_symbols_t * symbols ;
124132
133+ // The context for this event pool to use to allocate
134+ // events.
135+ hipCtx_t hip_context ;
136+
125137 // Guards event related fields in the pool. We don't expect a performant
126138 // program to frequently allocate events for synchronization purposes; the
127139 // traffic to this pool should be low. So it should be fine to use mutex to
@@ -142,7 +154,7 @@ struct iree_hal_hip_event_pool_t {
142154static void iree_hal_hip_event_pool_free (iree_hal_hip_event_pool_t * event_pool );
143155
144156iree_status_t iree_hal_hip_event_pool_allocate (
145- const iree_hal_hip_dynamic_symbols_t * symbols ,
157+ const iree_hal_hip_dynamic_symbols_t * symbols , hipCtx_t hip_context ,
146158 iree_host_size_t available_capacity , iree_allocator_t host_allocator ,
147159 iree_hal_hip_event_pool_t * * out_event_pool ) {
148160 IREE_ASSERT_ARGUMENT (symbols );
@@ -163,11 +175,12 @@ iree_status_t iree_hal_hip_event_pool_allocate(
163175 iree_slim_mutex_initialize (& event_pool -> event_mutex );
164176 event_pool -> available_capacity = available_capacity ;
165177 event_pool -> available_count = 0 ;
178+ event_pool -> hip_context = hip_context ;
166179
167180 iree_status_t status = iree_ok_status ();
168181 for (iree_host_size_t i = 0 ; i < available_capacity ; ++ i ) {
169182 status = iree_hal_hip_event_create (
170- symbols , event_pool , host_allocator ,
183+ symbols , event_pool , hip_context , host_allocator ,
171184 & event_pool -> available_list [event_pool -> available_count ++ ]);
172185 if (!iree_status_is_ok (status )) break ;
173186 }
@@ -240,9 +253,9 @@ iree_status_t iree_hal_hip_event_pool_acquire(
240253 IREE_TRACE_ZONE_BEGIN_NAMED (z1 , "event-pool-unpooled-acquire" );
241254 iree_status_t status = iree_ok_status ();
242255 for (iree_host_size_t i = 0 ; i < remaining_count ; ++ i ) {
243- status = iree_hal_hip_event_create (event_pool -> symbols , event_pool ,
244- event_pool -> host_allocator ,
245- & out_events [from_pool_count + i ]);
256+ status = iree_hal_hip_event_create (
257+ event_pool -> symbols , event_pool , event_pool -> hip_context ,
258+ event_pool -> host_allocator , & out_events [from_pool_count + i ]);
246259 if (!iree_status_is_ok (status )) {
247260 // Must release all events we've acquired so far.
248261 iree_hal_hip_event_pool_release_event (event_pool , from_pool_count + i ,
0 commit comments