@@ -118,8 +118,6 @@ struct webgpu_context_struct {
118
118
wgpu::Limits limits;
119
119
120
120
std::recursive_mutex mutex;
121
- std::mutex get_tensor_mutex;
122
- std::mutex init_mutex;
123
121
124
122
bool device_init = false ;
125
123
@@ -139,6 +137,8 @@ struct webgpu_context_struct {
139
137
140
138
// Parameter buffers associated with the staged command buffers
141
139
std::vector<webgpu_param_bufs> staged_param_bufs;
140
+
141
+ std::vector<wgpu::FutureWaitInfo> callback_futures;
142
142
};
143
143
144
144
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@@ -221,25 +221,39 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
221
221
222
222
/* * WebGPU Actions */
223
223
224
+ // Wait for the queue to finish processing all submitted work
224
225
static void ggml_backend_webgpu_wait_on_submission (webgpu_context & ctx) {
225
- // Wait for the queue to finish processing all commands
226
- ctx->instance .WaitAny (ctx->queue .OnSubmittedWorkDone (
227
- wgpu::CallbackMode::AllowSpontaneous,
228
- [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
229
- if (status != wgpu::QueueWorkDoneStatus::Success) {
230
- GGML_LOG_ERROR (" ggml_webgpu: Failed to wait on queue: %s\n " , message.data );
231
- }
232
- }),
233
- UINT64_MAX);
226
+ std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
227
+ if (ctx->callback_futures .empty ()) {
228
+ // no existing callbacks, wait on queue submission
229
+ ctx->instance .WaitAny (ctx->queue .OnSubmittedWorkDone (
230
+ wgpu::CallbackMode::AllowSpontaneous,
231
+ [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
232
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
233
+ GGML_LOG_ERROR (" ggml_webgpu: Failed to submit commands: %s\n " , message.data );
234
+ }
235
+ }),
236
+ UINT64_MAX);
237
+ } else {
238
+ // existing callbacks, wait on them
239
+ ctx->instance .WaitAny (ctx->callback_futures .size (), ctx->callback_futures .data (), UINT64_MAX);
240
+ ctx->callback_futures .clear ();
241
+ }
234
242
}
235
243
236
244
static void ggml_backend_webgpu_submit_queue (webgpu_context & ctx) {
237
245
std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
246
+ WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_submit_queue()" );
247
+ if (ctx->staged_command_bufs .empty ()) {
248
+ // Nothing to submit
249
+ return ;
250
+ }
238
251
ctx->queue .Submit (ctx->staged_command_bufs .size (), ctx->staged_command_bufs .data ());
239
252
ctx->staged_command_bufs .clear ();
240
253
std::vector<webgpu_param_bufs> staged_param_bufs = std::move (ctx->staged_param_bufs );
254
+
241
255
// Free the staged parameter buffers once the submission completes
242
- ctx->queue .OnSubmittedWorkDone (
256
+ wgpu::Future f = ctx->queue .OnSubmittedWorkDone (
243
257
wgpu::CallbackMode::AllowSpontaneous,
244
258
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
245
259
if (status != wgpu::QueueWorkDoneStatus::Success) {
@@ -248,6 +262,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
248
262
// Free the staged parameter buffers
249
263
ctx->param_buf_pool .free_bufs (staged_param_bufs);
250
264
});
265
+ ctx->callback_futures .push_back ({ f });
251
266
}
252
267
253
268
static void ggml_backend_webgpu_map_buffer (webgpu_context & ctx,
@@ -273,7 +288,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
273
288
std::vector<uint32_t > params,
274
289
std::vector<wgpu::BindGroupEntry> bind_group_entries,
275
290
uint32_t wg_x,
276
- bool submit_imm = false ) {
291
+ bool submit_and_wait = false ) {
277
292
webgpu_param_bufs params_bufs = ctx->param_buf_pool .alloc_bufs ();
278
293
279
294
ggml_backend_webgpu_map_buffer (ctx, params_bufs.host_buf , wgpu::MapMode::Write, 0 , params_bufs.host_buf .GetSize ());
@@ -304,17 +319,18 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
304
319
pass.DispatchWorkgroups (wg_x, 1 , 1 );
305
320
pass.End ();
306
321
wgpu::CommandBuffer commands = encoder.Finish ();
307
- if (submit_imm ) {
308
- // Submit immediately
322
+ if (submit_and_wait ) {
323
+ // Submit and wait immediately
309
324
ctx->queue .Submit (1 , &commands);
310
- ctx->queue .OnSubmittedWorkDone (wgpu::CallbackMode::AllowSpontaneous,
311
- [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
312
- if (status != wgpu::QueueWorkDoneStatus::Success) {
313
- GGML_LOG_ERROR (" ggml_webgpu: Failed to submit commands: %s\n " ,
314
- message.data );
315
- }
316
- ctx->param_buf_pool .free_bufs ({ params_bufs });
317
- });
325
+ ctx->instance .WaitAny (ctx->queue .OnSubmittedWorkDone (
326
+ wgpu::CallbackMode::AllowSpontaneous,
327
+ [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
328
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
329
+ GGML_LOG_ERROR (" ggml_webgpu: Failed to submit commands: %s\n " , message.data );
330
+ }
331
+ ctx->param_buf_pool .free_bufs ({ params_bufs });
332
+ }),
333
+ UINT64_MAX);
318
334
} else {
319
335
// Lock the context mutex when pushing to the staging vectors.
320
336
std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
@@ -579,6 +595,9 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
579
595
// memset the remaining bytes
580
596
ggml_backend_webgpu_buffer_memset (
581
597
webgpu_ctx, buf_ctx->buffer , val32, total_offset + (size - remaining_size), remaining_size);
598
+ } else {
599
+ // wait for WriteBuffer to complete
600
+ ggml_backend_webgpu_wait_on_submission (webgpu_ctx);
582
601
}
583
602
}
584
603
@@ -602,7 +621,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
602
621
final_size = size + (4 - (size % 4 ));
603
622
}
604
623
605
- std::lock_guard<std::mutex > lock (webgpu_ctx->get_tensor_mutex );
624
+ std::lock_guard<std::recursive_mutex > lock (webgpu_ctx->mutex );
606
625
607
626
if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf .GetSize () < final_size) {
608
627
// Create a new staging buffer if it doesn't exist or is too small
@@ -768,10 +787,11 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
768
787
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx ;
769
788
770
789
// Multiple threads may try to initialize the device
771
- std::lock_guard<std::mutex > lock (webgpu_ctx->init_mutex );
790
+ std::lock_guard<std::recursive_mutex > lock (webgpu_ctx->mutex );
772
791
if (!webgpu_ctx->device_init ) {
773
792
// Initialize device
774
- std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization };
793
+ std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
794
+ wgpu::FeatureName::ImplicitDeviceSynchronization };
775
795
wgpu::DeviceDescriptor dev_desc;
776
796
dev_desc.requiredLimits = &webgpu_ctx->limits ;
777
797
dev_desc.requiredFeatures = required_features.data ();
0 commit comments