Skip to content

Commit 38eb992

Browse files
benvanikGroverkss
authored andcommitted
Replacing iree_hal_command_buffer_discard_buffer with advise_buffer.
This is similar to madvise and cudaMemAdvise but within a command buffer. Stream operations in the compiler can lower to this to request cache management or lifetime to allow ASAN and other implementation features to know more about our aliased/slab buffer contents. There is likely a need for an iree_hal_buffer_advise host operation.
1 parent 1ca7baf commit 38eb992

15 files changed

+98
-68
lines changed

experimental/webgpu/command_buffer.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,9 +575,10 @@ static iree_status_t iree_hal_webgpu_command_buffer_wait_events(
575575
return iree_ok_status();
576576
}
577577

578-
static iree_status_t iree_hal_webgpu_command_buffer_discard_buffer(
578+
static iree_status_t iree_hal_webgpu_command_buffer_advise_buffer(
579579
iree_hal_command_buffer_t* base_command_buffer,
580-
iree_hal_buffer_ref_t buffer_ref) {
580+
iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags,
581+
uint64_t arg0, uint64_t arg1) {
581582
// No-op: though maybe it'd be a useful addition to the spec as otherwise
582583
// false dependencies can creep in.
583584
return iree_ok_status();
@@ -1043,7 +1044,7 @@ const iree_hal_command_buffer_vtable_t iree_hal_webgpu_command_buffer_vtable = {
10431044
.signal_event = iree_hal_webgpu_command_buffer_signal_event,
10441045
.reset_event = iree_hal_webgpu_command_buffer_reset_event,
10451046
.wait_events = iree_hal_webgpu_command_buffer_wait_events,
1046-
.discard_buffer = iree_hal_webgpu_command_buffer_discard_buffer,
1047+
.advise_buffer = iree_hal_webgpu_command_buffer_advise_buffer,
10471048
.fill_buffer = iree_hal_webgpu_command_buffer_fill_buffer,
10481049
.update_buffer = iree_hal_webgpu_command_buffer_update_buffer,
10491050
.copy_buffer = iree_hal_webgpu_command_buffer_copy_buffer,

runtime/src/iree/hal/command_buffer.c

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -405,18 +405,19 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_wait_events(
405405
return status;
406406
}
407407

408-
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_discard_buffer(
409-
iree_hal_command_buffer_t* command_buffer,
410-
iree_hal_buffer_ref_t buffer_ref) {
408+
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_advise_buffer(
409+
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_ref_t buffer_ref,
410+
iree_hal_memory_advise_flags_t flags, uint64_t arg0, uint64_t arg1) {
411411
IREE_ASSERT_ARGUMENT(command_buffer);
412412
IREE_TRACE_ZONE_BEGIN(z0);
413413
IF_VALIDATING(command_buffer, {
414414
IREE_RETURN_AND_END_ZONE_IF_ERROR(
415-
z0, iree_hal_command_buffer_discard_buffer_validation(
416-
command_buffer, VALIDATION_STATE(command_buffer), buffer_ref));
415+
z0, iree_hal_command_buffer_advise_buffer_validation(
416+
command_buffer, VALIDATION_STATE(command_buffer), buffer_ref,
417+
flags, arg0, arg1));
417418
});
418-
iree_status_t status = _VTABLE_DISPATCH(command_buffer, discard_buffer)(
419-
command_buffer, buffer_ref);
419+
iree_status_t status = _VTABLE_DISPATCH(command_buffer, advise_buffer)(
420+
command_buffer, buffer_ref, flags, arg0, arg1);
420421
IREE_TRACE_ZONE_END(z0);
421422
return status;
422423
}

runtime/src/iree/hal/command_buffer.h

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,16 @@ typedef struct iree_hal_buffer_barrier_t {
214214
iree_hal_buffer_ref_t buffer_ref;
215215
} iree_hal_buffer_barrier_t;
216216

217+
// Bitfield indicating advice for implementations managing a buffer.
218+
typedef uint64_t iree_hal_memory_advise_flags_t;
219+
enum iree_hal_memory_advise_flag_bits_t {
220+
IREE_HAL_MEMORY_ADVISE_FLAG_NONE = 0,
221+
// TODO(benvanik): cache control operations (invalidate/flush). arg0/arg1
222+
// could source/target queue affinities.
223+
// TODO(benvanik): prefetch and access type hints.
224+
// TODO(benvanik): ASAN hints (protect/unprotect).
225+
};
226+
217227
// Bitfield specifying flags controlling a fill operation.
218228
typedef uint64_t iree_hal_fill_flags_t;
219229
enum iree_hal_fill_flag_bits_t {
@@ -687,13 +697,12 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_wait_events(
687697
iree_host_size_t buffer_barrier_count,
688698
const iree_hal_buffer_barrier_t* buffer_barriers);
689699

690-
// Hints to the device queue that the given buffer will not be used again.
691-
// After encoding a discard the buffer contents will be considered undefined.
692-
// This is because the discard may be used to elide write backs to host memory
693-
// or aggressively reuse the allocation for other purposes.
694-
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_discard_buffer(
695-
iree_hal_command_buffer_t* command_buffer,
696-
iree_hal_buffer_ref_t buffer_ref);
700+
// Advises the device about the usage of the given buffer.
701+
// The device may use this information to perform cache management or ignore it
702+
// entirely.
703+
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_advise_buffer(
704+
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_ref_t buffer_ref,
705+
iree_hal_memory_advise_flags_t flags, uint64_t arg0, uint64_t arg1);
697706

698707
// Fills the target buffer with the given repeating value.
699708
// Expects that |pattern_length| is one of 1, 2, or 4 and that the offset and
@@ -892,9 +901,10 @@ typedef struct iree_hal_command_buffer_vtable_t {
892901
iree_host_size_t buffer_barrier_count,
893902
const iree_hal_buffer_barrier_t* buffer_barriers);
894903

895-
iree_status_t(IREE_API_PTR* discard_buffer)(
904+
iree_status_t(IREE_API_PTR* advise_buffer)(
896905
iree_hal_command_buffer_t* command_buffer,
897-
iree_hal_buffer_ref_t buffer_ref);
906+
iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags,
907+
uint64_t arg0, uint64_t arg1);
898908

899909
iree_status_t(IREE_API_PTR* fill_buffer)(
900910
iree_hal_command_buffer_t* command_buffer,

runtime/src/iree/hal/command_buffer_validation.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,10 +331,11 @@ iree_status_t iree_hal_command_buffer_wait_events_validation(
331331
return iree_ok_status();
332332
}
333333

334-
iree_status_t iree_hal_command_buffer_discard_buffer_validation(
334+
iree_status_t iree_hal_command_buffer_advise_buffer_validation(
335335
iree_hal_command_buffer_t* command_buffer,
336336
iree_hal_command_buffer_validation_state_t* validation_state,
337-
iree_hal_buffer_ref_t buffer_ref) {
337+
iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags,
338+
uint64_t arg0, uint64_t arg1) {
338339
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
339340
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_TRANSFER));
340341

runtime/src/iree/hal/command_buffer_validation.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,11 @@ iree_status_t iree_hal_command_buffer_wait_events_validation(
9999
iree_host_size_t buffer_barrier_count,
100100
const iree_hal_buffer_barrier_t* buffer_barriers);
101101

102-
iree_status_t iree_hal_command_buffer_discard_buffer_validation(
102+
iree_status_t iree_hal_command_buffer_advise_buffer_validation(
103103
iree_hal_command_buffer_t* command_buffer,
104104
iree_hal_command_buffer_validation_state_t* validation_state,
105-
iree_hal_buffer_ref_t buffer_ref);
105+
iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags,
106+
uint64_t arg0, uint64_t arg1);
106107

107108
iree_status_t iree_hal_command_buffer_fill_buffer_validation(
108109
iree_hal_command_buffer_t* command_buffer,

runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,10 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_wait_events(
477477
return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported");
478478
}
479479

480-
static iree_status_t iree_hal_cuda_graph_command_buffer_discard_buffer(
480+
static iree_status_t iree_hal_cuda_graph_command_buffer_advise_buffer(
481481
iree_hal_command_buffer_t* base_command_buffer,
482-
iree_hal_buffer_ref_t buffer_ref) {
482+
iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags,
483+
uint64_t arg0, uint64_t arg1) {
483484
// We could mark the memory as invalidated so that if this is a managed buffer
484485
// CUDA does not try to copy it back to the host.
485486
return iree_ok_status();
@@ -849,7 +850,7 @@ static const iree_hal_command_buffer_vtable_t
849850
.signal_event = iree_hal_cuda_graph_command_buffer_signal_event,
850851
.reset_event = iree_hal_cuda_graph_command_buffer_reset_event,
851852
.wait_events = iree_hal_cuda_graph_command_buffer_wait_events,
852-
.discard_buffer = iree_hal_cuda_graph_command_buffer_discard_buffer,
853+
.advise_buffer = iree_hal_cuda_graph_command_buffer_advise_buffer,
853854
.fill_buffer = iree_hal_cuda_graph_command_buffer_fill_buffer,
854855
.update_buffer = iree_hal_cuda_graph_command_buffer_update_buffer,
855856
.copy_buffer = iree_hal_cuda_graph_command_buffer_copy_buffer,

runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,10 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_wait_events(
308308
return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported");
309309
}
310310

311-
static iree_status_t iree_hal_cuda_stream_command_buffer_discard_buffer(
311+
static iree_status_t iree_hal_cuda_stream_command_buffer_advise_buffer(
312312
iree_hal_command_buffer_t* base_command_buffer,
313-
iree_hal_buffer_ref_t buffer_ref) {
313+
iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags,
314+
uint64_t arg0, uint64_t arg1) {
314315
// We could mark the memory as invalidated so that if managed CUDA does not
315316
// try to copy it back to the host.
316317
return iree_ok_status();
@@ -601,7 +602,7 @@ static const iree_hal_command_buffer_vtable_t
601602
.signal_event = iree_hal_cuda_stream_command_buffer_signal_event,
602603
.reset_event = iree_hal_cuda_stream_command_buffer_reset_event,
603604
.wait_events = iree_hal_cuda_stream_command_buffer_wait_events,
604-
.discard_buffer = iree_hal_cuda_stream_command_buffer_discard_buffer,
605+
.advise_buffer = iree_hal_cuda_stream_command_buffer_advise_buffer,
605606
.fill_buffer = iree_hal_cuda_stream_command_buffer_fill_buffer,
606607
.update_buffer = iree_hal_cuda_stream_command_buffer_update_buffer,
607608
.copy_buffer = iree_hal_cuda_stream_command_buffer_copy_buffer,

runtime/src/iree/hal/drivers/hip/graph_command_buffer.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -486,9 +486,10 @@ static iree_status_t iree_hal_hip_graph_command_buffer_wait_events(
486486
return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported");
487487
}
488488

489-
static iree_status_t iree_hal_hip_graph_command_buffer_discard_buffer(
489+
static iree_status_t iree_hal_hip_graph_command_buffer_advise_buffer(
490490
iree_hal_command_buffer_t* base_command_buffer,
491-
iree_hal_buffer_ref_t buffer_ref) {
491+
iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags,
492+
uint64_t arg0, uint64_t arg1) {
492493
// We could mark the memory as invalidated so that if this is a managed buffer
493494
// HIP does not try to copy it back to the host.
494495
return iree_ok_status();
@@ -858,7 +859,7 @@ static const iree_hal_command_buffer_vtable_t
858859
.signal_event = iree_hal_hip_graph_command_buffer_signal_event,
859860
.reset_event = iree_hal_hip_graph_command_buffer_reset_event,
860861
.wait_events = iree_hal_hip_graph_command_buffer_wait_events,
861-
.discard_buffer = iree_hal_hip_graph_command_buffer_discard_buffer,
862+
.advise_buffer = iree_hal_hip_graph_command_buffer_advise_buffer,
862863
.fill_buffer = iree_hal_hip_graph_command_buffer_fill_buffer,
863864
.update_buffer = iree_hal_hip_graph_command_buffer_update_buffer,
864865
.copy_buffer = iree_hal_hip_graph_command_buffer_copy_buffer,

runtime/src/iree/hal/drivers/hip/stream_command_buffer.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,10 @@ static iree_status_t iree_hal_hip_stream_command_buffer_wait_events(
299299
return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported");
300300
}
301301

302-
static iree_status_t iree_hal_hip_stream_command_buffer_discard_buffer(
302+
static iree_status_t iree_hal_hip_stream_command_buffer_advise_buffer(
303303
iree_hal_command_buffer_t* base_command_buffer,
304-
iree_hal_buffer_ref_t buffer_ref) {
304+
iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags,
305+
uint64_t arg0, uint64_t arg1) {
305306
// We could mark the memory as invalidated so that if managed HIP does not
306307
// try to copy it back to the host.
307308
return iree_ok_status();
@@ -590,7 +591,7 @@ static const iree_hal_command_buffer_vtable_t
590591
.signal_event = iree_hal_hip_stream_command_buffer_signal_event,
591592
.reset_event = iree_hal_hip_stream_command_buffer_reset_event,
592593
.wait_events = iree_hal_hip_stream_command_buffer_wait_events,
593-
.discard_buffer = iree_hal_hip_stream_command_buffer_discard_buffer,
594+
.advise_buffer = iree_hal_hip_stream_command_buffer_advise_buffer,
594595
.fill_buffer = iree_hal_hip_stream_command_buffer_fill_buffer,
595596
.update_buffer = iree_hal_hip_stream_command_buffer_update_buffer,
596597
.copy_buffer = iree_hal_hip_stream_command_buffer_copy_buffer,

runtime/src/iree/hal/drivers/local_task/task_command_buffer.c

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -455,12 +455,13 @@ static iree_status_t iree_hal_task_command_buffer_wait_events(
455455
}
456456

457457
//===----------------------------------------------------------------------===//
458-
// iree_hal_command_buffer_discard_buffer
458+
// iree_hal_command_buffer_advise_buffer
459459
//===----------------------------------------------------------------------===//
460460

461-
static iree_status_t iree_hal_task_command_buffer_discard_buffer(
461+
static iree_status_t iree_hal_task_command_buffer_advise_buffer(
462462
iree_hal_command_buffer_t* base_command_buffer,
463-
iree_hal_buffer_ref_t buffer_ref) {
463+
iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags,
464+
uint64_t arg0, uint64_t arg1) {
464465
return iree_ok_status();
465466
}
466467

@@ -948,7 +949,7 @@ static const iree_hal_command_buffer_vtable_t
948949
.signal_event = iree_hal_task_command_buffer_signal_event,
949950
.reset_event = iree_hal_task_command_buffer_reset_event,
950951
.wait_events = iree_hal_task_command_buffer_wait_events,
951-
.discard_buffer = iree_hal_task_command_buffer_discard_buffer,
952+
.advise_buffer = iree_hal_task_command_buffer_advise_buffer,
952953
.fill_buffer = iree_hal_task_command_buffer_fill_buffer,
953954
.update_buffer = iree_hal_task_command_buffer_update_buffer,
954955
.copy_buffer = iree_hal_task_command_buffer_copy_buffer,

0 commit comments

Comments
 (0)