1919#include < vector>
2020
2121#ifdef GGML_WEBGPU_DEBUG
22- # define WEBGPU_LOG_DEBUG (msg ) std::cout << msg << std::endl
22+ # define WEBGPU_LOG_DEBUG (msg ) std::cout << msg << std::endl
23+ # define WEBGPU_DEBUG_BUF_ELEMS 32
2324#else
2425# define WEBGPU_LOG_DEBUG (msg ) ((void ) 0 )
2526#endif // GGML_WEBGPU_DEBUG
2627
2728/* Constants */
2829
29- #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
30- #define WEBGPU_MUL_MAT_WG_SIZE 64
31- #define WEBGPU_NUM_PARAM_BUFS 100
32- #define WEBGPU_PARAMS_BUF_SIZE_BYTES 256
33- #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
30+ #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
31+ #define WEBGPU_MUL_MAT_WG_SIZE 64
32+ #define WEBGPU_NUM_PARAM_BUFS 100
33+ #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
34+ #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
35+ #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
36+ #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
3437
3538/* End Constants */
3639
@@ -54,46 +57,42 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
5457 wgpu::BufferUsage usage,
5558 const char * label);
5659
57- struct webgpu_param_bufs {
60+ struct webgpu_pool_bufs {
5861 wgpu::Buffer host_buf;
5962 wgpu::Buffer dev_buf;
6063};
6164
6265// Holds a pool of parameter buffers for WebGPU operations
63- struct webgpu_param_buf_pool {
64- std::vector<webgpu_param_bufs > free;
66+ struct webgpu_buf_pool {
67+ std::vector<webgpu_pool_bufs > free;
6568
6669 std::mutex mutex;
6770
6871 std::condition_variable cv;
6972
70- void init (wgpu::Device device) {
71- for (int i = 0 ; i < WEBGPU_NUM_PARAM_BUFS; i++) {
73+ void init (wgpu::Device device,
74+ int num_bufs,
75+ size_t buf_size,
76+ wgpu::BufferUsage dev_buf_usage,
77+ wgpu::BufferUsage host_buf_usage) {
78+ for (int i = 0 ; i < num_bufs; i++) {
7279 wgpu::Buffer host_buf;
7380 wgpu::Buffer dev_buf;
74- ggml_webgpu_create_buffer (device,
75- host_buf,
76- WEBGPU_PARAMS_BUF_SIZE_BYTES,
77- wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite,
78- " ggml_webgpu_host_params_buf" );
79- ggml_webgpu_create_buffer (device,
80- dev_buf,
81- WEBGPU_PARAMS_BUF_SIZE_BYTES,
82- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
83- " ggml_webgpu_dev_params_buf" );
81+ ggml_webgpu_create_buffer (device, host_buf, buf_size, host_buf_usage, " ggml_webgpu_host_pool_buf" );
82+ ggml_webgpu_create_buffer (device, dev_buf, buf_size, dev_buf_usage, " ggml_webgpu_dev_pool_buf" );
8483 free.push_back ({ host_buf, dev_buf });
8584 }
8685 }
8786
88- webgpu_param_bufs alloc_bufs () {
87+ webgpu_pool_bufs alloc_bufs () {
8988 std::unique_lock<std::mutex> lock (mutex);
9089 cv.wait (lock, [this ] { return !free.empty (); });
91- webgpu_param_bufs bufs = free.back ();
90+ webgpu_pool_bufs bufs = free.back ();
9291 free.pop_back ();
9392 return bufs;
9493 }
9594
96- void free_bufs (std::vector<webgpu_param_bufs > bufs) {
95+ void free_bufs (std::vector<webgpu_pool_bufs > bufs) {
9796 std::lock_guard<std::mutex> lock (mutex);
9897 free.insert (free.end (), bufs.begin (), bufs.end ());
9998 cv.notify_all ();
@@ -121,10 +120,12 @@ struct webgpu_context_struct {
121120
122121 bool device_init = false ;
123122
124- webgpu_param_buf_pool param_buf_pool;
123+ webgpu_buf_pool param_buf_pool;
124+ webgpu_buf_pool set_rows_error_buf_pool;
125125
126126 wgpu::ComputePipeline memset_pipeline;
127127 wgpu::ComputePipeline mul_mat_pipeline;
128+ wgpu::ComputePipeline set_rows_pipeline;
128129 wgpu::ComputePipeline cpy_pipeline;
129130
130131 size_t memset_bytes_per_thread;
@@ -136,9 +137,16 @@ struct webgpu_context_struct {
136137 std::vector<wgpu::CommandBuffer> staged_command_bufs;
137138
138139 // Parameter buffers associated with the staged command buffers
139- std::vector<webgpu_param_bufs> staged_param_bufs;
140+ std::vector<webgpu_pool_bufs> staged_param_bufs;
141+ // Buffers associated with set_rows operations, used to store potential errors
142+ std::vector<webgpu_pool_bufs> staged_set_row_error_bufs;
140143
141144 std::vector<wgpu::FutureWaitInfo> callback_futures;
145+
146+ #ifdef GGML_WEBGPU_DEBUG
147+ wgpu::Buffer debug_host_buf;
148+ wgpu::Buffer debug_dev_buf;
149+ #endif
142150};
143151
144152typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@@ -249,20 +257,55 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
249257 return ;
250258 }
251259 ctx->queue .Submit (ctx->staged_command_bufs .size (), ctx->staged_command_bufs .data ());
260+
261+ // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
262+ if (ctx->staged_set_row_error_bufs .size () > 0 ) {
263+ wgpu::CommandEncoder encoder = ctx->device .CreateCommandEncoder ();
264+ for (auto & error_bufs : ctx->staged_set_row_error_bufs ) {
265+ // Copy the error buffer to the host buffer
266+ encoder.CopyBufferToBuffer (error_bufs.dev_buf , 0 , error_bufs.host_buf , 0 , error_bufs.host_buf .GetSize ());
267+ }
268+ wgpu::CommandBuffer commands = encoder.Finish ();
269+ ctx->queue .Submit (1 , &commands);
270+ }
271+
252272 ctx->staged_command_bufs .clear ();
253- std::vector<webgpu_param_bufs> staged_param_bufs = std::move (ctx->staged_param_bufs );
273+ std::vector<webgpu_pool_bufs> staged_param_bufs = std::move (ctx->staged_param_bufs );
274+ std::vector<webgpu_pool_bufs> staged_set_row_error_bufs = std::move (ctx->staged_set_row_error_bufs );
254275
255276 // Free the staged parameter buffers once the submission completes
256- wgpu::Future f = ctx->queue .OnSubmittedWorkDone (
277+ wgpu::Future p_f = ctx->queue .OnSubmittedWorkDone (
257278 wgpu::CallbackMode::AllowSpontaneous,
258279 [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
259280 if (status != wgpu::QueueWorkDoneStatus::Success) {
260281 GGML_LOG_ERROR (" ggml_webgpu: Failed to submit commands: %s\n " , message.data );
261282 }
262- // Free the staged parameter buffers
283+ // Free the staged buffers
263284 ctx->param_buf_pool .free_bufs (staged_param_bufs);
264285 });
265- ctx->callback_futures .push_back ({ f });
286+ ctx->callback_futures .push_back ({ p_f });
287+
288+ // Check for errrors in SET_ROWS operations
289+ for (auto & error_bufs : staged_set_row_error_bufs) {
290+ wgpu::Future f = error_bufs.host_buf .MapAsync (
291+ wgpu::MapMode::Read,
292+ 0 ,
293+ error_bufs.host_buf .GetSize (),
294+ wgpu::CallbackMode::AllowSpontaneous,
295+ [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
296+ if (status != wgpu::MapAsyncStatus::Success) {
297+ GGML_LOG_ERROR (" ggml_webgpu: Failed to map error buffer: %s\n " , message.data );
298+ } else {
299+ const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf .GetConstMappedRange ();
300+ if (*error_data) {
301+ GGML_ABORT (" ggml_webgpu: SET_ROWS index > 2^32, unsupported." );
302+ }
303+ // We can't unmap in here due to WebGPU reentrancy limitations.
304+ ctx->set_rows_error_buf_pool .free_bufs ({ error_bufs });
305+ }
306+ });
307+ ctx->callback_futures .push_back ({ f });
308+ }
266309}
267310
268311static void ggml_backend_webgpu_map_buffer (webgpu_context & ctx,
@@ -283,13 +326,34 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
283326 UINT64_MAX);
284327}
285328
329+ #ifdef GGML_WEBGPU_DEBUG
330+ // This function adds debugging information to shaders, as WebGPU does not support printing directly.
331+ // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
332+ // debug statements in the shader, and then call this function after encoding the commands and submitting them.
333+ static void ggml_backend_webgpu_debug (webgpu_context & ctx) {
334+ wgpu::CommandEncoder encoder = ctx->device .CreateCommandEncoder ();
335+ encoder.CopyBufferToBuffer (ctx->debug_dev_buf , 0 , ctx->debug_host_buf , 0 , ctx->debug_host_buf .GetSize ());
336+ wgpu::CommandBuffer commands = encoder.Finish ();
337+ ctx->queue .Submit (1 , &commands);
338+
339+ ggml_backend_webgpu_map_buffer (ctx, ctx->debug_host_buf , wgpu::MapMode::Read, 0 , ctx->debug_host_buf .GetSize ());
340+ const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf .GetConstMappedRange ();
341+ std::cout << " debug data:" ;
342+ for (size_t i = 0 ; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
343+ std::cout << " " << i << " : " << debug_data[i];
344+ }
345+ std::cout << " \n " ;
346+ ctx->debug_host_buf .Unmap ();
347+ }
348+ #endif
349+
286350static void ggml_backend_webgpu_build_and_enqueue (webgpu_context & ctx,
287351 wgpu::ComputePipeline & pipeline,
288352 std::vector<uint32_t > params,
289353 std::vector<wgpu::BindGroupEntry> bind_group_entries,
290354 uint32_t wg_x,
291355 bool submit_and_wait = false ) {
292- webgpu_param_bufs params_bufs = ctx->param_buf_pool .alloc_bufs ();
356+ webgpu_pool_bufs params_bufs = ctx->param_buf_pool .alloc_bufs ();
293357
294358 ggml_backend_webgpu_map_buffer (ctx, params_bufs.host_buf , wgpu::MapMode::Write, 0 , params_bufs.host_buf .GetSize ());
295359 uint32_t * _params = (uint32_t *) params_bufs.host_buf .GetMappedRange ();
@@ -429,6 +493,76 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
429493 ggml_backend_webgpu_build_and_enqueue (ctx, ctx->cpy_pipeline , params, entries, wg_x);
430494}
431495
496+ static void ggml_webgpu_set_rows (webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
497+ // For set rows specifically, we need to check if src and idx are empty tensors.
498+ if (ggml_is_empty (src) || ggml_is_empty (idx)) {
499+ return ;
500+ }
501+
502+ webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool .alloc_bufs ();
503+ if (error_bufs.host_buf .GetMapState () == wgpu::BufferMapState::Mapped) {
504+ error_bufs.host_buf .Unmap ();
505+ }
506+
507+ size_t src_offset = ggml_backend_webgpu_tensor_offset (src);
508+ // assumes power of 2 offset alignment
509+ size_t src_misalignment = src_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
510+ // align to minimum offset alignment
511+ src_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
512+ size_t idx_offset = ggml_backend_webgpu_tensor_offset (idx);
513+ size_t idx_misalignment = idx_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
514+ idx_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
515+ size_t dst_offset = ggml_backend_webgpu_tensor_offset (dst);
516+ size_t dst_misalignment = dst_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
517+ dst_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
518+
519+ std::vector<uint32_t > params = { (uint32_t ) (src_misalignment / ggml_type_size (src->type )),
520+ (uint32_t ) (idx_misalignment / ggml_type_size (idx->type )),
521+ (uint32_t ) (dst_misalignment / ggml_type_size (dst->type )),
522+ // Convert byte-strides to element-strides
523+ (uint32_t ) (src->nb [1 ] / ggml_type_size (src->type )),
524+ (uint32_t ) (src->nb [2 ] / ggml_type_size (src->type )),
525+ (uint32_t ) (src->nb [3 ] / ggml_type_size (src->type )),
526+ (uint32_t ) (idx->nb [0 ] / ggml_type_size (idx->type )),
527+ (uint32_t ) (idx->nb [1 ] / ggml_type_size (idx->type )),
528+ (uint32_t ) (idx->nb [2 ] / ggml_type_size (idx->type )),
529+ (uint32_t ) (dst->nb [1 ] / ggml_type_size (dst->type )),
530+ (uint32_t ) (dst->nb [2 ] / ggml_type_size (dst->type )),
531+ (uint32_t ) (dst->nb [3 ] / ggml_type_size (dst->type )),
532+ // Shape of src
533+ (uint32_t ) src->ne [0 ],
534+ (uint32_t ) src->ne [1 ],
535+ (uint32_t ) src->ne [2 ],
536+ (uint32_t ) src->ne [3 ],
537+ // Shape of idx
538+ (uint32_t ) (idx->ne [1 ]),
539+ (uint32_t ) (idx->ne [2 ]) };
540+
541+ std::vector<wgpu::BindGroupEntry> entries = {
542+ { .binding = 0 ,
543+ .buffer = ggml_backend_webgpu_tensor_buf (src),
544+ .offset = ggml_backend_webgpu_tensor_offset (src),
545+ .size = ggml_nbytes (src) },
546+ { .binding = 1 ,
547+ .buffer = ggml_backend_webgpu_tensor_buf (idx),
548+ .offset = ggml_backend_webgpu_tensor_offset (idx),
549+ .size = ggml_nbytes (idx) },
550+ { .binding = 2 ,
551+ .buffer = ggml_backend_webgpu_tensor_buf (dst),
552+ .offset = ggml_backend_webgpu_tensor_offset (dst),
553+ .size = ggml_nbytes (dst) },
554+ { .binding = 3 , .buffer = error_bufs.dev_buf , .offset = 0 , .size = error_bufs.dev_buf .GetSize () }
555+ };
556+
557+ size_t max_wg_size = ctx->limits .maxComputeWorkgroupSizeX ;
558+ uint32_t wg_x = (src->ne [1 ] * src->ne [2 ] * src->ne [3 ] + max_wg_size - 1 ) / max_wg_size;
559+
560+ std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
561+ ctx->staged_set_row_error_bufs .push_back (error_bufs);
562+
563+ ggml_backend_webgpu_build_and_enqueue (ctx, ctx->set_rows_pipeline , params, entries, wg_x);
564+ }
565+
432566static void ggml_webgpu_mul_mat (webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
433567 std::vector<uint32_t > params = {
434568 (uint32_t ) dst->ne [1 ], // number of rows in result (M)
@@ -487,6 +621,11 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
487621 ggml_webgpu_cpy (ctx, src0, node);
488622 break ;
489623 }
624+ case GGML_OP_SET_ROWS:
625+ {
626+ ggml_webgpu_set_rows (ctx, src0, src1, node);
627+ break ;
628+ }
490629 case GGML_OP_MUL_MAT:
491630 {
492631 ggml_webgpu_mul_mat (ctx, src0, src1, node);
@@ -771,6 +910,14 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
771910 ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_mat_pipeline , wgsl_mul_mat, " mul_mat" );
772911}
773912
913+ static void ggml_webgpu_init_set_rows_pipeline (webgpu_context & webgpu_ctx) {
914+ std::vector<wgpu::ConstantEntry> constants (1 );
915+ constants[0 ].key = " wg_size" ;
916+ constants[0 ].value = webgpu_ctx->limits .maxComputeWorkgroupSizeX ;
917+ ggml_webgpu_create_pipeline (
918+ webgpu_ctx->device , webgpu_ctx->set_rows_pipeline , wgsl_set_rows, " set_rows" , constants);
919+ }
920+
774921static void ggml_webgpu_init_cpy_pipeline (webgpu_context & webgpu_ctx) {
775922 std::vector<wgpu::ConstantEntry> constants (1 );
776923 constants[0 ].key = " wg_size" ;
@@ -827,11 +974,35 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
827974 webgpu_ctx->queue = webgpu_ctx->device .GetQueue ();
828975
829976 // Create buffer pool for shader parameters
830- webgpu_ctx->param_buf_pool .init (webgpu_ctx->device );
977+ webgpu_ctx->param_buf_pool .init (webgpu_ctx->device ,
978+ WEBGPU_NUM_PARAM_BUFS,
979+ WEBGPU_PARAMS_BUF_SIZE_BYTES,
980+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
981+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
982+ webgpu_ctx->set_rows_error_buf_pool .init (webgpu_ctx->device ,
983+ WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
984+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
985+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
986+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
831987
832988 ggml_webgpu_init_memset_pipeline (webgpu_ctx);
833989 ggml_webgpu_init_mul_mat_pipeline (webgpu_ctx);
990+ ggml_webgpu_init_set_rows_pipeline (webgpu_ctx);
834991 ggml_webgpu_init_cpy_pipeline (webgpu_ctx);
992+
993+ #ifdef GGML_WEBGPU_DEBUG
994+ // Initialize debug buffers
995+ ggml_webgpu_create_buffer (webgpu_ctx->device ,
996+ webgpu_ctx->debug_host_buf ,
997+ WEBGPU_DEBUG_BUF_ELEMS * sizeof (uint32_t ),
998+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
999+ " debug_host_buf" );
1000+ ggml_webgpu_create_buffer (webgpu_ctx->device ,
1001+ webgpu_ctx->debug_dev_buf ,
1002+ WEBGPU_DEBUG_BUF_ELEMS * sizeof (uint32_t ),
1003+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
1004+ " debug_dev_buf" );
1005+ #endif
8351006 webgpu_ctx->device_init = true ;
8361007 }
8371008
@@ -882,7 +1053,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
8821053 case GGML_OP_VIEW:
8831054 case GGML_OP_PERMUTE:
8841055 return true ;
885- case GGML_OP_CPY:
1056+ case GGML_OP_CPY | GGML_OP_SET_ROWS :
8861057 return op->type == GGML_TYPE_F16 && op->src [0 ]->type == GGML_TYPE_F32;
8871058 case GGML_OP_MUL_MAT:
8881059 return op->src [0 ]->type == GGML_TYPE_F32 && op->src [1 ]->type == GGML_TYPE_F32;
0 commit comments