Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ jobs:
- name: Test
id: cmake_test
run: |
export LLAMA_SET_ROWS=0
cd build
ctest -L main --verbose --timeout 900

Expand Down Expand Up @@ -437,6 +438,7 @@ jobs:
- name: Test
id: cmake_test
run: |
export LLAMA_SET_ROWS=0
cd build
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 3600
Expand Down
72 changes: 46 additions & 26 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ struct webgpu_context_struct {
wgpu::Limits limits;

std::recursive_mutex mutex;
std::mutex get_tensor_mutex;
std::mutex init_mutex;

bool device_init = false;

Expand All @@ -139,6 +137,8 @@ struct webgpu_context_struct {

// Parameter buffers associated with the staged command buffers
std::vector<webgpu_param_bufs> staged_param_bufs;

std::vector<wgpu::FutureWaitInfo> callback_futures;
};

typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
Expand Down Expand Up @@ -221,25 +221,39 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,

/** WebGPU Actions */

// Wait for the queue to finish processing all submitted work
static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
// Wait for the queue to finish processing all commands
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data);
}
}),
UINT64_MAX);
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
if (ctx->callback_futures.empty()) {
// no existing callbacks, wait on queue submission
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
}
}),
UINT64_MAX);
} else {
// existing callbacks, wait on them
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
ctx->callback_futures.clear();
}
}

static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()");
if (ctx->staged_command_bufs.empty()) {
// Nothing to submit
return;
}
ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
ctx->staged_command_bufs.clear();
std::vector<webgpu_param_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);

// Free the staged parameter buffers once the submission completes
ctx->queue.OnSubmittedWorkDone(
wgpu::Future f = ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
Expand All @@ -248,6 +262,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
// Free the staged parameter buffers
ctx->param_buf_pool.free_bufs(staged_param_bufs);
});
ctx->callback_futures.push_back({ f });
}

static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
Expand All @@ -273,7 +288,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
std::vector<uint32_t> params,
std::vector<wgpu::BindGroupEntry> bind_group_entries,
uint32_t wg_x,
bool submit_imm = false) {
bool submit_and_wait = false) {
webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();

ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
Expand Down Expand Up @@ -304,17 +319,18 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
pass.DispatchWorkgroups(wg_x, 1, 1);
pass.End();
wgpu::CommandBuffer commands = encoder.Finish();
if (submit_imm) {
// Submit immediately
if (submit_and_wait) {
// Submit and wait immediately
ctx->queue.Submit(1, &commands);
ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
[ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
message.data);
}
ctx->param_buf_pool.free_bufs({ params_bufs });
});
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
}
ctx->param_buf_pool.free_bufs({ params_bufs });
}),
UINT64_MAX);
} else {
// Lock the context mutex when pushing to the staging vectors.
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
Expand Down Expand Up @@ -579,6 +595,9 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
// memset the remaining bytes
ggml_backend_webgpu_buffer_memset(
webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
} else {
// wait for WriteBuffer to complete
ggml_backend_webgpu_wait_on_submission(webgpu_ctx);
}
}

Expand All @@ -602,7 +621,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
final_size = size + (4 - (size % 4));
}

std::lock_guard<std::mutex> lock(webgpu_ctx->get_tensor_mutex);
std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);

if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
// Create a new staging buffer if it doesn't exist or is too small
Expand Down Expand Up @@ -768,10 +787,11 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;

// Multiple threads may try to initialize the device
std::lock_guard<std::mutex> lock(webgpu_ctx->init_mutex);
std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
if (!webgpu_ctx->device_init) {
// Initialize device
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization };
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
wgpu::FeatureName::ImplicitDeviceSynchronization };
wgpu::DeviceDescriptor dev_desc;
dev_desc.requiredLimits = &webgpu_ctx->limits;
dev_desc.requiredFeatures = required_features.data();
Expand Down