@@ -255,7 +255,14 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
255
255
uint32_t x = program.DispatchGroupSizeX ();
256
256
uint32_t y = program.DispatchGroupSizeY ();
257
257
uint32_t z = program.DispatchGroupSizeZ ();
258
- ORT_RETURN_IF_ERROR (program_mgr_->NormalizeDispatchGroupSize (x, y, z));
258
+
259
+ // Skip normalization for indirect dispatch since dimensions are determined by the indirect buffer
260
+ if (program.IndirectDispatchTensor () == nullptr ) {
261
+ ORT_RETURN_IF_ERROR (program_mgr_->NormalizeDispatchGroupSize (x, y, z));
262
+ } else {
263
+ ORT_ENFORCE (x == 0 && y == 0 && z == 0 ,
264
+ " Only one of SetIndirectDispatchTensor and SetDispatchGroupSize should be called for program" , program.Name ());
265
+ }
259
266
260
267
bool is_1d_dispatch = (y == 1 && z == 1 );
261
268
@@ -442,7 +449,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
442
449
bind_buffers.push_back (uniform_buffer);
443
450
}
444
451
445
- LaunchComputePipeline (compute_pass_encoder, bind_buffers, *program_artifact, x, y, z);
452
+ LaunchComputePipeline (compute_pass_encoder, bind_buffers, *program_artifact, x, y, z, program. IndirectDispatchTensor () );
446
453
if (uniform_buffer) {
447
454
buffer_mgr.Release (uniform_buffer);
448
455
}
@@ -722,7 +729,8 @@ void WebGpuContext::OnRunEnd() {
722
729
void WebGpuContext::LaunchComputePipeline (const wgpu::ComputePassEncoder& compute_pass_encoder,
723
730
const std::vector<WGPUBuffer>& bind_buffers,
724
731
const ProgramArtifact& program_artifact,
725
- uint32_t x, uint32_t y, uint32_t z) {
732
+ uint32_t x, uint32_t y, uint32_t z,
733
+ const Tensor* indirect_dispatch_tensor) {
726
734
uint32_t entry_index = 0 ;
727
735
std::vector<WGPUBindGroupEntry> bind_group_entries;
728
736
for (WGPUBuffer buffer : bind_buffers) {
@@ -738,14 +746,27 @@ void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& comput
738
746
739
747
auto bind_group = wgpuDeviceCreateBindGroup (Device ().Get (), &bind_group_desc);
740
748
if (graph_capture_state_ == GraphCaptureState::Capturing) {
749
+ WGPUBuffer indirect_buffer = nullptr ;
750
+ if (indirect_dispatch_tensor != nullptr ) {
751
+ indirect_buffer = reinterpret_cast <WGPUBuffer>(const_cast <void *>(indirect_dispatch_tensor->DataRaw ()));
752
+ }
741
753
external_captured_commands_->push_back ({program_artifact.compute_pipeline ,
742
754
bind_group,
743
755
bind_group_layout,
744
- {x, y, z}});
756
+ {x, y, z},
757
+ indirect_buffer});
745
758
} else {
746
759
compute_pass_encoder.SetPipeline (program_artifact.compute_pipeline );
747
760
wgpuComputePassEncoderSetBindGroup (compute_pass_encoder.Get (), 0 , bind_group, 0 , nullptr );
748
- compute_pass_encoder.DispatchWorkgroups (x, y, z);
761
+
762
+ if (indirect_dispatch_tensor != nullptr ) {
763
+ // Use indirect dispatch
764
+ WGPUBuffer indirect_buffer = reinterpret_cast <WGPUBuffer>(const_cast <void *>(indirect_dispatch_tensor->DataRaw ()));
765
+ compute_pass_encoder.DispatchWorkgroupsIndirect (indirect_buffer, 0 );
766
+ } else {
767
+ // Use direct dispatch
768
+ compute_pass_encoder.DispatchWorkgroups (x, y, z);
769
+ }
749
770
750
771
wgpuBindGroupRelease (bind_group);
751
772
wgpuBindGroupLayoutRelease (bind_group_layout);
@@ -781,7 +802,15 @@ void WebGpuContext::Replay(const std::vector<webgpu::CapturedCommandInfo>& captu
781
802
WriteTimestamp (num_pending_dispatches_ * 2 );
782
803
compute_pass_encoder.SetPipeline (command.compute_pipeline );
783
804
wgpuComputePassEncoderSetBindGroup (compute_pass_encoder.Get (), 0 , command.bind_group , 0 , nullptr );
784
- compute_pass_encoder.DispatchWorkgroups (command.dispatch_group [0 ], command.dispatch_group [1 ], command.dispatch_group [2 ]);
805
+
806
+ if (command.indirect_buffer != nullptr ) {
807
+ // Use indirect dispatch
808
+ compute_pass_encoder.DispatchWorkgroupsIndirect (command.indirect_buffer , 0 );
809
+ } else {
810
+ // Use direct dispatch
811
+ compute_pass_encoder.DispatchWorkgroups (command.dispatch_group [0 ], command.dispatch_group [1 ], command.dispatch_group [2 ]);
812
+ }
813
+
785
814
WriteTimestamp (num_pending_dispatches_ * 2 + 1 );
786
815
++num_pending_dispatches_;
787
816
if (num_pending_dispatches_ >= max_num_pending_dispatches_ ||
0 commit comments