Skip to content

Commit 248f7a5

Browse files
committed
Add error buffers for reporting unsupported SET_ROWS indices
1 parent b2dbfcd commit 248f7a5

File tree

2 files changed

+107
-45
lines changed

2 files changed

+107
-45
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 97 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,21 @@
1919
#include <vector>
2020

2121
#ifdef GGML_WEBGPU_DEBUG
22-
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
22+
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
2323
# define WEBGPU_DEBUG_BUF_ELEMS 32
2424
#else
2525
# define WEBGPU_LOG_DEBUG(msg) ((void) 0)
2626
#endif // GGML_WEBGPU_DEBUG
2727

2828
/* Constants */
2929

30-
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
31-
#define WEBGPU_MUL_MAT_WG_SIZE 64
32-
#define WEBGPU_NUM_PARAM_BUFS 100
33-
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 256
34-
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
30+
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
31+
#define WEBGPU_MUL_MAT_WG_SIZE 64
32+
#define WEBGPU_NUM_PARAM_BUFS 100
33+
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
34+
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
35+
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
36+
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
3537

3638
/* End Constants */
3739

@@ -55,46 +57,42 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
5557
wgpu::BufferUsage usage,
5658
const char * label);
5759

58-
struct webgpu_param_bufs {
60+
struct webgpu_pool_bufs {
5961
wgpu::Buffer host_buf;
6062
wgpu::Buffer dev_buf;
6163
};
6264

6365
// Holds a pool of parameter buffers for WebGPU operations
64-
struct webgpu_param_buf_pool {
65-
std::vector<webgpu_param_bufs> free;
66+
struct webgpu_buf_pool {
67+
std::vector<webgpu_pool_bufs> free;
6668

6769
std::mutex mutex;
6870

6971
std::condition_variable cv;
7072

71-
void init(wgpu::Device device) {
72-
for (int i = 0; i < WEBGPU_NUM_PARAM_BUFS; i++) {
73+
void init(wgpu::Device device,
74+
int num_bufs,
75+
size_t buf_size,
76+
wgpu::BufferUsage dev_buf_usage,
77+
wgpu::BufferUsage host_buf_usage) {
78+
for (int i = 0; i < num_bufs; i++) {
7379
wgpu::Buffer host_buf;
7480
wgpu::Buffer dev_buf;
75-
ggml_webgpu_create_buffer(device,
76-
host_buf,
77-
WEBGPU_PARAMS_BUF_SIZE_BYTES,
78-
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite,
79-
"ggml_webgpu_host_params_buf");
80-
ggml_webgpu_create_buffer(device,
81-
dev_buf,
82-
WEBGPU_PARAMS_BUF_SIZE_BYTES,
83-
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
84-
"ggml_webgpu_dev_params_buf");
81+
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
82+
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
8583
free.push_back({ host_buf, dev_buf });
8684
}
8785
}
8886

89-
webgpu_param_bufs alloc_bufs() {
87+
webgpu_pool_bufs alloc_bufs() {
9088
std::unique_lock<std::mutex> lock(mutex);
9189
cv.wait(lock, [this] { return !free.empty(); });
92-
webgpu_param_bufs bufs = free.back();
90+
webgpu_pool_bufs bufs = free.back();
9391
free.pop_back();
9492
return bufs;
9593
}
9694

97-
void free_bufs(std::vector<webgpu_param_bufs> bufs) {
95+
void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
9896
std::lock_guard<std::mutex> lock(mutex);
9997
free.insert(free.end(), bufs.begin(), bufs.end());
10098
cv.notify_all();
@@ -122,7 +120,8 @@ struct webgpu_context_struct {
122120

123121
bool device_init = false;
124122

125-
webgpu_param_buf_pool param_buf_pool;
123+
webgpu_buf_pool param_buf_pool;
124+
webgpu_buf_pool set_rows_error_buf_pool;
126125

127126
wgpu::ComputePipeline memset_pipeline;
128127
wgpu::ComputePipeline mul_mat_pipeline;
@@ -138,15 +137,16 @@ struct webgpu_context_struct {
138137
std::vector<wgpu::CommandBuffer> staged_command_bufs;
139138

140139
// Parameter buffers associated with the staged command buffers
141-
std::vector<webgpu_param_bufs> staged_param_bufs;
140+
std::vector<webgpu_pool_bufs> staged_param_bufs;
141+
// Buffers associated with set_rows operations, used to store potential errors
142+
std::vector<webgpu_pool_bufs> staged_set_row_error_bufs;
142143

143144
std::vector<wgpu::FutureWaitInfo> callback_futures;
144145

145146
#ifdef GGML_WEBGPU_DEBUG
146147
wgpu::Buffer debug_host_buf;
147148
wgpu::Buffer debug_dev_buf;
148149
#endif
149-
150150
};
151151

152152
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@@ -257,20 +257,55 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
257257
return;
258258
}
259259
ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
260+
261+
// If there are SET_ROWS operations in this submission, copy their error buffers to the host.
262+
if (ctx->staged_set_row_error_bufs.size() > 0) {
263+
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
264+
for (auto & error_bufs : ctx->staged_set_row_error_bufs) {
265+
// Copy the error buffer to the host buffer
266+
encoder.CopyBufferToBuffer(error_bufs.dev_buf, 0, error_bufs.host_buf, 0, error_bufs.host_buf.GetSize());
267+
}
268+
wgpu::CommandBuffer commands = encoder.Finish();
269+
ctx->queue.Submit(1, &commands);
270+
}
271+
260272
ctx->staged_command_bufs.clear();
261-
std::vector<webgpu_param_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
273+
std::vector<webgpu_pool_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
274+
std::vector<webgpu_pool_bufs> staged_set_row_error_bufs = std::move(ctx->staged_set_row_error_bufs);
262275

263276
// Free the staged parameter buffers once the submission completes
264-
wgpu::Future f = ctx->queue.OnSubmittedWorkDone(
277+
wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
265278
wgpu::CallbackMode::AllowSpontaneous,
266279
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
267280
if (status != wgpu::QueueWorkDoneStatus::Success) {
268281
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
269282
}
270-
// Free the staged parameter buffers
283+
// Free the staged buffers
271284
ctx->param_buf_pool.free_bufs(staged_param_bufs);
272285
});
273-
ctx->callback_futures.push_back({ f });
286+
ctx->callback_futures.push_back({ p_f });
287+
288+
// Check for errrors in SET_ROWS operations
289+
for (auto & error_bufs : staged_set_row_error_bufs) {
290+
wgpu::Future f = error_bufs.host_buf.MapAsync(
291+
wgpu::MapMode::Read,
292+
0,
293+
error_bufs.host_buf.GetSize(),
294+
wgpu::CallbackMode::AllowSpontaneous,
295+
[ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
296+
if (status != wgpu::MapAsyncStatus::Success) {
297+
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", message.data);
298+
} else {
299+
const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange();
300+
if (*error_data) {
301+
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
302+
}
303+
// We can't unmap in here due to WebGPU reentrancy limitations.
304+
ctx->set_rows_error_buf_pool.free_bufs({ error_bufs });
305+
}
306+
});
307+
ctx->callback_futures.push_back({ f });
308+
}
274309
}
275310

276311
static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
@@ -294,7 +329,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
294329
#ifdef GGML_WEBGPU_DEBUG
295330
// This function adds debugging information to shaders, as WebGPU does not support printing directly.
296331
// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
297-
// debug statements in the shader, and then call this function after encoding the commands.
332+
// debug statements in the shader, and then call this function after encoding the commands and submitting them.
298333
static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
299334
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
300335
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
@@ -318,7 +353,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
318353
std::vector<wgpu::BindGroupEntry> bind_group_entries,
319354
uint32_t wg_x,
320355
bool submit_and_wait = false) {
321-
webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
356+
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
322357

323358
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
324359
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
@@ -464,6 +499,12 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
464499
return;
465500
}
466501

502+
// allocate error bufs
503+
webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
504+
if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
505+
error_bufs.host_buf.Unmap();
506+
}
507+
467508
size_t src_offset = ggml_backend_webgpu_tensor_offset(src);
468509
// assumes power of 2 offset alignment
469510
size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
@@ -476,8 +517,7 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
476517
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
477518
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
478519

479-
std::vector<uint32_t> params = {
480-
(uint32_t) (src_misalignment / ggml_type_size(src->type)),
520+
std::vector<uint32_t> params = { (uint32_t) (src_misalignment / ggml_type_size(src->type)),
481521
(uint32_t) (idx_misalignment / ggml_type_size(idx->type)),
482522
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
483523
// Convert byte-strides to element-strides
@@ -497,28 +537,31 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
497537
(uint32_t) src->ne[3],
498538
// Shape of idx
499539
(uint32_t) (idx->ne[1]),
500-
(uint32_t) (idx->ne[2])
501-
};
540+
(uint32_t) (idx->ne[2]) };
502541

503542
std::vector<wgpu::BindGroupEntry> entries = {
504543
{ .binding = 0,
505544
.buffer = ggml_backend_webgpu_tensor_buf(src),
506545
.offset = ggml_backend_webgpu_tensor_offset(src),
507-
.size = ggml_nbytes(src) },
546+
.size = ggml_nbytes(src) },
508547
{ .binding = 1,
509548
.buffer = ggml_backend_webgpu_tensor_buf(idx),
510549
.offset = ggml_backend_webgpu_tensor_offset(idx),
511-
.size = ggml_nbytes(idx) },
550+
.size = ggml_nbytes(idx) },
512551
{ .binding = 2,
513552
.buffer = ggml_backend_webgpu_tensor_buf(dst),
514553
.offset = ggml_backend_webgpu_tensor_offset(dst),
515-
.size = ggml_nbytes(dst) }
554+
.size = ggml_nbytes(dst) },
555+
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
516556
};
517557

