Skip to content

Commit c2c44e7

Browse files
benvanikgiacs-epic
authored andcommitted
Adding HAL semaphore support for statuses-as-failure-payloads. (iree-org#18912)
This allows an implementation to have a single atomic value for a semaphore that encodes the user payload or an error payload that optionally references an iree_status_t object. Implementations not using the status feature can ignore it but must perform a greater-than-or-equal check on `IREE_HAL_SEMAPHORE_FAILURE_VALUE` instead of equality. Signed-off-by: Giacomo Serafini <[email protected]>
1 parent 90743be commit c2c44e7

File tree

5 files changed

+69
-15
lines changed

5 files changed

+69
-15
lines changed

runtime/src/iree/hal/cts/semaphore_submission_test.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,7 @@ TEST_F(SemaphoreSubmissionTest, PropagateFailSignal) {
882882
EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED);
883883
uint64_t value = 1234;
884884
iree_status_t query_status = iree_hal_semaphore_query(semaphore2, &value);
885-
EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
885+
EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
886886
CheckStatusContains(query_status, status);
887887

888888
signal_thread.join();

runtime/src/iree/hal/cts/semaphore_test.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ TEST_F(SemaphoreTest, FailThenWait) {
406406
EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED);
407407
uint64_t value = 1234;
408408
iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value);
409-
EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
409+
EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
410410
CheckStatusContains(query_status, status);
411411

412412
iree_hal_semaphore_release(semaphore);
@@ -431,7 +431,7 @@ TEST_F(SemaphoreTest, WaitThenFail) {
431431
EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED);
432432
uint64_t value = 1234;
433433
iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value);
434-
EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
434+
EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
435435
CheckStatusContains(query_status, status);
436436

