Skip to content

Commit f6b405c

Browse files
[webgpu] Add dispatchWorkgroupsIndirect support (#25934)
### Description This PR adds the dispatchWorkgroupsIndirect capability for the program. It's part of the work to enable graph capture in phi4 #25868 --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent fd35afb commit f6b405c

File tree

5 files changed

+51
-8
lines changed

5 files changed

+51
-8
lines changed

onnxruntime/core/providers/webgpu/allocator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void* GpuBufferAllocator::Alloc(size_t size) {
2727
stats_.num_allocs++;
2828

2929
wgpu::BufferUsage usage = mapped_at_creation_ ? wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite
30-
: wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst;
30+
: wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Indirect;
3131

3232
return buffer_manager_.Create(size, usage);
3333
}

onnxruntime/core/providers/webgpu/program.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ ProgramBase::ProgramBase(std::string_view name, ProgramMetadata&& metadata)
319319
dispatch_group_size_x_{0},
320320
dispatch_group_size_y_{0},
321321
dispatch_group_size_z_{0},
322+
indirect_dispatch_tensor_{nullptr},
322323
workgroup_size_x_{0},
323324
workgroup_size_y_{0},
324325
workgroup_size_z_{0} {
@@ -359,6 +360,11 @@ ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x, uint32_t y, uint32_t
359360
return *this;
360361
}
361362

363+
ProgramBase& ProgramBase::SetIndirectDispatchTensor(const Tensor* indirect_dispatch_tensor) {
364+
indirect_dispatch_tensor_ = indirect_dispatch_tensor;
365+
return *this;
366+
}
367+
362368
ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x) {
363369
return SetWorkgroupSize(x, 1, 1);
364370
}

