1919#include < vector>
2020
2121#ifdef GGML_WEBGPU_DEBUG
22- # define WEBGPU_LOG_DEBUG (msg ) std::cout << msg << std::endl
22+ # define WEBGPU_LOG_DEBUG (msg ) std::cout << msg << std::endl
2323# define WEBGPU_DEBUG_BUF_ELEMS 32
2424#else
2525# define WEBGPU_LOG_DEBUG (msg ) ((void ) 0 )
2626#endif // GGML_WEBGPU_DEBUG
2727
2828/* Constants */
2929
30- #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
31- #define WEBGPU_MUL_MAT_WG_SIZE 64
32- #define WEBGPU_NUM_PARAM_BUFS 100
33- #define WEBGPU_PARAMS_BUF_SIZE_BYTES 256
34- #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
30+ #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
31+ #define WEBGPU_MUL_MAT_WG_SIZE 64
32+ #define WEBGPU_NUM_PARAM_BUFS 100
33+ #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
34+ #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
35+ #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
36+ #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
3537
3638/* End Constants */
3739
@@ -55,46 +57,42 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
5557 wgpu::BufferUsage usage,
5658 const char * label);
5759
58- struct webgpu_param_bufs {
60+ struct webgpu_pool_bufs {
5961 wgpu::Buffer host_buf;
6062 wgpu::Buffer dev_buf;
6163};
6264
6365// Holds a pool of parameter buffers for WebGPU operations
64- struct webgpu_param_buf_pool {
65- std::vector<webgpu_param_bufs > free;
66+ struct webgpu_buf_pool {
67+ std::vector<webgpu_pool_bufs > free;
6668
6769 std::mutex mutex;
6870
6971 std::condition_variable cv;
7072
71- void init (wgpu::Device device) {
72- for (int i = 0 ; i < WEBGPU_NUM_PARAM_BUFS; i++) {
73+ void init (wgpu::Device device,
74+ int num_bufs,
75+ size_t buf_size,
76+ wgpu::BufferUsage dev_buf_usage,
77+ wgpu::BufferUsage host_buf_usage) {
78+ for (int i = 0 ; i < num_bufs; i++) {
7379 wgpu::Buffer host_buf;
7480 wgpu::Buffer dev_buf;
75- ggml_webgpu_create_buffer (device,
76- host_buf,
77- WEBGPU_PARAMS_BUF_SIZE_BYTES,
78- wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite,
79- " ggml_webgpu_host_params_buf" );
80- ggml_webgpu_create_buffer (device,
81- dev_buf,
82- WEBGPU_PARAMS_BUF_SIZE_BYTES,
83- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
84- " ggml_webgpu_dev_params_buf" );
81+ ggml_webgpu_create_buffer (device, host_buf, buf_size, host_buf_usage, " ggml_webgpu_host_pool_buf" );
82+ ggml_webgpu_create_buffer (device, dev_buf, buf_size, dev_buf_usage, " ggml_webgpu_dev_pool_buf" );
8583 free.push_back ({ host_buf, dev_buf });
8684 }
8785 }
8886
89- webgpu_param_bufs alloc_bufs () {
87+ webgpu_pool_bufs alloc_bufs () {
9088 std::unique_lock<std::mutex> lock (mutex);
9189 cv.wait (lock, [this ] { return !free.empty (); });
92- webgpu_param_bufs bufs = free.back ();
90+ webgpu_pool_bufs bufs = free.back ();
9391 free.pop_back ();
9492 return bufs;
9593 }
9694
97- void free_bufs (std::vector<webgpu_param_bufs > bufs) {
95+ void free_bufs (std::vector<webgpu_pool_bufs > bufs) {
9896 std::lock_guard<std::mutex> lock (mutex);
9997 free.insert (free.end (), bufs.begin (), bufs.end ());
10098 cv.notify_all ();
@@ -122,7 +120,8 @@ struct webgpu_context_struct {
122120
123121 bool device_init = false ;
124122
125- webgpu_param_buf_pool param_buf_pool;
123+ webgpu_buf_pool param_buf_pool;
124+ webgpu_buf_pool set_rows_error_buf_pool;
126125
127126 wgpu::ComputePipeline memset_pipeline;
128127 wgpu::ComputePipeline mul_mat_pipeline;
@@ -138,15 +137,16 @@ struct webgpu_context_struct {
138137 std::vector<wgpu::CommandBuffer> staged_command_bufs;
139138
140139 // Parameter buffers associated with the staged command buffers
141- std::vector<webgpu_param_bufs> staged_param_bufs;
140+ std::vector<webgpu_pool_bufs> staged_param_bufs;
141+ // Buffers associated with set_rows operations, used to store potential errors
142+ std::vector<webgpu_pool_bufs> staged_set_row_error_bufs;
142143
143144 std::vector<wgpu::FutureWaitInfo> callback_futures;
144145
145146#ifdef GGML_WEBGPU_DEBUG
146147 wgpu::Buffer debug_host_buf;
147148 wgpu::Buffer debug_dev_buf;
148149#endif
149-
150150};
151151
152152typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@@ -257,20 +257,55 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
257257 return ;
258258 }
259259 ctx->queue .Submit (ctx->staged_command_bufs .size (), ctx->staged_command_bufs .data ());
260+
261+ // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
262+ if (ctx->staged_set_row_error_bufs .size () > 0 ) {
263+ wgpu::CommandEncoder encoder = ctx->device .CreateCommandEncoder ();
264+ for (auto & error_bufs : ctx->staged_set_row_error_bufs ) {
265+ // Copy the error buffer to the host buffer
266+ encoder.CopyBufferToBuffer (error_bufs.dev_buf , 0 , error_bufs.host_buf , 0 , error_bufs.host_buf .GetSize ());
267+ }
268+ wgpu::CommandBuffer commands = encoder.Finish ();
269+ ctx->queue .Submit (1 , &commands);
270+ }
271+
260272 ctx->staged_command_bufs .clear ();
261- std::vector<webgpu_param_bufs> staged_param_bufs = std::move (ctx->staged_param_bufs );
273+ std::vector<webgpu_pool_bufs> staged_param_bufs = std::move (ctx->staged_param_bufs );
274+ std::vector<webgpu_pool_bufs> staged_set_row_error_bufs = std::move (ctx->staged_set_row_error_bufs );
262275
263276 // Free the staged parameter buffers once the submission completes
264- wgpu::Future f = ctx->queue .OnSubmittedWorkDone (
277+ wgpu::Future p_f = ctx->queue .OnSubmittedWorkDone (
265278 wgpu::CallbackMode::AllowSpontaneous,
266279 [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
267280 if (status != wgpu::QueueWorkDoneStatus::Success) {
268281 GGML_LOG_ERROR (" ggml_webgpu: Failed to submit commands: %s\n " , message.data );
269282 }
270- // Free the staged parameter buffers
283+ // Free the staged buffers
271284 ctx->param_buf_pool .free_bufs (staged_param_bufs);
272285 });
273- ctx->callback_futures .push_back ({ f });
286+ ctx->callback_futures .push_back ({ p_f });
287+
288+ // Check for errrors in SET_ROWS operations
289+ for (auto & error_bufs : staged_set_row_error_bufs) {
290+ wgpu::Future f = error_bufs.host_buf .MapAsync (
291+ wgpu::MapMode::Read,
292+ 0 ,
293+ error_bufs.host_buf .GetSize (),
294+ wgpu::CallbackMode::AllowSpontaneous,
295+ [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
296+ if (status != wgpu::MapAsyncStatus::Success) {
297+ GGML_LOG_ERROR (" ggml_webgpu: Failed to map error buffer: %s\n " , message.data );
298+ } else {
299+ const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf .GetConstMappedRange ();
300+ if (*error_data) {
301+ GGML_ABORT (" ggml_webgpu: SET_ROWS index > 2^32, unsupported." );
302+ }
303+ // We can't unmap in here due to WebGPU reentrancy limitations.
304+ ctx->set_rows_error_buf_pool .free_bufs ({ error_bufs });
305+ }
306+ });
307+ ctx->callback_futures .push_back ({ f });
308+ }
274309}
275310
276311static void ggml_backend_webgpu_map_buffer (webgpu_context & ctx,
@@ -294,7 +329,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
294329#ifdef GGML_WEBGPU_DEBUG
295330// This function adds debugging information to shaders, as WebGPU does not support printing directly.
296331// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
297- // debug statements in the shader, and then call this function after encoding the commands.
332+ // debug statements in the shader, and then call this function after encoding the commands and submitting them .
298333static void ggml_backend_webgpu_debug (webgpu_context & ctx) {
299334 wgpu::CommandEncoder encoder = ctx->device .CreateCommandEncoder ();
300335 encoder.CopyBufferToBuffer (ctx->debug_dev_buf , 0 , ctx->debug_host_buf , 0 , ctx->debug_host_buf .GetSize ());
@@ -318,7 +353,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
318353 std::vector<wgpu::BindGroupEntry> bind_group_entries,
319354 uint32_t wg_x,
320355 bool submit_and_wait = false ) {
321- webgpu_param_bufs params_bufs = ctx->param_buf_pool .alloc_bufs ();
356+ webgpu_pool_bufs params_bufs = ctx->param_buf_pool .alloc_bufs ();
322357
323358 ggml_backend_webgpu_map_buffer (ctx, params_bufs.host_buf , wgpu::MapMode::Write, 0 , params_bufs.host_buf .GetSize ());
324359 uint32_t * _params = (uint32_t *) params_bufs.host_buf .GetMappedRange ();
@@ -464,6 +499,12 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
464499 return ;
465500 }
466501
502+ // allocate error bufs
503+ webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool .alloc_bufs ();
504+ if (error_bufs.host_buf .GetMapState () == wgpu::BufferMapState::Mapped) {
505+ error_bufs.host_buf .Unmap ();
506+ }
507+
467508 size_t src_offset = ggml_backend_webgpu_tensor_offset (src);
468509 // assumes power of 2 offset alignment
469510 size_t src_misalignment = src_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
@@ -476,8 +517,7 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
476517 size_t dst_misalignment = dst_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
477518 dst_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
478519
479- std::vector<uint32_t > params = {
480- (uint32_t ) (src_misalignment / ggml_type_size (src->type )),
520+ std::vector<uint32_t > params = { (uint32_t ) (src_misalignment / ggml_type_size (src->type )),
481521 (uint32_t ) (idx_misalignment / ggml_type_size (idx->type )),
482522 (uint32_t ) (dst_misalignment / ggml_type_size (dst->type )),
483523 // Convert byte-strides to element-strides
@@ -497,28 +537,31 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
497537 (uint32_t ) src->ne [3 ],
498538 // Shape of idx
499539 (uint32_t ) (idx->ne [1 ]),
500- (uint32_t ) (idx->ne [2 ])
501- };
540+ (uint32_t ) (idx->ne [2 ]) };
502541
503542 std::vector<wgpu::BindGroupEntry> entries = {
504543 { .binding = 0 ,
505544 .buffer = ggml_backend_webgpu_tensor_buf (src),
506545 .offset = ggml_backend_webgpu_tensor_offset (src),
507- .size = ggml_nbytes (src) },
546+ .size = ggml_nbytes (src) },
508547 { .binding = 1 ,
509548 .buffer = ggml_backend_webgpu_tensor_buf (idx),
510549 .offset = ggml_backend_webgpu_tensor_offset (idx),
511- .size = ggml_nbytes (idx) },
550+ .size = ggml_nbytes (idx) },
512551 { .binding = 2 ,
513552 .buffer = ggml_backend_webgpu_tensor_buf (dst),
514553 .offset = ggml_backend_webgpu_tensor_offset (dst),
515- .size = ggml_nbytes (dst) }
554+ .size = ggml_nbytes (dst) },
555+ { .binding = 3 , .buffer = error_bufs.dev_buf , .offset = 0 , .size = error_bufs.dev_buf .GetSize () }
516556 };
517557
518558 size_t max_wg_size = ctx->limits .maxComputeWorkgroupSizeX ;
519- uint32_t wg_x = (src->ne [1 ] * src->ne [2 ] * src->ne [3 ] + max_wg_size - 1 ) / max_wg_size;
559+ uint32_t wg_x = (src->ne [1 ] * src->ne [2 ] * src->ne [3 ] + max_wg_size - 1 ) / max_wg_size;
560+
561+ std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
562+ ctx->staged_set_row_error_bufs .push_back (error_bufs);
563+
520564 ggml_backend_webgpu_build_and_enqueue (ctx, ctx->set_rows_pipeline , params, entries, wg_x);
521- ggml_backend_webgpu_submit_queue (ctx);
522565}
523566
524567static void ggml_webgpu_mul_mat (webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
@@ -872,7 +915,8 @@ static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
872915 std::vector<wgpu::ConstantEntry> constants (1 );
873916 constants[0 ].key = " wg_size" ;
874917 constants[0 ].value = webgpu_ctx->limits .maxComputeWorkgroupSizeX ;
875- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->set_rows_pipeline , wgsl_set_rows, " set_rows" , constants);
918+ ggml_webgpu_create_pipeline (
919+ webgpu_ctx->device , webgpu_ctx->set_rows_pipeline , wgsl_set_rows, " set_rows" , constants);
876920}
877921
878922static void ggml_webgpu_init_cpy_pipeline (webgpu_context & webgpu_ctx) {
@@ -931,7 +975,16 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
931975 webgpu_ctx->queue = webgpu_ctx->device .GetQueue ();
932976
933977 // Create buffer pool for shader parameters
934- webgpu_ctx->param_buf_pool .init (webgpu_ctx->device );
978+ webgpu_ctx->param_buf_pool .init (webgpu_ctx->device ,
979+ WEBGPU_NUM_PARAM_BUFS,
980+ WEBGPU_PARAMS_BUF_SIZE_BYTES,
981+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
982+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
983+ webgpu_ctx->set_rows_error_buf_pool .init (webgpu_ctx->device ,
984+ WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
985+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
986+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
987+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
935988
936989 ggml_webgpu_init_memset_pipeline (webgpu_ctx);
937990 ggml_webgpu_init_mul_mat_pipeline (webgpu_ctx);
0 commit comments