Skip to content

Commit fa3a144

Browse files
authored
Adding iree_hal_device_queue_update and improving queue DMA operations. (#19000)
As with all queue DMA operations it's best if things are batched into command buffers but it's bad to have a command buffer with a single DMA operation - this completes the set of fill/update/copy operations at the queue level to match the command buffer DMA operations. Practically this is useful when combined with reusable/indirect command buffers for uploading new parameters in queue order prior to issuing a command buffer that references them. The compiler will use this to turn push constants into uniform buffers. An emulated version is added but implementations are encouraged to do better... they currently don't. While updating the queue API I've added placeholder flags to all DMA operations in preparation for compiler updates that will provide them. `iree_hal_device_queue_execute` has needed simplification for awhile and that's done here to allow implementations to not need to worry with batched command buffer juggling. The unused-since-its-inception `iree_hal_command_buffer_discard_buffer` API has been renamed to `iree_hal_command_buffer_advise_buffer` ahead of compiler changes that will use it for multi-device cache management. No breaking changes to the compiler here - future PRs will update the HAL module and ops.
2 parents 2b1a8e7 + 1e1a6e3 commit fa3a144

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+887
-604
lines changed

experimental/web/sample_webgpu/main.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,8 @@ static iree_status_t process_call_outputs(
794794
};
795795
status = iree_hal_device_queue_execute(
796796
device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
797-
signal_semaphores, 1, &transfer_command_buffer,
798-
/*binding_tables=*/NULL);
797+
signal_semaphores, transfer_command_buffer,
798+
iree_hal_buffer_binding_table_empty());
799799
}
800800
// TODO(scotttodd): Make this async - pass a wait source to iree_loop_wait_one
801801
// 1. create iree_hal_fence_t, iree_hal_fence_insert(fance, semaphore)

experimental/webgpu/command_buffer.c

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -575,9 +575,10 @@ static iree_status_t iree_hal_webgpu_command_buffer_wait_events(
575575
return iree_ok_status();
576576
}
577577

