Skip to content

Commit c773e2f

Browse files
committed
Try waiting on all futures explicitly
1 parent 69965a8 commit c773e2f

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ struct webgpu_context_struct {
139139

140140
// Parameter buffers associated with the staged command buffers
141141
std::vector<webgpu_param_bufs> staged_param_bufs;
142+
143+
std::vector<wgpu::FutureWaitInfo> callback_futures;
142144
};
143145

144146
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@@ -221,16 +223,14 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
221223

222224
/** WebGPU Actions */
223225

226+
// Wait for the queue to finish processing all submitted work
224227
static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
225-
// Wait for the queue to finish processing all commands
226-
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
227-
wgpu::CallbackMode::AllowSpontaneous,
228-
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
229-
if (status != wgpu::QueueWorkDoneStatus::Success) {
230-
GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data);
231-
}
232-
}),
233-
UINT64_MAX);
228+
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
229+
if (ctx->callback_futures.empty()) {
230+
return;
231+
}
232+
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
233+
ctx->callback_futures.clear();
234234
}
235235

236236
static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
@@ -243,7 +243,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
243243
ctx->staged_command_bufs.clear();
244244
std::vector<webgpu_param_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
245245
// Free the staged parameter buffers once the submission completes
246-
ctx->queue.OnSubmittedWorkDone(
246+
wgpu::Future f = ctx->queue.OnSubmittedWorkDone(
247247
wgpu::CallbackMode::AllowSpontaneous,
248248
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
249249
if (status != wgpu::QueueWorkDoneStatus::Success) {
@@ -252,6 +252,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
252252
// Free the staged parameter buffers
253253
ctx->param_buf_pool.free_bufs(staged_param_bufs);
254254
});
255+
ctx->callback_futures.push_back({ f });
255256
}
256257

257258
static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
@@ -311,14 +312,16 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
311312
if (submit_imm) {
312313
// Submit immediately
313314
ctx->queue.Submit(1, &commands);
314-
ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
315+
wgpu::Future f = ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
315316
[ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
316317
if (status != wgpu::QueueWorkDoneStatus::Success) {
317318
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
318319
message.data);
319320
}
320321
ctx->param_buf_pool.free_bufs({ params_bufs });
321322
});
323+
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
324+
ctx->callback_futures.push_back({ f });
322325
} else {
323326
// Lock the context mutex when pushing to the staging vectors.
324327
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);

0 commit comments

Comments
 (0)