437437
signal_thread.join();
@@ -467,7 +467,7 @@ TEST_F(SemaphoreTest, MultiWaitThenFail) {
467467
uint64_t value = 1234;
468468
iree_status_t semaphore1_query_status =
469469
iree_hal_semaphore_query(semaphore1, &value);
470-
EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
470+
EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
471471
CheckStatusContains(semaphore1_query_status, status);
472472

473473
// semaphore2 must not have changed.
@@ -511,7 +511,7 @@ TEST_F(SemaphoreTest, DeviceMultiWaitThenFail) {
511511
uint64_t value = 1234;
512512
iree_status_t semaphore1_query_status =
513513
iree_hal_semaphore_query(semaphore1, &value);
514-
EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
514+
EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
515515
CheckStatusContains(semaphore1_query_status, status);
516516

517517
// semaphore2 must not have changed.

runtime/src/iree/hal/drivers/cuda/event_semaphore.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ static iree_status_t iree_hal_cuda_semaphore_wait(
325325
}
326326

327327
iree_slim_mutex_lock(&semaphore->mutex);
328-
if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
328+
if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
329329
iree_slim_mutex_unlock(&semaphore->mutex);
330330
IREE_TRACE_ZONE_END(z0);
331331
return iree_make_status(IREE_STATUS_ABORTED);
@@ -350,7 +350,7 @@ static iree_status_t iree_hal_cuda_semaphore_wait(
350350
}
351351

352352
iree_slim_mutex_lock(&semaphore->mutex);
353-
if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
353+
if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
354354
status = iree_make_status(IREE_STATUS_ABORTED);
355355
}
356356
iree_slim_mutex_unlock(&semaphore->mutex);
@@ -444,7 +444,7 @@ iree_status_t iree_hal_cuda_semaphore_multi_wait(
444444
iree_hal_cuda_semaphore_t* semaphore =
445445
iree_hal_cuda_semaphore_cast(semaphore_list.semaphores[i]);
446446
iree_slim_mutex_lock(&semaphore->mutex);
447-
if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
447+
if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
448448
iree_slim_mutex_unlock(&semaphore->mutex);
449449
status = iree_make_status(IREE_STATUS_ABORTED);
450450
break;

runtime/src/iree/hal/drivers/hip/event_semaphore.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ static iree_status_t iree_hal_hip_semaphore_wait(
323323
}
324324

325325
iree_slim_mutex_lock(&semaphore->mutex);
326-
if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
326+
if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
327327
iree_slim_mutex_unlock(&semaphore->mutex);
328328
IREE_TRACE_ZONE_END(z0);
329329
return iree_make_status(IREE_STATUS_ABORTED);
@@ -346,7 +346,7 @@ static iree_status_t iree_hal_hip_semaphore_wait(
346346
}
347347

348348
iree_slim_mutex_lock(&semaphore->mutex);
349-
if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
349+
if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
350350
status = iree_make_status(IREE_STATUS_ABORTED);
351351
}
352352
iree_slim_mutex_unlock(&semaphore->mutex);
@@ -440,7 +440,7 @@ iree_status_t iree_hal_hip_semaphore_multi_wait(
440440
iree_hal_hip_semaphore_t* semaphore =
441441
iree_hal_hip_semaphore_cast(semaphore_list.semaphores[i]);
442442
iree_slim_mutex_lock(&semaphore->mutex);
443-
if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
443+
if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
444444
iree_slim_mutex_unlock(&semaphore->mutex);
445445
status = iree_make_status(IREE_STATUS_ABORTED);
446446
break;

runtime/src/iree/hal/semaphore.h

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@ enum iree_hal_semaphore_flag_bits_t {
3030
};
3131
typedef uint32_t iree_hal_semaphore_flags_t;
3232

33-
//===----------------------------------------------------------------------===//
34-
// iree_hal_semaphore_t
35-
//===----------------------------------------------------------------------===//
36-
3733
// The maximum valid payload value of an iree_hal_semaphore_t.
3834
// Payload values larger than this indicate that the semaphore has failed.
3935
//
@@ -56,8 +52,66 @@ typedef uint32_t iree_hal_semaphore_flags_t;
5652
// https://vulkan.gpuinfo.org/displayextensionproperty.php?name=maxTimelineSemaphoreValueDifference
5753
#define IREE_HAL_SEMAPHORE_MAX_VALUE (2147483647ull - 1)
5854

55+
// The minimum value for a semaphore that indicates failure. Any value
56+
// greater-than-or-equal-to (>=) this indicates the semaphore has failed.
57+
//
58+
// If the upper bit 63 is set then the value represents an iree_status_t.
59+
// Use iree_hal_semaphore_failure_as_status to convert a payload value to a
60+
// status. Not all implementations do (or can) support encoding statuses and may
61+
// only ever be able to set a failing semaphore to this value.
5962
#define IREE_HAL_SEMAPHORE_FAILURE_VALUE (IREE_HAL_SEMAPHORE_MAX_VALUE + 1)
6063

64+
// Bit indicating that a failing semaphore value can be interpreted as an
65+
// iree_status_t.
66+
#define IREE_HAL_SEMAPHORE_FAILURE_VALUE_STATUS_BIT 0x8000000000000000ull
67+
68+
// Returns a semaphore payload value that encodes the given |status|.
69+
// Ownership of the status is transferred to the semaphore and it must be
70+
// freed by a consumer. Not all implementations can support failure status
71+
// payloads and this should only be used by those implementations that can.
72+
static inline uint64_t iree_hal_status_as_semaphore_failure(
73+
iree_status_t status) {
74+
return IREE_HAL_SEMAPHORE_FAILURE_VALUE_STATUS_BIT |
75+
(((uint64_t)status) >> 1);
76+
}
77+
78+
// Returns OK if the |value| does not indicate an error.
79+
// Returns an error status if the semaphore payload value represents a failure.
80+
// If the payload contains an encoded iree_status_t it will be cloned and the
81+
// new copy will be returned to the caller.
82+
static inline iree_status_t iree_hal_semaphore_failure_as_status(
83+
uint64_t value) {
84+
if (value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
85+
if (value & IREE_HAL_SEMAPHORE_FAILURE_VALUE_STATUS_BIT) {
86+
// The top bits of a pointer are sign-extended from bit 47 so we can
87+
// restore the top bit by left-shifting the upper bits and then
88+
// right-shifting with sign extension. We only use a single bit today and
89+
// so bit 62 should still be the original value of the pointer.
90+
// Note that if the status is just a code (no allocated pointer) this
91+
// clone is a no-op and the code will be returned without an allocation.
92+
//
93+
// See:
94+
// https://en.wikipedia.org/wiki/X86-64#Canonical_form_addresses
95+
return iree_status_clone((iree_status_t)(((int64_t)value << 1) >> 1));
96+
} else {
97+
return iree_status_from_code(IREE_STATUS_INTERNAL);
98+
}
99+
} else {
100+
return iree_ok_status();
101+
}
102+
}
103+
104+
// Frees an iree_status_t encoded in a semaphore |value|, if any.
105+
static inline void iree_hal_semaphore_failure_free(uint64_t value) {
106+
if (value & IREE_HAL_SEMAPHORE_FAILURE_VALUE_STATUS_BIT) {
107+
iree_status_free((iree_status_t)(((int64_t)value << 1) >> 1));
108+
}
109+
}
110+
111+
//===----------------------------------------------------------------------===//
112+
// iree_hal_semaphore_t
113+
//===----------------------------------------------------------------------===//
114+
61115
// Synchronization mechanism for host->device, device->host, host->host,
62116
// and device->device notification. Semaphores behave like Vulkan timeline
63117
// semaphores (or D3D12 fences) and contain a monotonically increasing

0 commit comments

Comments
 (0)