onnxruntime/core/providers/webgpu/program.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,9 @@ class ProgramBase {
305305
// set the size of dispatch groups.
306306
ProgramBase& SetDispatchGroupSize(uint32_t x, uint32_t y, uint32_t z);
307307

308+
// set indirect dispatch tensor for indirect dispatch
309+
ProgramBase& SetIndirectDispatchTensor(const Tensor* indirect_dispatch_tensor);
310+
308311
// set the size of a workgroup grid. Y and Z are 1 if not specified.
309312
ProgramBase& SetWorkgroupSize(uint32_t x);
310313
// set the size of a workgroup grid. Z is 1 if not specified.
@@ -348,6 +351,7 @@ class ProgramBase {
348351
inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; }
349352
inline uint32_t DispatchGroupSizeY() const { return dispatch_group_size_y_; }
350353
inline uint32_t DispatchGroupSizeZ() const { return dispatch_group_size_z_; }
354+
inline const Tensor* IndirectDispatchTensor() const { return indirect_dispatch_tensor_; }
351355
inline uint32_t WorkgroupSizeX() const { return workgroup_size_x_; }
352356
inline uint32_t WorkgroupSizeY() const { return workgroup_size_y_; }
353357
inline uint32_t WorkgroupSizeZ() const { return workgroup_size_z_; }
@@ -374,6 +378,8 @@ class ProgramBase {
374378
uint32_t dispatch_group_size_y_;
375379
uint32_t dispatch_group_size_z_;
376380

381+
const Tensor* indirect_dispatch_tensor_;
382+
377383
uint32_t workgroup_size_x_;
378384
uint32_t workgroup_size_y_;
379385
uint32_t workgroup_size_z_;

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,14 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
255255
uint32_t x = program.DispatchGroupSizeX();
256256
uint32_t y = program.DispatchGroupSizeY();
257257
uint32_t z = program.DispatchGroupSizeZ();
258-
ORT_RETURN_IF_ERROR(program_mgr_->NormalizeDispatchGroupSize(x, y, z));
258+
259+
// Skip normalization for indirect dispatch since dimensions are determined by the indirect buffer
260+
if (program.IndirectDispatchTensor() == nullptr) {
261+
ORT_RETURN_IF_ERROR(program_mgr_->NormalizeDispatchGroupSize(x, y, z));
262+
} else {
263+
ORT_ENFORCE(x == 0 && y == 0 && z == 0,
264+
"Only one of SetIndirectDispatchTensor and SetDispatchGroupSize should be called for program", program.Name());
265+
}
259266

260267
bool is_1d_dispatch = (y == 1 && z == 1);
261268

@@ -442,7 +449,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
442449
bind_buffers.push_back(uniform_buffer);
443450
}
444451

445-
LaunchComputePipeline(compute_pass_encoder, bind_buffers, *program_artifact, x, y, z);
452+
LaunchComputePipeline(compute_pass_encoder, bind_buffers, *program_artifact, x, y, z, program.IndirectDispatchTensor());
446453
if (uniform_buffer) {
447454
buffer_mgr.Release(uniform_buffer);
448455
}
@@ -722,7 +729,8 @@ void WebGpuContext::OnRunEnd() {
722729
void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& compute_pass_encoder,
723730
const std::vector<WGPUBuffer>& bind_buffers,
724731
const ProgramArtifact& program_artifact,
725-
uint32_t x, uint32_t y, uint32_t z) {
732+
uint32_t x, uint32_t y, uint32_t z,
733+
const Tensor* indirect_dispatch_tensor) {
726734
uint32_t entry_index = 0;
727735
std::vector<WGPUBindGroupEntry> bind_group_entries;
728736
for (WGPUBuffer buffer : bind_buffers) {
@@ -738,14 +746,27 @@ void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& comput
738746

739747
auto bind_group = wgpuDeviceCreateBindGroup(Device().Get(), &bind_group_desc);
740748
if (graph_capture_state_ == GraphCaptureState::Capturing) {
749+
WGPUBuffer indirect_buffer = nullptr;
750+
if (indirect_dispatch_tensor != nullptr) {
751+
indirect_buffer = reinterpret_cast<WGPUBuffer>(const_cast<void*>(indirect_dispatch_tensor->DataRaw()));
752+
}
741753
external_captured_commands_->push_back({program_artifact.compute_pipeline,
742754
bind_group,
743755
bind_group_layout,
744-
{x, y, z}});
756+
{x, y, z},
757+
indirect_buffer});
745758
} else {
746759
compute_pass_encoder.SetPipeline(program_artifact.compute_pipeline);
747760
wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, bind_group, 0, nullptr);
748-
compute_pass_encoder.DispatchWorkgroups(x, y, z);
761+
762+
if (indirect_dispatch_tensor != nullptr) {
763+
// Use indirect dispatch
764+
WGPUBuffer indirect_buffer = reinterpret_cast<WGPUBuffer>(const_cast<void*>(indirect_dispatch_tensor->DataRaw()));
765+
compute_pass_encoder.DispatchWorkgroupsIndirect(indirect_buffer, 0);
766+
} else {
767+
// Use direct dispatch
768+
compute_pass_encoder.DispatchWorkgroups(x, y, z);
769+
}
749770

750771
wgpuBindGroupRelease(bind_group);
751772
wgpuBindGroupLayoutRelease(bind_group_layout);
@@ -781,7 +802,15 @@ void WebGpuContext::Replay(const std::vector<webgpu::CapturedCommandInfo>& captu
781802
WriteTimestamp(num_pending_dispatches_ * 2);
782803
compute_pass_encoder.SetPipeline(command.compute_pipeline);
783804
wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, command.bind_group, 0, nullptr);
784-
compute_pass_encoder.DispatchWorkgroups(command.dispatch_group[0], command.dispatch_group[1], command.dispatch_group[2]);
805+
806+
if (command.indirect_buffer != nullptr) {
807+
// Use indirect dispatch
808+
compute_pass_encoder.DispatchWorkgroupsIndirect(command.indirect_buffer, 0);
809+
} else {
810+
// Use direct dispatch
811+
compute_pass_encoder.DispatchWorkgroups(command.dispatch_group[0], command.dispatch_group[1], command.dispatch_group[2]);
812+
}
813+
785814
WriteTimestamp(num_pending_dispatches_ * 2 + 1);
786815
++num_pending_dispatches_;
787816
if (num_pending_dispatches_ >= max_num_pending_dispatches_ ||

onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ struct CapturedCommandInfo {
3030
WGPUBindGroup bind_group;
3131
WGPUBindGroupLayout bind_group_layout;
3232
std::array<uint32_t, 3> dispatch_group;
33+
WGPUBuffer indirect_buffer; // WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch
3334
};
3435

3536
struct WebGpuContextConfig {
@@ -182,7 +183,8 @@ class WebGpuContext final {
182183
void LaunchComputePipeline(const wgpu::ComputePassEncoder& compute_pass_encoder,
183184
const std::vector<WGPUBuffer>& bind_buffers,
184185
const ProgramArtifact& program_artifact,
185-
uint32_t x, uint32_t y, uint32_t z);
186+
uint32_t x, uint32_t y, uint32_t z,
187+
const Tensor* indirect_dispatch_tensor = nullptr);
186188

187189
std::vector<const char*> GetEnabledAdapterToggles() const;
188190
std::vector<const char*> GetEnabledDeviceToggles() const;

0 commit comments

Comments
 (0)