Skip to content

Commit 1ca7baf

Browse files
benvanikGroverkss
authored andcommitted
Adding flags to fill/update/copy commands and vtabling queue fill/copy.
The flags are unused but may be in the future to provide hints on caching or distribution. The vtable queue fill/copy operations allow implementations that can implement them more efficiently to do so.
1 parent 2b1a8e7 commit 1ca7baf

37 files changed

+346
-136
lines changed

experimental/webgpu/command_buffer.c

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ static uint32_t iree_hal_webgpu_splat_pattern(const void* pattern,
608608
static iree_status_t iree_hal_webgpu_command_buffer_fill_buffer(
609609
iree_hal_command_buffer_t* base_command_buffer,
610610
iree_hal_buffer_ref_t target_ref, const void* pattern,
611-
iree_host_size_t pattern_length) {
611+
iree_host_size_t pattern_length, iree_hal_fill_flags_t flags) {
612612
iree_hal_webgpu_command_buffer_t* command_buffer =
613613
iree_hal_webgpu_command_buffer_cast(base_command_buffer);
614614

@@ -693,7 +693,8 @@ static iree_status_t iree_hal_webgpu_command_buffer_fill_buffer(
693693

694694
static iree_status_t iree_hal_webgpu_command_buffer_update_buffer(
695695
iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer,
696-
iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref) {
696+
iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref,
697+
iree_hal_update_flags_t flags) {
697698
iree_hal_webgpu_command_buffer_t* command_buffer =
698699
iree_hal_webgpu_command_buffer_cast(base_command_buffer);
699700

@@ -734,7 +735,8 @@ static iree_status_t iree_hal_webgpu_command_buffer_update_buffer(
734735

735736
static iree_status_t iree_hal_webgpu_command_buffer_copy_buffer(
736737
iree_hal_command_buffer_t* base_command_buffer,
737-
iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref) {
738+
iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref,
739+
iree_hal_copy_flags_t flags) {
738740
iree_hal_webgpu_command_buffer_t* command_buffer =
739741
iree_hal_webgpu_command_buffer_cast(base_command_buffer);
740742

experimental/webgpu/webgpu_device.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ static iree_status_t iree_hal_webgpu_device_queue_read(
354354
const iree_hal_semaphore_list_t signal_semaphore_list,
355355
iree_hal_file_t* source_file, uint64_t source_offset,
356356
iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
357-
iree_device_size_t length, uint32_t flags) {
357+
iree_device_size_t length, iree_hal_read_flags_t flags) {
358358
// TODO: expose streaming chunk count/size options.
359359
iree_status_t loop_status = iree_ok_status();
360360
iree_hal_file_transfer_options_t options = {
@@ -376,7 +376,7 @@ static iree_status_t iree_hal_webgpu_device_queue_write(
376376
const iree_hal_semaphore_list_t signal_semaphore_list,
377377
iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
378378
iree_hal_file_t* target_file, uint64_t target_offset,
379-
iree_device_size_t length, uint32_t flags) {
379+
iree_device_size_t length, iree_hal_write_flags_t flags) {
380380
// TODO: expose streaming chunk count/size options.
381381
iree_status_t loop_status = iree_ok_status();
382382
iree_hal_file_transfer_options_t options = {
@@ -473,6 +473,8 @@ const iree_hal_device_vtable_t iree_hal_webgpu_device_vtable = {
473473
iree_hal_webgpu_device_query_semaphore_compatibility,
474474
.queue_alloca = iree_hal_webgpu_device_queue_alloca,
475475
.queue_dealloca = iree_hal_webgpu_device_queue_dealloca,
476+
.queue_fill = iree_hal_device_queue_emulated_fill,
477+
.queue_copy = iree_hal_device_queue_emulated_copy,
476478
.queue_read = iree_hal_webgpu_device_queue_read,
477479
.queue_write = iree_hal_webgpu_device_queue_write,
478480
.queue_execute = iree_hal_webgpu_device_queue_execute,

integrations/pjrt/src/iree_pjrt/common/api_impl.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,8 @@ iree_status_t DeviceInstance::HostBufferToDeviceSplat(
837837
IREE_CHECK_OK(iree_hal_command_buffer_begin(transfer_cb.get()));
838838
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_fill_buffer(
839839
transfer_cb.get(), buffer.get(), /*target_offset=*/0,
840-
/*target_size=*/byte_length, data, element_type_byte_size));
840+
/*target_size=*/byte_length, data, element_type_byte_size,
841+
IREE_HAL_FILL_FLAG_NONE));
841842
IREE_CHECK_OK(iree_hal_command_buffer_end(transfer_cb.get()));
842843

843844
// Execute the enqueued splat:

runtime/bindings/python/hal.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -619,11 +619,12 @@ void HalDevice::QueueCopy(HalBuffer& source_buffer, HalBuffer& target_buffer,
619619
"Source and buffer length must be less than the target buffer length "
620620
"and it does not. Please check allocations");
621621
}
622-
CheckApiStatus(iree_hal_device_queue_copy(
623-
raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list,
624-
signal_list, source_buffer.raw_ptr(), 0,
625-
target_buffer.raw_ptr(), 0, source_length),
626-
"Copying buffer on queue");
622+
CheckApiStatus(
623+
iree_hal_device_queue_copy(
624+
raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list, signal_list,
625+
source_buffer.raw_ptr(), 0, target_buffer.raw_ptr(), 0, source_length,
626+
IREE_HAL_COPY_FLAG_NONE),
627+
"Copying buffer on queue");
627628
}
628629

629630
py::object HalDevice::CreateDLPackCapsule(HalBufferView& buffer_view,
@@ -1729,7 +1730,8 @@ void SetupHalBindings(nanobind::module_ m) {
17291730
iree_hal_make_buffer_ref(source_buffer.raw_ptr(),
17301731
source_offset, resolved_length),
17311732
iree_hal_make_buffer_ref(target_buffer.raw_ptr(),
1732-
target_offset, resolved_length)),
1733+
target_offset, resolved_length),
1734+
IREE_HAL_COPY_FLAG_NONE),
17331735
"copy command");
17341736
if (end) {
17351737
CheckApiStatus(iree_hal_command_buffer_end(self.raw_ptr()),
@@ -1767,7 +1769,8 @@ void SetupHalBindings(nanobind::module_ m) {
17671769
self.raw_ptr(),
17681770
iree_hal_make_buffer_ref(target_buffer.raw_ptr(),
17691771
target_offset, resolved_length),
1770-
pattern_view.buf, pattern_view.len),
1772+
pattern_view.buf, pattern_view.len,
1773+
IREE_HAL_FILL_FLAG_NONE),
17711774
"command buffer fill");
17721775
if (end) {
17731776
CheckApiStatus(iree_hal_command_buffer_end(self.raw_ptr()),

runtime/src/iree/hal/command_buffer.c

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,8 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_discard_buffer(
423423

424424
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_fill_buffer(
425425
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_ref_t target_ref,
426-
const void* pattern, iree_host_size_t pattern_length) {
426+
const void* pattern, iree_host_size_t pattern_length,
427+
iree_hal_fill_flags_t flags) {
427428
IREE_ASSERT_ARGUMENT(command_buffer);
428429
if (target_ref.length == 0) {
429430
// No-op fill. All other validation is skipped.
@@ -434,17 +435,18 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_fill_buffer(
434435
IREE_RETURN_AND_END_ZONE_IF_ERROR(
435436
z0, iree_hal_command_buffer_fill_buffer_validation(
436437
command_buffer, VALIDATION_STATE(command_buffer), target_ref,
437-
pattern, pattern_length));
438+
pattern, pattern_length, flags));
438439
});
439440
iree_status_t status = _VTABLE_DISPATCH(command_buffer, fill_buffer)(
440-
command_buffer, target_ref, pattern, pattern_length);
441+
command_buffer, target_ref, pattern, pattern_length, flags);
441442
IREE_TRACE_ZONE_END(z0);
442443
return status;
443444
}
444445

445446
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_update_buffer(
446447
iree_hal_command_buffer_t* command_buffer, const void* source_buffer,
447-
iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref) {
448+
iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref,
449+
iree_hal_update_flags_t flags) {
448450
IREE_ASSERT_ARGUMENT(command_buffer);
449451
IREE_ASSERT_ARGUMENT(source_buffer);
450452
if (target_ref.length == 0) {
@@ -456,17 +458,17 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_update_buffer(
456458
IREE_RETURN_AND_END_ZONE_IF_ERROR(
457459
z0, iree_hal_command_buffer_update_buffer_validation(
458460
command_buffer, VALIDATION_STATE(command_buffer), source_buffer,
459-
source_offset, target_ref));
461+
source_offset, target_ref, flags));
460462
});
461463
iree_status_t status = _VTABLE_DISPATCH(command_buffer, update_buffer)(
462-
command_buffer, source_buffer, source_offset, target_ref);
464+
command_buffer, source_buffer, source_offset, target_ref, flags);
463465
IREE_TRACE_ZONE_END(z0);
464466
return status;
465467
}
466468

467469
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_copy_buffer(
468470
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_ref_t source_ref,
469-
iree_hal_buffer_ref_t target_ref) {
471+
iree_hal_buffer_ref_t target_ref, iree_hal_copy_flags_t flags) {
470472
IREE_ASSERT_ARGUMENT(command_buffer);
471473
if (target_ref.length == 0) {
472474
// No-op copy. All other validation is skipped.
@@ -477,10 +479,10 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_copy_buffer(
477479
IREE_RETURN_AND_END_ZONE_IF_ERROR(
478480
z0, iree_hal_command_buffer_copy_buffer_validation(
479481
command_buffer, VALIDATION_STATE(command_buffer), source_ref,
480-
target_ref));
482+
target_ref, flags));
481483
});
482484
iree_status_t status = _VTABLE_DISPATCH(command_buffer, copy_buffer)(
483-
command_buffer, source_ref, target_ref);
485+
command_buffer, source_ref, target_ref, flags);
484486
IREE_TRACE_ZONE_END(z0);
485487
return status;
486488
}
@@ -658,7 +660,7 @@ IREE_API_EXPORT iree_status_t iree_hal_create_transfer_command_buffer(
658660
transfer_command->fill.target_offset,
659661
transfer_command->fill.length),
660662
transfer_command->fill.pattern,
661-
transfer_command->fill.pattern_length);
663+
transfer_command->fill.pattern_length, IREE_HAL_FILL_FLAG_NONE);
662664
break;
663665
case IREE_HAL_TRANSFER_COMMAND_TYPE_COPY:
664666
status = iree_hal_command_buffer_copy_buffer(
@@ -668,15 +670,17 @@ IREE_API_EXPORT iree_status_t iree_hal_create_transfer_command_buffer(
668670
transfer_command->copy.length),
669671
iree_hal_make_buffer_ref(transfer_command->copy.target_buffer,
670672
transfer_command->copy.target_offset,
671-
transfer_command->copy.length));
673+
transfer_command->copy.length),
674+
IREE_HAL_COPY_FLAG_NONE);
672675
break;
673676
case IREE_HAL_TRANSFER_COMMAND_TYPE_UPDATE:
674677
status = iree_hal_command_buffer_update_buffer(
675678
command_buffer, transfer_command->update.source_buffer,
676679
transfer_command->update.source_offset,
677680
iree_hal_make_buffer_ref(transfer_command->update.target_buffer,
678681
transfer_command->update.target_offset,
679-
transfer_command->update.length));
682+
transfer_command->update.length),
683+
IREE_HAL_UPDATE_FLAG_NONE);
680684
break;
681685
default:
682686
status =

