Skip to content

Commit a1f8f6f

Browse files
benvanikGroverkss
authored andcommitted
Making iree_hal_device_queue_execute take zero or one command buffer.
The practical reason to take multiple is multi-threaded encoding or mixing those with different lifetimes (per-execution prefix/suffix command buffers around a reusable command buffer, etc) but we don't do this in practice and supporting only one makes it easier to implement against all APIs but Vulkan.
1 parent 38eb992 commit a1f8f6f

File tree

30 files changed

+270
-381
lines changed

30 files changed

+270
-381
lines changed

experimental/web/sample_webgpu/main.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,8 @@ static iree_status_t process_call_outputs(
794794
};
795795
status = iree_hal_device_queue_execute(
796796
device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
797-
signal_semaphores, 1, &transfer_command_buffer,
798-
/*binding_tables=*/NULL);
797+
signal_semaphores, transfer_command_buffer,
798+
iree_hal_buffer_binding_table_empty());
799799
}
800800
// TODO(scotttodd): Make this async - pass a wait source to iree_loop_wait_one
801801
// 1. create iree_hal_fence_t, iree_hal_fence_insert(fance, semaphore)

experimental/webgpu/webgpu_device.c

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,8 @@ static iree_status_t iree_hal_webgpu_device_queue_execute(
396396
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
397397
const iree_hal_semaphore_list_t wait_semaphore_list,
398398
const iree_hal_semaphore_list_t signal_semaphore_list,
399-
iree_host_size_t command_buffer_count,
400-
iree_hal_command_buffer_t* const* command_buffers,
401-
iree_hal_buffer_binding_table_t const* binding_tables) {
399+
iree_hal_command_buffer_t* command_buffer,
400+
iree_hal_buffer_binding_table_t binding_table) {
402401
iree_hal_webgpu_device_t* device = iree_hal_webgpu_device_cast(base_device);
403402

404403
// TODO(benvanik): this currently assumes we are synchronizing on semaphores
@@ -410,11 +409,8 @@ static iree_status_t iree_hal_webgpu_device_queue_execute(
410409
iree_infinite_timeout()));
411410

412411
// TODO(benvanik): propagate errors to semaphores.
413-
for (iree_host_size_t i = 0; i < command_buffer_count; i++) {
414-
iree_hal_command_buffer_t* command_buffer = command_buffers[i];
415-
IREE_RETURN_IF_ERROR(
416-
iree_hal_webgpu_command_buffer_issue(command_buffer, device->queue));
417-
}
412+
IREE_RETURN_IF_ERROR(
413+
iree_hal_webgpu_command_buffer_issue(command_buffer, device->queue));
418414

419415
IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_signal(signal_semaphore_list));
420416

integrations/pjrt/src/iree_pjrt/common/api_impl.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -590,8 +590,8 @@ iree_status_t BufferInstance::CopyToHost(void* dst, iree_host_size_t dst_size,
590590
device_.device(), IREE_HAL_QUEUE_AFFINITY_ANY,
591591
/*wait_semaphore_list=*/iree_hal_fence_semaphore_list(ready_fence_.get()),
592592
/*signal_semaphore_list=*/
593-
iree_hal_fence_semaphore_list(dst_buffer_ready_fence.get()),
594-
/*command_buffer_count=*/1, &transfer_cb, NULL));
593+
iree_hal_fence_semaphore_list(dst_buffer_ready_fence.get()), transfer_cb,
594+
iree_hal_buffer_binding_table_empty()));
595595

596596
*out_done_event = copy_done_event;
597597
return iree_ok_status();
@@ -847,8 +847,8 @@ iree_status_t DeviceInstance::HostBufferToDeviceSplat(
847847
/*wait_semaphore_list=*/
848848
{1, &transfer_timeline_, &signal_alloca_complete},
849849
/*signal_semaphore_list=*/
850-
{1, &transfer_timeline_, &signal_copy_complete},
851-
/*command_buffer_count=*/1, &transfer_cb, NULL));
850+
{1, &transfer_timeline_, &signal_copy_complete}, transfer_cb,
851+
iree_hal_buffer_binding_table_empty()));
852852

