@@ -139,6 +139,8 @@ struct webgpu_context_struct {
139139
140140 // Parameter buffers associated with the staged command buffers
141141 std::vector<webgpu_param_bufs> staged_param_bufs;
142+
143+ std::vector<wgpu::FutureWaitInfo> callback_futures;
142144};
143145
144146typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@@ -221,16 +223,14 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
221223
222224/* * WebGPU Actions */
223225
226+ // Wait for the queue to finish processing all submitted work
224227static void ggml_backend_webgpu_wait_on_submission (webgpu_context & ctx) {
225- // Wait for the queue to finish processing all commands
226- ctx->instance .WaitAny (ctx->queue .OnSubmittedWorkDone (
227- wgpu::CallbackMode::AllowSpontaneous,
228- [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
229- if (status != wgpu::QueueWorkDoneStatus::Success) {
230- GGML_LOG_ERROR (" ggml_webgpu: Failed to wait on queue: %s\n " , message.data );
231- }
232- }),
233- UINT64_MAX);
228+ std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
229+ if (ctx->callback_futures .empty ()) {
230+ return ;
231+ }
232+ ctx->instance .WaitAny (ctx->callback_futures .size (), ctx->callback_futures .data (), UINT64_MAX);
233+ ctx->callback_futures .clear ();
234234}
235235
236236static void ggml_backend_webgpu_submit_queue (webgpu_context & ctx) {
@@ -243,7 +243,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
243243 ctx->staged_command_bufs .clear ();
244244 std::vector<webgpu_param_bufs> staged_param_bufs = std::move (ctx->staged_param_bufs );
245245 // Free the staged parameter buffers once the submission completes
246- ctx->queue .OnSubmittedWorkDone (
246+ wgpu::Future f = ctx->queue .OnSubmittedWorkDone (
247247 wgpu::CallbackMode::AllowSpontaneous,
248248 [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
249249 if (status != wgpu::QueueWorkDoneStatus::Success) {
@@ -252,6 +252,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
252252 // Free the staged parameter buffers
253253 ctx->param_buf_pool .free_bufs (staged_param_bufs);
254254 });
255+ ctx->callback_futures .push_back ({ f });
255256}
256257
257258static void ggml_backend_webgpu_map_buffer (webgpu_context & ctx,
@@ -311,14 +312,16 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
311312 if (submit_imm) {
312313 // Submit immediately
313314 ctx->queue .Submit (1 , &commands);
314- ctx->queue .OnSubmittedWorkDone (wgpu::CallbackMode::AllowSpontaneous,
315+ wgpu::Future f = ctx->queue .OnSubmittedWorkDone (wgpu::CallbackMode::AllowSpontaneous,
315316 [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
316317 if (status != wgpu::QueueWorkDoneStatus::Success) {
317318 GGML_LOG_ERROR (" ggml_webgpu: Failed to submit commands: %s\n " ,
318319 message.data );
319320 }
320321 ctx->param_buf_pool .free_bufs ({ params_bufs });
321322 });
323+ std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
324+ ctx->callback_futures .push_back ({ f });
322325 } else {
323326 // Lock the context mutex when pushing to the staging vectors.
324327 std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
0 commit comments