@@ -806,123 +806,6 @@ static iree_status_t iree_hal_webgpu_command_buffer_push_descriptor_set(
806806 return iree_ok_status ();
807807}
808808
809- static iree_status_t iree_hal_webgpu_command_buffer_prepare_dispatch (
810- iree_hal_webgpu_command_buffer_t * command_buffer ,
811- iree_hal_executable_t * executable , uint32_t ordinal ,
812- iree_const_byte_span_t constants , iree_hal_buffer_ref_list_t bindings ,
813- iree_hal_dispatch_flags_t flags , WGPUComputePassEncoder * out_compute_pass ) {
814- const iree_hal_webgpu_entry_point_t * entry_point =
815- iree_hal_webgpu_executable_lookup_entry_point (executable , ordinal );
816-
817- // Upload push constant data - this may incur a segment flush if the staging
818- // buffer is exhausted.
819- iree_host_size_t constant_count =
820- iree_hal_webgpu_pipeline_layout_constant_count (entry_point -> layout );
821- iree_const_byte_span_t constant_data = iree_make_const_byte_span (
822- command_buffer -> state .constants ,
823- constant_count * sizeof (command_buffer -> state .constants [0 ]));
824- uint32_t params_offset = 0 ;
825- IREE_RETURN_IF_ERROR (iree_hal_webgpu_command_buffer_append_parameters (
826- command_buffer , constant_data , & params_offset ));
827-
828- // Acquire the compute pass we'll encode the dispatch into - this may be
829- // fresh or reused from prior commands.
830- WGPUComputePassEncoder compute_pass = NULL ;
831- IREE_RETURN_IF_ERROR (iree_hal_webgpu_command_buffer_acquire_compute_pass (
832- command_buffer , & compute_pass ));
833- wgpuComputePassEncoderSetPipeline (compute_pass , entry_point -> pipeline );
834-
835- if (constant_count > 0 ) {
836- // Bind the push constant emulation bind group at the staging buffer
837- // relative offset for this dispatch.
838- wgpuComputePassEncoderSetBindGroup (
839- compute_pass , IREE_HAL_WEBGPU_PARAMS_BIND_GROUP_INDEX ,
840- command_buffer -> staging_buffer -> bind_group , 1 , & params_offset );
841- }
842-
843- // Set all bindings.
844- const iree_hal_webgpu_set_binding_info_t * binding_info =
845- iree_hal_webgpu_pipeline_layout_set_binding_info (entry_point -> layout );
846- for (iree_host_size_t i = 0 ; i < binding_info -> set_count ; ++ i ) {
847- // If there are no bindings in this set we can skip it.
848- if (binding_info -> set_masks [i ] == 0 ) continue ;
849-
850- // If there is a bind group handle then it means we've done the lookup and
851- // set the bind group on the device already - we can skip.
852- if (command_buffer -> state .bind_groups [i ].handle ) continue ;
853-
854- // Acquire the bind group to use for the current descriptor set.
855- WGPUBindGroup handle = iree_hal_webgpu_bind_group_cache_acquire (
856- command_buffer -> bind_group_cache , binding_info -> set_layouts [i ],
857- command_buffer -> state .bind_groups [i ].bindings ,
858- binding_info -> set_masks [i ]);
859-
860- // NOTE: today we don't support dynamic offsets for push descriptor sets.
861- // This will be a larger change we'll need to handle in the compiler. If we
862- // wanted to improve caching we could make all the bindings dynamic and then
863- // always cache the base offsets, however
864- // maxDynamicStorageBuffersPerPipelineLayout is minimally 4 and that's not
865- // a lot of bindings.
866- wgpuComputePassEncoderSetBindGroup (compute_pass , (uint32_t )i , handle , 0 ,
867- NULL );
868- command_buffer -> state .bind_groups [i ].handle = handle ;
869- command_buffer -> state .bind_groups_empty &= ~(1ull << i );
870- }
871-
872- if (constant_count > 0 ) {
873- // Pad up to IREE_HAL_WEBGPU_PARAMS_BIND_GROUP_INDEX with empty bind groups.
874- WGPUBindGroup empty_handle =
875- command_buffer -> staging_buffer -> empty_bind_group ;
876- for (iree_host_size_t i = binding_info -> set_count ;
877- i < IREE_HAL_WEBGPU_PARAMS_BIND_GROUP_INDEX ; ++ i ) {
878- // Skip if an empty group is already set at this index.
879- if ((command_buffer -> state .bind_groups_empty >> i ) & 1ull ) continue ;
880-
881- wgpuComputePassEncoderSetBindGroup (compute_pass , (uint32_t )i ,
882- empty_handle , 0 , NULL );
883- command_buffer -> state .bind_groups [i ].handle = empty_handle ;
884- command_buffer -> state .bind_groups_empty |= 1ull << i ;
885- }
886- }
887-
888- * out_compute_pass = compute_pass ;
889- return iree_ok_status ();
890- }
891-
892- static iree_status_t iree_hal_webgpu_command_buffer_dispatch (
893- iree_hal_command_buffer_t * base_command_buffer ,
894- iree_hal_executable_t * executable , int32_t entry_point ,
895- uint32_t workgroup_x , uint32_t workgroup_y , uint32_t workgroup_z ,
896- iree_hal_dispatch_flags_t flags ) {
897- iree_hal_webgpu_command_buffer_t * command_buffer =
898- iree_hal_webgpu_command_buffer_cast (base_command_buffer );
899-
900- WGPUComputePassEncoder compute_pass = NULL ;
901- IREE_RETURN_IF_ERROR (iree_hal_webgpu_command_buffer_prepare_dispatch (
902- command_buffer , executable , entry_point , & compute_pass ));
903- wgpuComputePassEncoderDispatchWorkgroups (compute_pass , workgroup_x ,
904- workgroup_y , workgroup_z );
905-
906- return iree_ok_status ();
907- }
908-
909- static iree_status_t iree_hal_webgpu_command_buffer_dispatch_indirect (
910- iree_hal_command_buffer_t * base_command_buffer ,
911- iree_hal_executable_t * executable , int32_t entry_point ,
912- iree_hal_buffer_ref_t workgroups_ref , iree_hal_dispatch_flags_t flags ) {
913- iree_hal_webgpu_command_buffer_t * command_buffer =
914- iree_hal_webgpu_command_buffer_cast (base_command_buffer );
915-
916- WGPUComputePassEncoder compute_pass = NULL ;
917- IREE_RETURN_IF_ERROR (iree_hal_webgpu_command_buffer_prepare_dispatch (
918- command_buffer , executable , entry_point , & compute_pass ));
919- wgpuComputePassEncoderDispatchWorkgroupsIndirect (
920- compute_pass , iree_hal_webgpu_buffer_handle (workgroups_ref .buffer ),
921- workgroups_ref .offset );
922-
923- return iree_ok_status ();
924- }
925-
926809static iree_status_t iree_hal_webgpu_command_buffer_prepare_dispatch (
927810 iree_hal_webgpu_command_buffer_t * command_buffer ,
928811 iree_hal_executable_t * executable , uint32_t ordinal ,
@@ -968,15 +851,16 @@ static iree_status_t iree_hal_webgpu_command_buffer_prepare_dispatch(
968851 binding_mask |= 1u << i ;
969852 group_bindings [i ].type = WGPUBufferBindingType_Storage ;
970853 group_bindings [i ].buffer =
971- bindings [i ].buffer ? iree_hal_webgpu_buffer_handle (bindings [i ].buffer )
972- : NULL ;
973- group_bindings [i ] offset = bindings [i ].offset ;
974- group_bindings [i ] length = bindings [i ].length ;
854+ bindings .values [i ].buffer
855+ ? iree_hal_webgpu_buffer_handle (bindings .values [i ].buffer )
856+ : NULL ;
857+ group_bindings [i ].offset = bindings .values [i ].offset ;
858+ group_bindings [i ].length = bindings .values [i ].length ;
975859 }
976860
977861 // Acquire the bind group to use for the current descriptor set.
978862 WGPUBindGroup handle = iree_hal_webgpu_bind_group_cache_acquire (
979- command_buffer -> bind_group_cache , binding_info -> set_layout ,
863+ command_buffer -> bind_group_cache , binding_info -> set_layouts [ 0 ] ,
980864 group_bindings , binding_mask );
981865
982866 // NOTE: today we don't support dynamic offsets for push descriptor sets.
@@ -994,36 +878,33 @@ static iree_status_t iree_hal_webgpu_command_buffer_prepare_dispatch(
994878static iree_status_t iree_hal_webgpu_command_buffer_dispatch (
995879 iree_hal_command_buffer_t * base_command_buffer ,
996880 iree_hal_executable_t * executable , int32_t entry_point ,
997- const uint32_t workgroup_count [3 ], iree_const_byte_span_t constants ,
998- iree_hal_buffer_ref_list_t bindings , iree_hal_dispatch_flags_t flags ) {
881+ const iree_hal_dispatch_config_t config , iree_const_byte_span_t constants ,
882+ const iree_hal_buffer_ref_list_t bindings ,
883+ iree_hal_dispatch_flags_t flags ) {
999884 iree_hal_webgpu_command_buffer_t * command_buffer =
1000885 iree_hal_webgpu_command_buffer_cast (base_command_buffer );
1001886
1002- WGPUComputePassEncoder compute_pass = NULL ;
1003- IREE_RETURN_IF_ERROR (iree_hal_webgpu_command_buffer_prepare_dispatch (
1004- command_buffer , executable , entry_point , constants , bindings , flags ,
1005- & compute_pass ));
1006- wgpuComputePassEncoderDispatchWorkgroups (
1007- compute_pass , workgroup_count [0 ], workgroup_count [1 ], workgroup_count [2 ]);
1008-
1009- return iree_ok_status ();
1010- }
1011-
1012- static iree_status_t iree_hal_webgpu_command_buffer_dispatch_indirect (
1013- iree_hal_command_buffer_t * base_command_buffer ,
1014- iree_hal_executable_t * executable , int32_t entry_point ,
1015- iree_hal_buffer_ref_t workgroups_ref , iree_const_byte_span_t constants ,
1016- iree_hal_buffer_ref_list_t bindings , iree_hal_dispatch_flags_t flags ) {
1017- iree_hal_webgpu_command_buffer_t * command_buffer =
1018- iree_hal_webgpu_command_buffer_cast (base_command_buffer );
887+ if (iree_hal_dispatch_uses_custom_arguments (flags )) {
888+ return iree_make_status (
889+ IREE_STATUS_UNIMPLEMENTED ,
890+ "direct/indirect arguments are not supported in WebGPU" );
891+ }
1019892
1020893 WGPUComputePassEncoder compute_pass = NULL ;
1021894 IREE_RETURN_IF_ERROR (iree_hal_webgpu_command_buffer_prepare_dispatch (
1022895 command_buffer , executable , entry_point , constants , bindings , flags ,
1023896 & compute_pass ));
1024- wgpuComputePassEncoderDispatchWorkgroupsIndirect (
1025- compute_pass , iree_hal_webgpu_buffer_handle (workgroups_ref .buffer ),
1026- workgroups_ref .offset );
897+
898+ if (iree_hal_dispatch_uses_indirect_parameters (flags )) {
899+ wgpuComputePassEncoderDispatchWorkgroupsIndirect (
900+ compute_pass ,
901+ iree_hal_webgpu_buffer_handle (config .workgroup_count_ref .buffer ),
902+ config .workgroup_count_ref .offset );
903+ } else {
904+ wgpuComputePassEncoderDispatchWorkgroups (
905+ compute_pass , config .workgroup_count [0 ], config .workgroup_count [1 ],
906+ config .workgroup_count [2 ]);
907+ }
1027908
1028909 return iree_ok_status ();
1029910}
@@ -1045,7 +926,4 @@ const iree_hal_command_buffer_vtable_t iree_hal_webgpu_command_buffer_vtable = {
1045926 .constants = iree_hal_webgpu_command_buffer_constants ,
1046927 .push_descriptor_set = iree_hal_webgpu_command_buffer_push_descriptor_set ,
1047928 .dispatch = iree_hal_webgpu_command_buffer_dispatch ,
1048- .dispatch_indirect = iree_hal_webgpu_command_buffer_dispatch_indirect ,
1049- .dispatch = iree_hal_webgpu_command_buffer_dispatch ,
1050- .dispatch_indirect = iree_hal_webgpu_command_buffer_dispatch_indirect ,
1051929};
0 commit comments