578-
static iree_status_t iree_hal_webgpu_command_buffer_discard_buffer(
578+
static iree_status_t iree_hal_webgpu_command_buffer_advise_buffer(
579579
iree_hal_command_buffer_t* base_command_buffer,
580-
iree_hal_buffer_ref_t buffer_ref) {
580+
iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags,
581+
uint64_t arg0, uint64_t arg1) {
581582
// No-op: though maybe it'd be a useful addition to the spec as otherwise
582583
// false dependencies can creep in.
583584
return iree_ok_status();
@@ -608,7 +609,7 @@ static uint32_t iree_hal_webgpu_splat_pattern(const void* pattern,
608609
static iree_status_t iree_hal_webgpu_command_buffer_fill_buffer(
609610
iree_hal_command_buffer_t* base_command_buffer,
610611
iree_hal_buffer_ref_t target_ref, const void* pattern,
611-
iree_host_size_t pattern_length) {
612+
iree_host_size_t pattern_length, iree_hal_fill_flags_t flags) {
612613
iree_hal_webgpu_command_buffer_t* command_buffer =
613614
iree_hal_webgpu_command_buffer_cast(base_command_buffer);
614615

@@ -693,7 +694,8 @@ static iree_status_t iree_hal_webgpu_command_buffer_fill_buffer(
693694

694695
static iree_status_t iree_hal_webgpu_command_buffer_update_buffer(
695696
iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer,
696-
iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref) {
697+
iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref,
698+
iree_hal_update_flags_t flags) {
697699
iree_hal_webgpu_command_buffer_t* command_buffer =
698700
iree_hal_webgpu_command_buffer_cast(base_command_buffer);
699701

@@ -734,7 +736,8 @@ static iree_status_t iree_hal_webgpu_command_buffer_update_buffer(
734736

735737
static iree_status_t iree_hal_webgpu_command_buffer_copy_buffer(
736738
iree_hal_command_buffer_t* base_command_buffer,
737-
iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref) {
739+
iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref,
740+
iree_hal_copy_flags_t flags) {
738741
iree_hal_webgpu_command_buffer_t* command_buffer =
739742
iree_hal_webgpu_command_buffer_cast(base_command_buffer);
740743

@@ -1041,7 +1044,7 @@ const iree_hal_command_buffer_vtable_t iree_hal_webgpu_command_buffer_vtable = {
10411044
.signal_event = iree_hal_webgpu_command_buffer_signal_event,
10421045
.reset_event = iree_hal_webgpu_command_buffer_reset_event,
10431046
.wait_events = iree_hal_webgpu_command_buffer_wait_events,
1044-
.discard_buffer = iree_hal_webgpu_command_buffer_discard_buffer,
1047+
.advise_buffer = iree_hal_webgpu_command_buffer_advise_buffer,
10451048
.fill_buffer = iree_hal_webgpu_command_buffer_fill_buffer,
10461049
.update_buffer = iree_hal_webgpu_command_buffer_update_buffer,
10471050
.copy_buffer = iree_hal_webgpu_command_buffer_copy_buffer,

experimental/webgpu/webgpu_device.c

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ static iree_status_t iree_hal_webgpu_device_queue_read(
354354
const iree_hal_semaphore_list_t signal_semaphore_list,
355355
iree_hal_file_t* source_file, uint64_t source_offset,
356356
iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
357-
iree_device_size_t length, uint32_t flags) {
357+
iree_device_size_t length, iree_hal_read_flags_t flags) {
358358
// TODO: expose streaming chunk count/size options.
359359
iree_status_t loop_status = iree_ok_status();
360360
iree_hal_file_transfer_options_t options = {
@@ -376,7 +376,7 @@ static iree_status_t iree_hal_webgpu_device_queue_write(
376376
const iree_hal_semaphore_list_t signal_semaphore_list,
377377
iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
378378
iree_hal_file_t* target_file, uint64_t target_offset,
379-
iree_device_size_t length, uint32_t flags) {
379+
iree_device_size_t length, iree_hal_write_flags_t flags) {
380380
// TODO: expose streaming chunk count/size options.
381381
iree_status_t loop_status = iree_ok_status();
382382
iree_hal_file_transfer_options_t options = {
@@ -396,9 +396,8 @@ static iree_status_t iree_hal_webgpu_device_queue_execute(
396396
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
397397
const iree_hal_semaphore_list_t wait_semaphore_list,
398398
const iree_hal_semaphore_list_t signal_semaphore_list,
399-
iree_host_size_t command_buffer_count,
400-
iree_hal_command_buffer_t* const* command_buffers,
401-
iree_hal_buffer_binding_table_t const* binding_tables) {
399+
iree_hal_command_buffer_t* command_buffer,
400+
iree_hal_buffer_binding_table_t binding_table) {
402401
iree_hal_webgpu_device_t* device = iree_hal_webgpu_device_cast(base_device);
403402

404403
// TODO(benvanik): this currently assumes we are synchronizing on semaphores
@@ -410,11 +409,8 @@ static iree_status_t iree_hal_webgpu_device_queue_execute(
410409
iree_infinite_timeout()));
411410

412411
// TODO(benvanik): propagate errors to semaphores.
413-
for (iree_host_size_t i = 0; i < command_buffer_count; i++) {
414-
iree_hal_command_buffer_t* command_buffer = command_buffers[i];
415-
IREE_RETURN_IF_ERROR(
416-
iree_hal_webgpu_command_buffer_issue(command_buffer, device->queue));
417-
}
412+
IREE_RETURN_IF_ERROR(
413+
iree_hal_webgpu_command_buffer_issue(command_buffer, device->queue));
418414

419415
IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_signal(signal_semaphore_list));
420416

@@ -473,6 +469,9 @@ const iree_hal_device_vtable_t iree_hal_webgpu_device_vtable = {
473469
iree_hal_webgpu_device_query_semaphore_compatibility,
474470
.queue_alloca = iree_hal_webgpu_device_queue_alloca,
475471
.queue_dealloca = iree_hal_webgpu_device_queue_dealloca,
472+
.queue_fill = iree_hal_device_queue_emulated_fill,
473+
.queue_update = iree_hal_device_queue_emulated_update,
474+
.queue_copy = iree_hal_device_queue_emulated_copy,
476475
.queue_read = iree_hal_webgpu_device_queue_read,
477476
.queue_write = iree_hal_webgpu_device_queue_write,
478477
.queue_execute = iree_hal_webgpu_device_queue_execute,

integrations/pjrt/src/iree_pjrt/common/api_impl.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -590,8 +590,8 @@ iree_status_t BufferInstance::CopyToHost(void* dst, iree_host_size_t dst_size,
590590
device_.device(), IREE_HAL_QUEUE_AFFINITY_ANY,
591591
/*wait_semaphore_list=*/iree_hal_fence_semaphore_list(ready_fence_.get()),
592592
/*signal_semaphore_list=*/
593-
iree_hal_fence_semaphore_list(dst_buffer_ready_fence.get()),
594-
/*command_buffer_count=*/1, &transfer_cb, NULL));
593+
iree_hal_fence_semaphore_list(dst_buffer_ready_fence.get()), transfer_cb,
594+
iree_hal_buffer_binding_table_empty()));
595595

596596
*out_done_event = copy_done_event;
597597
return iree_ok_status();
@@ -837,7 +837,8 @@ iree_status_t DeviceInstance::HostBufferToDeviceSplat(
837837
IREE_CHECK_OK(iree_hal_command_buffer_begin(transfer_cb.get()));
838838
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_fill_buffer(
839839
transfer_cb.get(), buffer.get(), /*target_offset=*/0,
840-
/*target_size=*/byte_length, data, element_type_byte_size));
840+
/*target_size=*/byte_length, data, element_type_byte_size,
841+
IREE_HAL_FILL_FLAG_NONE));
841842
IREE_CHECK_OK(iree_hal_command_buffer_end(transfer_cb.get()));
842843

843844
// Execute the enqueued splat:
@@ -846,8 +847,8 @@ iree_status_t DeviceInstance::HostBufferToDeviceSplat(
846847
/*wait_semaphore_list=*/
847848
{1, &transfer_timeline_, &signal_alloca_complete},
848849
/*signal_semaphore_list=*/
849-
{1, &transfer_timeline_, &signal_copy_complete},
850-
/*command_buffer_count=*/1, &transfer_cb, NULL));
850+
{1, &transfer_timeline_, &signal_copy_complete}, transfer_cb,
851+
iree_hal_buffer_binding_table_empty()));
851852

852853
// Wrap in a buffer view and return:
853854
iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
@@ -1190,8 +1191,8 @@ iree_status_t DeviceInstance::HostBufferToDevice(
11901191
/*wait_semaphore_list=*/
11911192
{1, &transfer_timeline_, &signal_alloca_complete},
11921193
/*signal_semaphore_list=*/
1193-
{1, &transfer_timeline_, &signal_copy_complete},
1194-
/*command_buffer_count=*/1, &transfer_cb, NULL));
1194+
{1, &transfer_timeline_, &signal_copy_complete}, transfer_cb,
1195+
iree_hal_buffer_binding_table_empty()));
11951196

11961197
// Wrap in a buffer view and return.
11971198
iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;

integrations/pjrt/src/iree_pjrt/common/iree_helpers.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,17 +139,16 @@ iree_status_t hal_device_queue_execute(
139139
iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
140140
const iree_hal_semaphore_list_t wait_semaphore_list,
141141
const iree_hal_semaphore_list_t signal_semaphore_list,
142-
iree_host_size_t command_buffer_count,
143-
iree_hal_command_buffer_t* const* command_buffers) {
142+
iree_hal_command_buffer_t* command_buffer) {
144143
if (LOGGING_ENABLED) {
145144
LogInvoke(__func__, "device=%p, wait={%s}, signal={%s}", device,
146145
SemaphoreListToString(wait_semaphore_list).c_str(),
147146
SemaphoreListToString(signal_semaphore_list).c_str());
148147
}
149148
return HandleStatus(__func__, iree_hal_device_queue_execute(
150149
device, queue_affinity, wait_semaphore_list,
151-
signal_semaphore_list, command_buffer_count,
152-
command_buffers, /*binding_tables=*/NULL));
150+
signal_semaphore_list, command_buffer,
151+
iree_hal_buffer_binding_table_empty()));
153152
}
154153

