Skip to content

Commit 9515c61

Browse files
authored
ggml: WebGPU disable SET_ROWS for now (#15078)
* Add paramater buffer pool, batching of submissions, refactor command building/submission * Add header for linux builds * Free staged parameter buffers at once * Format with clang-format * Fix thread-safe implementation * Use device implicit synchronization * Update workflow to use custom release * Remove testing branch workflow * Disable set_rows until it's implemented * Fix potential issue around empty queue submission * Try synchronous submission * Try waiting on all futures explicitly * Add debug * Add more debug messages * Work on getting ssh access for debugging * Debug on failure * Disable other tests * Remove extra if * Try more locking * maybe passes? * test * Some cleanups * Restore build file * Remove extra testing branch ci
1 parent fd1234c commit 9515c61

File tree

2 files changed

+48
-26
lines changed

2 files changed

+48
-26
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ jobs:
179179
- name: Test
180180
id: cmake_test
181181
run: |
182+
export LLAMA_SET_ROWS=0
182183
cd build
183184
ctest -L main --verbose --timeout 900
184185
@@ -437,6 +438,7 @@ jobs:
437438
- name: Test
438439
id: cmake_test
439440
run: |
441+
export LLAMA_SET_ROWS=0
440442
cd build
441443
# This is using llvmpipe and runs slower than other backends
442444
ctest -L main --verbose --timeout 3600

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,6 @@ struct webgpu_context_struct {
118118
wgpu::Limits limits;
119119

120120
std::recursive_mutex mutex;
121-
std::mutex get_tensor_mutex;
122-
std::mutex init_mutex;
123121

124122
bool device_init = false;
125123

@@ -139,6 +137,8 @@ struct webgpu_context_struct {
139137

140138
// Parameter buffers associated with the staged command buffers
141139
std::vector<webgpu_param_bufs> staged_param_bufs;
140+
141+
std::vector<wgpu::FutureWaitInfo> callback_futures;
142142
};
143143

144144
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@@ -221,25 +221,39 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
221221

222222
/** WebGPU Actions */
223223

224+
// Wait for the queue to finish processing all submitted work
224225
static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
225-
// Wait for the queue to finish processing all commands
226-
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
227-
wgpu::CallbackMode::AllowSpontaneous,
228-
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
229-
if (status != wgpu::QueueWorkDoneStatus::Success) {
230-
GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data);
231-
}
232-
}),
233-
UINT64_MAX);
226+
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
227+
if (ctx->callback_futures.empty()) {
228+
// no existing callbacks, wait on queue submission
229+
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
230+
wgpu::CallbackMode::AllowSpontaneous,
231+
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
232+
if (status != wgpu::QueueWorkDoneStatus::Success) {
233+
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
234+
}
235+
}),
236+
UINT64_MAX);
237+
} else {
238+
// existing callbacks, wait on them
239+
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
240+
ctx->callback_futures.clear();
241+
}
234242
}
235243

236244
static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
237245
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
246+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()");
247+
if (ctx->staged_command_bufs.empty()) {
248+
// Nothing to submit
249+
return;
250+
}
238251
ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
239252
ctx->staged_command_bufs.clear();
240253
std::vector<webgpu_param_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
254+
241255
// Free the staged parameter buffers once the submission completes
242-
ctx->queue.OnSubmittedWorkDone(
256+
wgpu::Future f = ctx->queue.OnSubmittedWorkDone(
243257
wgpu::CallbackMode::AllowSpontaneous,
244258
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
245259
if (status != wgpu::QueueWorkDoneStatus::Success) {
@@ -248,6 +262,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
248262
// Free the staged parameter buffers
249263
ctx->param_buf_pool.free_bufs(staged_param_bufs);
250264
});
265+
ctx->callback_futures.push_back({ f });
251266
}
252267

253268
static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
@@ -273,7 +288,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
273288
std::vector<uint32_t> params,
274289
std::vector<wgpu::BindGroupEntry> bind_group_entries,
275290
uint32_t wg_x,
276-
bool submit_imm = false) {
291+
bool submit_and_wait = false) {
277292
webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
278293

279294
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
@@ -304,17 +319,18 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
304319
pass.DispatchWorkgroups(wg_x, 1, 1);
305320
pass.End();
306321
wgpu::CommandBuffer commands = encoder.Finish();
307-
if (submit_imm) {
308-
// Submit immediately
322+
if (submit_and_wait) {
323+
// Submit and wait immediately
309324
ctx->queue.Submit(1, &commands);
310-
ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
311-
[ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
312-
if (status != wgpu::QueueWorkDoneStatus::Success) {
313-
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
314-
message.data);
315-
}
316-
ctx->param_buf_pool.free_bufs({ params_bufs });
317-
});
325+
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
326+
wgpu::CallbackMode::AllowSpontaneous,
327+
[ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
328+
if (status != wgpu::QueueWorkDoneStatus::Success) {
329+
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
330+
}
331+
ctx->param_buf_pool.free_bufs({ params_bufs });
332+
}),
333+
UINT64_MAX);
318334
} else {
319335
// Lock the context mutex when pushing to the staging vectors.
320336
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
@@ -579,6 +595,9 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
579595
// memset the remaining bytes
580596
ggml_backend_webgpu_buffer_memset(
581597
webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
598+
} else {
599+
// wait for WriteBuffer to complete
600+
ggml_backend_webgpu_wait_on_submission(webgpu_ctx);
582601
}
583602
}
584603

@@ -602,7 +621,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
602621
final_size = size + (4 - (size % 4));
603622
}
604623

605-
std::lock_guard<std::mutex> lock(webgpu_ctx->get_tensor_mutex);
624+
std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
606625

607626
if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
608627
// Create a new staging buffer if it doesn't exist or is too small
@@ -768,10 +787,11 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
768787
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
769788

770789
// Multiple threads may try to initialize the device
771-
std::lock_guard<std::mutex> lock(webgpu_ctx->init_mutex);
790+
std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
772791
if (!webgpu_ctx->device_init) {
773792
// Initialize device
774-
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization };
793+
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
794+
wgpu::FeatureName::ImplicitDeviceSynchronization };
775795
wgpu::DeviceDescriptor dev_desc;
776796
dev_desc.requiredLimits = &webgpu_ctx->limits;
777797
dev_desc.requiredFeatures = required_features.data();

0 commit comments

Comments
 (0)