853853
// Wrap in a buffer view and return:
854854
iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
@@ -1191,8 +1191,8 @@ iree_status_t DeviceInstance::HostBufferToDevice(
11911191
/*wait_semaphore_list=*/
11921192
{1, &transfer_timeline_, &signal_alloca_complete},
11931193
/*signal_semaphore_list=*/
1194-
{1, &transfer_timeline_, &signal_copy_complete},
1195-
/*command_buffer_count=*/1, &transfer_cb, NULL));
1194+
{1, &transfer_timeline_, &signal_copy_complete}, transfer_cb,
1195+
iree_hal_buffer_binding_table_empty()));
11961196

11971197
// Wrap in a buffer view and return.
11981198
iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;

integrations/pjrt/src/iree_pjrt/common/iree_helpers.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,17 +139,16 @@ iree_status_t hal_device_queue_execute(
139139
iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
140140
const iree_hal_semaphore_list_t wait_semaphore_list,
141141
const iree_hal_semaphore_list_t signal_semaphore_list,
142-
iree_host_size_t command_buffer_count,
143-
iree_hal_command_buffer_t* const* command_buffers) {
142+
iree_hal_command_buffer_t* command_buffer) {
144143
if (LOGGING_ENABLED) {
145144
LogInvoke(__func__, "device=%p, wait={%s}, signal={%s}", device,
146145
SemaphoreListToString(wait_semaphore_list).c_str(),
147146
SemaphoreListToString(signal_semaphore_list).c_str());
148147
}
149148
return HandleStatus(__func__, iree_hal_device_queue_execute(
150149
device, queue_affinity, wait_semaphore_list,
151-
signal_semaphore_list, command_buffer_count,
152-
command_buffers, /*binding_tables=*/NULL));
150+
signal_semaphore_list, command_buffer,
151+
iree_hal_buffer_binding_table_empty()));
153152
}
154153

155154
iree_status_t hal_fence_create(iree_host_size_t capacity,

runtime/bindings/python/hal.cc

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ void HalDevice::QueueDealloca(HalBuffer& buffer, py::handle wait_semaphores,
499499
"deallocating memory on queue");
500500
}
501501

502-
void HalDevice::QueueExecute(py::handle command_buffers,
502+
void HalDevice::QueueExecute(py::handle command_buffer,
503503
py::handle wait_semaphores,
504504
py::handle signal_semaphores) {
505505
iree_hal_semaphore_list_t wait_list;
@@ -548,17 +548,14 @@ void HalDevice::QueueExecute(py::handle command_buffers,
548548
}
549549

550550
// Unpack command buffers.
551-
size_t cb_count = py::len(command_buffers);
552-
iree_hal_command_buffer_t** cb_list =
553-
static_cast<iree_hal_command_buffer_t**>(
554-
alloca(sizeof(iree_hal_command_buffer_t*) * cb_count));
555-
for (size_t i = 0; i < cb_count; ++i) {
556-
cb_list[i] = py::cast<HalCommandBuffer*>(command_buffers[i])->raw_ptr();
557-
}
551+
iree_hal_command_buffer_t* cb =
552+
!command_buffer.is_none()
553+
? py::cast<HalCommandBuffer*>(command_buffer)->raw_ptr()
554+
: NULL;
558555

559556
CheckApiStatus(iree_hal_device_queue_execute(
560557
raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list,
561-
signal_list, cb_count, cb_list, /*binding_tables=*/NULL),
558+
signal_list, cb, iree_hal_buffer_binding_table_empty()),
562559
"executing command buffers");
563560
}
564561