155154
iree_status_t hal_fence_create(iree_host_size_t capacity,

runtime/bindings/python/hal.cc

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ void HalDevice::QueueDealloca(HalBuffer& buffer, py::handle wait_semaphores,
499499
"deallocating memory on queue");
500500
}
501501

502-
void HalDevice::QueueExecute(py::handle command_buffers,
502+
void HalDevice::QueueExecute(py::handle command_buffer,
503503
py::handle wait_semaphores,
504504
py::handle signal_semaphores) {
505505
iree_hal_semaphore_list_t wait_list;
@@ -548,17 +548,14 @@ void HalDevice::QueueExecute(py::handle command_buffers,
548548
}
549549

550550
// Unpack command buffers.
551-
size_t cb_count = py::len(command_buffers);
552-
iree_hal_command_buffer_t** cb_list =
553-
static_cast<iree_hal_command_buffer_t**>(
554-
alloca(sizeof(iree_hal_command_buffer_t*) * cb_count));
555-
for (size_t i = 0; i < cb_count; ++i) {
556-
cb_list[i] = py::cast<HalCommandBuffer*>(command_buffers[i])->raw_ptr();
557-
}
551+
iree_hal_command_buffer_t* cb =
552+
!command_buffer.is_none()
553+
? py::cast<HalCommandBuffer*>(command_buffer)->raw_ptr()
554+
: NULL;
558555

559556
CheckApiStatus(iree_hal_device_queue_execute(
560557
raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list,
561-
signal_list, cb_count, cb_list, /*binding_tables=*/NULL),
558+
signal_list, cb, iree_hal_buffer_binding_table_empty()),
562559
"executing command buffers");
563560
}
564561