runtime/src/iree/hal/command_buffer.h

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,24 @@ typedef struct iree_hal_buffer_barrier_t {
214214
iree_hal_buffer_ref_t buffer_ref;
215215
} iree_hal_buffer_barrier_t;
216216

217+
// Bitfield specifying flags controlling a fill operation.
218+
typedef uint64_t iree_hal_fill_flags_t;
219+
enum iree_hal_fill_flag_bits_t {
220+
IREE_HAL_FILL_FLAG_NONE = 0,
221+
};
222+
223+
// Bitfield specifying flags controlling a update operation.
224+
typedef uint64_t iree_hal_update_flags_t;
225+
enum iree_hal_update_flag_bits_t {
226+
IREE_HAL_UPDATE_FLAG_NONE = 0,
227+
};
228+
229+
// Bitfield specifying flags controlling a copy operation.
230+
typedef uint64_t iree_hal_copy_flags_t;
231+
enum iree_hal_copy_flag_bits_t {
232+
IREE_HAL_COPY_FLAG_NONE = 0,
233+
};
234+
217235
// Specifies the type of collective operation.
218236
enum iree_hal_collective_kind_e {
219237
// Gathers N*|element_count| elements of the specified type in |recv_binding|
@@ -391,10 +409,10 @@ IREE_API_EXPORT iree_device_size_t iree_hal_collective_element_byte_count(
391409
iree_hal_collective_element_type_t element_type);
392410

393411
// Bitfield specifying flags controlling a dispatch operation.
412+
typedef uint64_t iree_hal_dispatch_flags_t;
394413
enum iree_hal_dispatch_flag_bits_t {
395414
IREE_HAL_DISPATCH_FLAG_NONE = 0,
396415
};
397-
typedef uint64_t iree_hal_dispatch_flags_t;
398416

399417
// An RGBA color.
400418
typedef struct iree_hal_label_color_t {
@@ -684,7 +702,8 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_discard_buffer(
684702
// device queue and be allocated with IREE_HAL_BUFFER_USAGE_TRANSFER.
685703
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_fill_buffer(
686704
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_ref_t target_ref,
687-
const void* pattern, iree_host_size_t pattern_length);
705+
const void* pattern, iree_host_size_t pattern_length,
706+
iree_hal_fill_flags_t flags);
688707

689708
// Updates a range of the given target buffer from the source host memory.
690709
// The source host memory is copied immediately into the command buffer and
@@ -697,7 +716,8 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_fill_buffer(
697716
// device queue and be allocated with IREE_HAL_BUFFER_USAGE_TRANSFER.
698717
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_update_buffer(
699718
iree_hal_command_buffer_t* command_buffer, const void* source_buffer,
700-
iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref);
719+
iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref,
720+
iree_hal_update_flags_t flags);
701721

702722
// Copies a range of one buffer to another.
703723
// Both buffers must be compatible with the devices owned by this device
@@ -709,7 +729,7 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_update_buffer(
709729
// copies.
710730
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_copy_buffer(
711731
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_ref_t source_ref,
712-
iree_hal_buffer_ref_t target_ref);
732+
iree_hal_buffer_ref_t target_ref, iree_hal_copy_flags_t flags);
713733

714734
// Dispatches a collective operation defined by |op| using the given buffers.
715735
// |param| must be specified for operations that require a root/peer rank
@@ -879,15 +899,17 @@ typedef struct iree_hal_command_buffer_vtable_t {
879899
iree_status_t(IREE_API_PTR* fill_buffer)(
880900
iree_hal_command_buffer_t* command_buffer,
881901
iree_hal_buffer_ref_t target_ref, const void* pattern,
882-
iree_host_size_t pattern_length);
902+
iree_host_size_t pattern_length, iree_hal_fill_flags_t flags);
883903

884904
iree_status_t(IREE_API_PTR* update_buffer)(
885905
iree_hal_command_buffer_t* command_buffer, const void* source_buffer,
886-
iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref);
906+
iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref,
907+
iree_hal_update_flags_t flags);
887908

888909
iree_status_t(IREE_API_PTR* copy_buffer)(
889910
iree_hal_command_buffer_t* command_buffer,
890-
iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref);
911+
iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref,
912+
iree_hal_copy_flags_t flags);
891913