runtime/bindings/python/iree/runtime/_binding.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class HalDevice:
185185
) -> None: ...
186186
def queue_execute(
187187
self,
188-
command_buffers: Sequence[HalCommandBuffer],
188+
command_buffer: HalCommandBuffer,
189189
wait_semaphores: HalSemaphoreList,
190190
signal_semaphores: HalSemaphoreList,
191191
) -> None: ...

runtime/bindings/python/tests/hal_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def testCommandBufferExecute(self):
463463

464464
sem = self.device.create_semaphore(0)
465465
self.device.queue_execute(
466-
[cb], wait_semaphores=[(sem, 0)], signal_semaphores=[(sem, 1)]
466+
cb, wait_semaphores=[(sem, 0)], signal_semaphores=[(sem, 1)]
467467
)
468468
iree.runtime.HalFence.create_at(sem, 1).wait()
469469

@@ -479,7 +479,7 @@ def testCommandBufferExecuteAcceptsFence(self):
479479

480480
sem = self.device.create_semaphore(0)
481481
self.device.queue_execute(
482-
[cb],
482+
cb,
483483
wait_semaphores=iree.runtime.HalFence.create_at(sem, 0),
484484
signal_semaphores=iree.runtime.HalFence.create_at(sem, 1),
485485
)

runtime/src/iree/hal/buffer_transfer.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ static iree_status_t iree_hal_device_transfer_and_wait(
7878
};
7979
status = iree_hal_device_queue_execute(
8080
device, IREE_HAL_QUEUE_AFFINITY_ANY, wait_semaphores, signal_semaphores,
81-
1, &command_buffer, /*binding_tables=*/NULL);
81+
command_buffer, iree_hal_buffer_binding_table_empty());
8282
}
8383
if (iree_status_is_ok(status)) {
8484
status = iree_hal_semaphore_wait(fence_semaphore, signal_value, timeout);

runtime/src/iree/hal/command_buffer.c

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect(
592592

593593
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_validate_submission(
594594
iree_hal_command_buffer_t* command_buffer,
595-
const iree_hal_buffer_binding_table_t* binding_table) {
595+
iree_hal_buffer_binding_table_t binding_table) {
596596
IREE_ASSERT_ARGUMENT(command_buffer);
597597

598598
// Validate the command buffer has been recorded properly.
@@ -607,25 +607,24 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_validate_submission(
607607
// the command buffer was allocated with.
608608
if (command_buffer->binding_count == 0) {
609609
return iree_ok_status();
610-
} else if (!binding_table) {
610+
} else if (binding_table.count == 0) {
611611
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
612612
"indirect command buffer requires at least %u "
613613
"bindings but no binding table was provided",
614614
command_buffer->binding_count);
615-
} else if (binding_table->count < command_buffer->binding_count) {
615+
} else if (binding_table.count < command_buffer->binding_count) {
616616
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
617617
"indirect command buffer requires at least %u "
618618
"bindings but only %" PRIhsz " were provided ",
619-
command_buffer->binding_count,
620-
binding_table->count);
619+
command_buffer->binding_count, binding_table.count);
621620
}
622621

623622
// Validate the binding table against the commands consuming them.
624623
// This is O(binding_count) so something we only do if validation is
625624
// requested on the command buffer.
626625
IF_VALIDATING(command_buffer, {
627626
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_binding_table_validation(
628-
command_buffer, VALIDATION_STATE(command_buffer), *binding_table));
627+
command_buffer, VALIDATION_STATE(command_buffer), binding_table));
629628
});
630629

631630
return iree_ok_status();

runtime/src/iree/hal/command_buffer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect(
792792
// are used by the command buffer are provided they will be ignored.
793793
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_validate_submission(
794794
iree_hal_command_buffer_t* command_buffer,
795-
const iree_hal_buffer_binding_table_t* binding_table);
795+
iree_hal_buffer_binding_table_t binding_table);
796796

797797
//===----------------------------------------------------------------------===//
798798
// Utilities for command buffer creation

0 commit comments

Comments
 (0)