@@ -619,11 +616,12 @@ void HalDevice::QueueCopy(HalBuffer& source_buffer, HalBuffer& target_buffer,
619616
"Source and buffer length must be less than the target buffer length "
620617
"and it does not. Please check allocations");
621618
}
622-
CheckApiStatus(iree_hal_device_queue_copy(
623-
raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list,
624-
signal_list, source_buffer.raw_ptr(), 0,
625-
target_buffer.raw_ptr(), 0, source_length),
626-
"Copying buffer on queue");
619+
CheckApiStatus(
620+
iree_hal_device_queue_copy(
621+
raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list, signal_list,
622+
source_buffer.raw_ptr(), 0, target_buffer.raw_ptr(), 0, source_length,
623+
IREE_HAL_COPY_FLAG_NONE),
624+
"Copying buffer on queue");
627625
}
628626

629627
py::object HalDevice::CreateDLPackCapsule(HalBufferView& buffer_view,
@@ -1729,7 +1727,8 @@ void SetupHalBindings(nanobind::module_ m) {
17291727
iree_hal_make_buffer_ref(source_buffer.raw_ptr(),
17301728
source_offset, resolved_length),
17311729
iree_hal_make_buffer_ref(target_buffer.raw_ptr(),
1732-
target_offset, resolved_length)),
1730+
target_offset, resolved_length),
1731+
IREE_HAL_COPY_FLAG_NONE),
17331732
"copy command");
17341733
if (end) {
17351734
CheckApiStatus(iree_hal_command_buffer_end(self.raw_ptr()),
@@ -1767,7 +1766,8 @@ void SetupHalBindings(nanobind::module_ m) {
17671766
self.raw_ptr(),
17681767
iree_hal_make_buffer_ref(target_buffer.raw_ptr(),
17691768
target_offset, resolved_length),
1770-
pattern_view.buf, pattern_view.len),
1769+
pattern_view.buf, pattern_view.len,
1770+
IREE_HAL_FILL_FLAG_NONE),
17711771
"command buffer fill");
17721772
if (end) {
17731773
CheckApiStatus(iree_hal_command_buffer_end(self.raw_ptr()),

runtime/bindings/python/iree/runtime/_binding.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class HalDevice:
185185
) -> None: ...
186186
def queue_execute(
187187
self,
188-
command_buffers: Sequence[HalCommandBuffer],
188+
command_buffer: HalCommandBuffer,
189189
wait_semaphores: HalSemaphoreList,
190190
signal_semaphores: HalSemaphoreList,
191191
) -> None: ...

runtime/bindings/python/tests/hal_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def testCommandBufferExecute(self):
463463

464464
sem = self.device.create_semaphore(0)
465465
self.device.queue_execute(
466-
[cb], wait_semaphores=[(sem, 0)], signal_semaphores=[(sem, 1)]
466+
cb, wait_semaphores=[(sem, 0)], signal_semaphores=[(sem, 1)]
467467
)
468468
iree.runtime.HalFence.create_at(sem, 1).wait()
469469

@@ -479,7 +479,7 @@ def testCommandBufferExecuteAcceptsFence(self):
479479

480480
sem = self.device.create_semaphore(0)
481481
self.device.queue_execute(
482-
[cb],
482+
cb,
483483
wait_semaphores=iree.runtime.HalFence.create_at(sem, 0),
484484
signal_semaphores=iree.runtime.HalFence.create_at(sem, 1),
485485
)

runtime/src/iree/hal/buffer_transfer.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ static iree_status_t iree_hal_device_transfer_and_wait(
7878
};
7979
status = iree_hal_device_queue_execute(
8080
device, IREE_HAL_QUEUE_AFFINITY_ANY, wait_semaphores, signal_semaphores,
81-
1, &command_buffer, /*binding_tables=*/NULL);
81+
command_buffer, iree_hal_buffer_binding_table_empty());
8282
}
8383
if (iree_status_is_ok(status)) {
8484
status = iree_hal_semaphore_wait(fence_semaphore, signal_value, timeout);

0 commit comments

Comments
 (0)