518558
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
519-
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
559+
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
560+
561+
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
562+
ctx->staged_set_row_error_bufs.push_back(error_bufs);
563+
520564
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
521-
ggml_backend_webgpu_submit_queue(ctx);
522565
}
523566

524567
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
@@ -872,7 +915,8 @@ static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
872915
std::vector<wgpu::ConstantEntry> constants(1);
873916
constants[0].key = "wg_size";
874917
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
875-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants);
918+
ggml_webgpu_create_pipeline(
919+
webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants);
876920
}
877921

878922
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
@@ -931,7 +975,16 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
931975
webgpu_ctx->queue = webgpu_ctx->device.GetQueue();
932976

933977
// Create buffer pool for shader parameters
934-
webgpu_ctx->param_buf_pool.init(webgpu_ctx->device);
978+
webgpu_ctx->param_buf_pool.init(webgpu_ctx->device,
979+
WEBGPU_NUM_PARAM_BUFS,
980+
WEBGPU_PARAMS_BUF_SIZE_BYTES,
981+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
982+
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
983+
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->device,
984+
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
985+
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
986+
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
987+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
935988

936989
ggml_webgpu_init_memset_pipeline(webgpu_ctx);
937990
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);

ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ var<storage, read_write> idx: array<u32>;
99
@group(0) @binding(2)
1010
var<storage, read_write> dst: array<f16>;
1111

12+
@group(0) @binding(3)
13+
var<storage, read_write> error: atomic<u32>;
14+
1215
struct Params {
1316
offset_src: u32, // in elements
1417
offset_idx: u32, // in elements
@@ -38,7 +41,7 @@ struct Params {
3841
idx2: u32,
3942
};
4043

41-
@group(0) @binding(3)
44+
@group(0) @binding(4)
4245
var<uniform> params: Params;
4346

4447
override wg_size: u32;
@@ -64,6 +67,12 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
6467
let idx_high_val = idx[idx_high];
6568
let idx_low_val = idx[idx_high + 1];
6669

70+
if (idx_low_val != 0) {
71+
// Upper bits of index are not zero, output will be incorrect
72+
atomicStore(&error, 1);
73+
return;
74+
}
75+
6776
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
6877
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
6978

0 commit comments

Comments
 (0)