892914
iree_status_t(IREE_API_PTR* collective)(
893915
iree_hal_command_buffer_t* command_buffer, iree_hal_channel_t* channel,

runtime/src/iree/hal/command_buffer_validation.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ iree_status_t iree_hal_command_buffer_fill_buffer_validation(
352352
iree_hal_command_buffer_t* command_buffer,
353353
iree_hal_command_buffer_validation_state_t* validation_state,
354354
iree_hal_buffer_ref_t target_ref, const void* pattern,
355-
iree_host_size_t pattern_length) {
355+
iree_host_size_t pattern_length, iree_hal_fill_flags_t flags) {
356356
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
357357
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_TRANSFER));
358358

@@ -392,7 +392,7 @@ iree_status_t iree_hal_command_buffer_update_buffer_validation(
392392
iree_hal_command_buffer_t* command_buffer,
393393
iree_hal_command_buffer_validation_state_t* validation_state,
394394
const void* source_buffer, iree_host_size_t source_offset,
395-
iree_hal_buffer_ref_t target_ref) {
395+
iree_hal_buffer_ref_t target_ref, iree_hal_update_flags_t flags) {
396396
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
397397
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_TRANSFER));
398398

@@ -412,7 +412,8 @@ iree_status_t iree_hal_command_buffer_update_buffer_validation(
412412
iree_status_t iree_hal_command_buffer_copy_buffer_validation(
413413
iree_hal_command_buffer_t* command_buffer,
414414
iree_hal_command_buffer_validation_state_t* validation_state,
415-
iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref) {
415+
iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref,
416+
iree_hal_copy_flags_t flags) {
416417
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
417418
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_TRANSFER));
418419

