@@ -409,31 +409,19 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
409409
410410 WriteTimestamp (num_pending_dispatches_ * 2 );
411411
412- uint32_t entry_index = 0 ;
413- std::vector<wgpu::BindGroupEntry> bind_group_entries ;
412+ std::vector<WGPUBuffer> bind_buffers ;
413+ bind_buffers. reserve (inputs. size () + outputs. size () + (uniform_buffer ? 1 : 0 )) ;
414414 for (const auto & input : inputs) {
415- bind_group_entries .push_back ({ nullptr , entry_index++, reinterpret_cast <WGPUBuffer>(const_cast <void *>(input.tensor ->DataRaw ()))} );
415+ bind_buffers .push_back (reinterpret_cast <WGPUBuffer>(const_cast <void *>(input.tensor ->DataRaw ())));
416416 }
417417 for (const auto & output : outputs) {
418- bind_group_entries .push_back ({ nullptr , entry_index++, reinterpret_cast <WGPUBuffer>(output.tensor ->MutableDataRaw ())} );
418+ bind_buffers .push_back (reinterpret_cast <WGPUBuffer>(output.tensor ->MutableDataRaw ()));
419419 }
420420 if (uniform_buffer) {
421- bind_group_entries .push_back ({ nullptr , entry_index++, uniform_buffer} );
421+ bind_buffers .push_back (uniform_buffer);
422422 }
423423
424- wgpu::BindGroupDescriptor bind_group_desc{};
425- bind_group_desc.layout = program_artifact->compute_pipeline .GetBindGroupLayout (0 );
426- bind_group_desc.entryCount = bind_group_entries.size ();
427- bind_group_desc.entries = bind_group_entries.data ();
428- bind_group_desc.label = program_artifact->name .c_str ();
429-
430- auto bind_group = Device ().CreateBindGroup (&bind_group_desc);
431-
432- // TODO support graph capture
433-
434- compute_pass_encoder.SetPipeline (program_artifact->compute_pipeline );
435- compute_pass_encoder.SetBindGroup (0 , bind_group);
436- compute_pass_encoder.DispatchWorkgroups (x, y, z);
424+ LaunchComputePipeline (compute_pass_encoder, bind_buffers, *program_artifact, x, y, z);
437425
438426 if (uniform_buffer) {
439427 buffer_mgr_->Release (uniform_buffer);
@@ -708,6 +696,35 @@ void WebGpuContext::OnRunEnd() {
708696#endif // ENABLE_PIX_FOR_WEBGPU_EP
709697}
710698
699+ void WebGpuContext::LaunchComputePipeline (const wgpu::ComputePassEncoder& compute_pass_encoder,
700+ const std::vector<WGPUBuffer>& bind_buffers,
701+ const ProgramArtifact& program_artifact,
702+ uint32_t x, uint32_t y, uint32_t z) {
703+ uint32_t entry_index = 0 ;
704+ std::vector<WGPUBindGroupEntry> bind_group_entries;
705+ for (WGPUBuffer buffer : bind_buffers) {
706+ bind_group_entries.push_back ({nullptr , entry_index++, buffer, 0 , WGPU_WHOLE_SIZE, nullptr , nullptr });
707+ }
708+
709+ WGPUBindGroupLayout bind_group_layout = program_artifact.compute_pipeline .GetBindGroupLayout (0 ).MoveToCHandle ();
710+ WGPUBindGroupDescriptor bind_group_desc{};
711+ bind_group_desc.layout = bind_group_layout;
712+ bind_group_desc.entryCount = bind_group_entries.size ();
713+ bind_group_desc.entries = bind_group_entries.data ();
714+ bind_group_desc.label = {program_artifact.name .data (), program_artifact.name .length ()};
715+
716+ auto bind_group = wgpuDeviceCreateBindGroup (Device ().Get (), &bind_group_desc);
717+
718+ // TODO support graph capture
719+
720+ compute_pass_encoder.SetPipeline (program_artifact.compute_pipeline );
721+ wgpuComputePassEncoderSetBindGroup (compute_pass_encoder.Get (), 0 , bind_group, 0 , nullptr );
722+ compute_pass_encoder.DispatchWorkgroups (x, y, z);
723+
724+ wgpuBindGroupRelease (bind_group);
725+ wgpuBindGroupLayoutRelease (bind_group_layout);
726+ }
727+
711728std::unordered_map<int32_t , WebGpuContextFactory::WebGpuContextInfo> WebGpuContextFactory::contexts_;
712729std::mutex WebGpuContextFactory::mutex_;
713730std::once_flag WebGpuContextFactory::init_default_flag_;
0 commit comments