runtime/src/iree/hal/command_buffer_validation.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,19 @@ iree_status_t iree_hal_command_buffer_fill_buffer_validation(
108108
iree_hal_command_buffer_t* command_buffer,
109109
iree_hal_command_buffer_validation_state_t* validation_state,
110110
iree_hal_buffer_ref_t target_ref, const void* pattern,
111-
iree_host_size_t pattern_length);
111+
iree_host_size_t pattern_length, iree_hal_fill_flags_t flags);
112112

113113
iree_status_t iree_hal_command_buffer_update_buffer_validation(
114114
iree_hal_command_buffer_t* command_buffer,
115115
iree_hal_command_buffer_validation_state_t* validation_state,
116116
const void* source_buffer, iree_host_size_t source_offset,
117-
iree_hal_buffer_ref_t target_ref);
117+
iree_hal_buffer_ref_t target_ref, iree_hal_update_flags_t flags);
118118

119119
iree_status_t iree_hal_command_buffer_copy_buffer_validation(
120120
iree_hal_command_buffer_t* command_buffer,
121121
iree_hal_command_buffer_validation_state_t* validation_state,
122-
iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref);
122+
iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref,
123+
iree_hal_copy_flags_t flags);
123124

124125
iree_status_t iree_hal_command_buffer_collective_validation(
125126
iree_hal_command_buffer_t* command_buffer,

0 commit comments

Comments
 (0)