19
19
#include < vector>
20
20
21
21
#ifdef GGML_WEBGPU_DEBUG
22
- # define WEBGPU_LOG_DEBUG (msg ) std::cout << msg << std::endl
22
+ # define WEBGPU_LOG_DEBUG (msg ) std::cout << msg << std::endl
23
+ # define WEBGPU_DEBUG_BUF_ELEMS 32
23
24
#else
24
25
# define WEBGPU_LOG_DEBUG (msg ) ((void ) 0 )
25
26
#endif // GGML_WEBGPU_DEBUG
26
27
27
28
/* Constants */
28
29
29
- #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
30
- #define WEBGPU_MUL_MAT_WG_SIZE 64
31
- #define WEBGPU_NUM_PARAM_BUFS 100
32
- #define WEBGPU_PARAMS_BUF_SIZE_BYTES 256
33
- #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
30
+ #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
31
+ #define WEBGPU_MUL_MAT_WG_SIZE 64
32
+ #define WEBGPU_NUM_PARAM_BUFS 100
33
+ #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
34
+ #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
35
+ #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
36
+ #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
34
37
35
38
/* End Constants */
36
39
@@ -54,46 +57,42 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
54
57
wgpu::BufferUsage usage,
55
58
const char * label);
56
59
57
- struct webgpu_param_bufs {
60
+ struct webgpu_pool_bufs {
58
61
wgpu::Buffer host_buf;
59
62
wgpu::Buffer dev_buf;
60
63
};
61
64
62
65
// Holds a pool of parameter buffers for WebGPU operations
63
- struct webgpu_param_buf_pool {
64
- std::vector<webgpu_param_bufs > free;
66
+ struct webgpu_buf_pool {
67
+ std::vector<webgpu_pool_bufs > free;
65
68
66
69
std::mutex mutex;
67
70
68
71
std::condition_variable cv;
69
72
70
- void init (wgpu::Device device) {
71
- for (int i = 0 ; i < WEBGPU_NUM_PARAM_BUFS; i++) {
73
+ void init (wgpu::Device device,
74
+ int num_bufs,
75
+ size_t buf_size,
76
+ wgpu::BufferUsage dev_buf_usage,
77
+ wgpu::BufferUsage host_buf_usage) {
78
+ for (int i = 0 ; i < num_bufs; i++) {
72
79
wgpu::Buffer host_buf;
73
80
wgpu::Buffer dev_buf;
74
- ggml_webgpu_create_buffer (device,
75
- host_buf,
76
- WEBGPU_PARAMS_BUF_SIZE_BYTES,
77
- wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite,
78
- " ggml_webgpu_host_params_buf" );
79
- ggml_webgpu_create_buffer (device,
80
- dev_buf,
81
- WEBGPU_PARAMS_BUF_SIZE_BYTES,
82
- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
83
- " ggml_webgpu_dev_params_buf" );
81
+ ggml_webgpu_create_buffer (device, host_buf, buf_size, host_buf_usage, " ggml_webgpu_host_pool_buf" );
82
+ ggml_webgpu_create_buffer (device, dev_buf, buf_size, dev_buf_usage, " ggml_webgpu_dev_pool_buf" );
84
83
free.push_back ({ host_buf, dev_buf });
85
84
}
86
85
}
87
86
88
- webgpu_param_bufs alloc_bufs () {
87
+ webgpu_pool_bufs alloc_bufs () {
89
88
std::unique_lock<std::mutex> lock (mutex);
90
89
cv.wait (lock, [this ] { return !free.empty (); });
91
- webgpu_param_bufs bufs = free.back ();
90
+ webgpu_pool_bufs bufs = free.back ();
92
91
free.pop_back ();
93
92
return bufs;
94
93
}
95
94
96
- void free_bufs (std::vector<webgpu_param_bufs > bufs) {
95
+ void free_bufs (std::vector<webgpu_pool_bufs > bufs) {
97
96
std::lock_guard<std::mutex> lock (mutex);
98
97
free.insert (free.end (), bufs.begin (), bufs.end ());
99
98
cv.notify_all ();
@@ -121,10 +120,12 @@ struct webgpu_context_struct {
121
120
122
121
bool device_init = false ;
123
122
124
- webgpu_param_buf_pool param_buf_pool;
123
+ webgpu_buf_pool param_buf_pool;
124
+ webgpu_buf_pool set_rows_error_buf_pool;
125
125
126
126
wgpu::ComputePipeline memset_pipeline;
127
127
wgpu::ComputePipeline mul_mat_pipeline;
128
+ wgpu::ComputePipeline set_rows_pipeline;
128
129
wgpu::ComputePipeline cpy_pipeline;
129
130
130
131
size_t memset_bytes_per_thread;
@@ -136,9 +137,16 @@ struct webgpu_context_struct {
136
137
std::vector<wgpu::CommandBuffer> staged_command_bufs;
137
138
138
139
// Parameter buffers associated with the staged command buffers
139
- std::vector<webgpu_param_bufs> staged_param_bufs;
140
+ std::vector<webgpu_pool_bufs> staged_param_bufs;
141
+ // Buffers associated with set_rows operations, used to store potential errors
142
+ std::vector<webgpu_pool_bufs> staged_set_row_error_bufs;
140
143
141
144
std::vector<wgpu::FutureWaitInfo> callback_futures;
145
+
146
+ #ifdef GGML_WEBGPU_DEBUG
147
+ wgpu::Buffer debug_host_buf;
148
+ wgpu::Buffer debug_dev_buf;
149
+ #endif
142
150
};
143
151
144
152
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@@ -249,20 +257,55 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
249
257
return ;
250
258
}
251
259
ctx->queue .Submit (ctx->staged_command_bufs .size (), ctx->staged_command_bufs .data ());
260
+
261
+ // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
262
+ if (ctx->staged_set_row_error_bufs .size () > 0 ) {
263
+ wgpu::CommandEncoder encoder = ctx->device .CreateCommandEncoder ();
264
+ for (auto & error_bufs : ctx->staged_set_row_error_bufs ) {
265
+ // Copy the error buffer to the host buffer
266
+ encoder.CopyBufferToBuffer (error_bufs.dev_buf , 0 , error_bufs.host_buf , 0 , error_bufs.host_buf .GetSize ());
267
+ }
268
+ wgpu::CommandBuffer commands = encoder.Finish ();
269
+ ctx->queue .Submit (1 , &commands);
270
+ }
271
+
252
272
ctx->staged_command_bufs .clear ();
253
- std::vector<webgpu_param_bufs> staged_param_bufs = std::move (ctx->staged_param_bufs );
273
+ std::vector<webgpu_pool_bufs> staged_param_bufs = std::move (ctx->staged_param_bufs );
274
+ std::vector<webgpu_pool_bufs> staged_set_row_error_bufs = std::move (ctx->staged_set_row_error_bufs );
254
275
255
276
// Free the staged parameter buffers once the submission completes
256
- wgpu::Future f = ctx->queue .OnSubmittedWorkDone (
277
+ wgpu::Future p_f = ctx->queue .OnSubmittedWorkDone (
257
278
wgpu::CallbackMode::AllowSpontaneous,
258
279
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
259
280
if (status != wgpu::QueueWorkDoneStatus::Success) {
260
281
GGML_LOG_ERROR (" ggml_webgpu: Failed to submit commands: %s\n " , message.data );
261
282
}
262
- // Free the staged parameter buffers
283
+ // Free the staged buffers
263
284
ctx->param_buf_pool .free_bufs (staged_param_bufs);
264
285
});
265
- ctx->callback_futures .push_back ({ f });
286
+ ctx->callback_futures .push_back ({ p_f });
287
+
288
+ // Check for errrors in SET_ROWS operations
289
+ for (auto & error_bufs : staged_set_row_error_bufs) {
290
+ wgpu::Future f = error_bufs.host_buf .MapAsync (
291
+ wgpu::MapMode::Read,
292
+ 0 ,
293
+ error_bufs.host_buf .GetSize (),
294
+ wgpu::CallbackMode::AllowSpontaneous,
295
+ [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
296
+ if (status != wgpu::MapAsyncStatus::Success) {
297
+ GGML_LOG_ERROR (" ggml_webgpu: Failed to map error buffer: %s\n " , message.data );
298
+ } else {
299
+ const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf .GetConstMappedRange ();
300
+ if (*error_data) {
301
+ GGML_ABORT (" ggml_webgpu: SET_ROWS index > 2^32, unsupported." );
302
+ }
303
+ // We can't unmap in here due to WebGPU reentrancy limitations.
304
+ ctx->set_rows_error_buf_pool .free_bufs ({ error_bufs });
305
+ }
306
+ });
307
+ ctx->callback_futures .push_back ({ f });
308
+ }
266
309
}
267
310
268
311
static void ggml_backend_webgpu_map_buffer (webgpu_context & ctx,
@@ -283,13 +326,34 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
283
326
UINT64_MAX);
284
327
}
285
328
329
+ #ifdef GGML_WEBGPU_DEBUG
330
+ // This function adds debugging information to shaders, as WebGPU does not support printing directly.
331
+ // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
332
+ // debug statements in the shader, and then call this function after encoding the commands and submitting them.
333
+ static void ggml_backend_webgpu_debug (webgpu_context & ctx) {
334
+ wgpu::CommandEncoder encoder = ctx->device .CreateCommandEncoder ();
335
+ encoder.CopyBufferToBuffer (ctx->debug_dev_buf , 0 , ctx->debug_host_buf , 0 , ctx->debug_host_buf .GetSize ());
336
+ wgpu::CommandBuffer commands = encoder.Finish ();
337
+ ctx->queue .Submit (1 , &commands);
338
+
339
+ ggml_backend_webgpu_map_buffer (ctx, ctx->debug_host_buf , wgpu::MapMode::Read, 0 , ctx->debug_host_buf .GetSize ());
340
+ const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf .GetConstMappedRange ();
341
+ std::cout << " debug data:" ;
342
+ for (size_t i = 0 ; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
343
+ std::cout << " " << i << " : " << debug_data[i];
344
+ }
345
+ std::cout << " \n " ;
346
+ ctx->debug_host_buf .Unmap ();
347
+ }
348
+ #endif
349
+
286
350
static void ggml_backend_webgpu_build_and_enqueue (webgpu_context & ctx,
287
351
wgpu::ComputePipeline & pipeline,
288
352
std::vector<uint32_t > params,
289
353
std::vector<wgpu::BindGroupEntry> bind_group_entries,
290
354
uint32_t wg_x,
291
355
bool submit_and_wait = false ) {
292
- webgpu_param_bufs params_bufs = ctx->param_buf_pool .alloc_bufs ();
356
+ webgpu_pool_bufs params_bufs = ctx->param_buf_pool .alloc_bufs ();
293
357
294
358
ggml_backend_webgpu_map_buffer (ctx, params_bufs.host_buf , wgpu::MapMode::Write, 0 , params_bufs.host_buf .GetSize ());
295
359
uint32_t * _params = (uint32_t *) params_bufs.host_buf .GetMappedRange ();
@@ -429,6 +493,76 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
429
493
ggml_backend_webgpu_build_and_enqueue (ctx, ctx->cpy_pipeline , params, entries, wg_x);
430
494
}
431
495
496
+ static void ggml_webgpu_set_rows (webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
497
+ // For set rows specifically, we need to check if src and idx are empty tensors.
498
+ if (ggml_is_empty (src) || ggml_is_empty (idx)) {
499
+ return ;
500
+ }
501
+
502
+ webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool .alloc_bufs ();
503
+ if (error_bufs.host_buf .GetMapState () == wgpu::BufferMapState::Mapped) {
504
+ error_bufs.host_buf .Unmap ();
505
+ }
506
+
507
+ size_t src_offset = ggml_backend_webgpu_tensor_offset (src);
508
+ // assumes power of 2 offset alignment
509
+ size_t src_misalignment = src_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
510
+ // align to minimum offset alignment
511
+ src_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
512
+ size_t idx_offset = ggml_backend_webgpu_tensor_offset (idx);
513
+ size_t idx_misalignment = idx_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
514
+ idx_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
515
+ size_t dst_offset = ggml_backend_webgpu_tensor_offset (dst);
516
+ size_t dst_misalignment = dst_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
517
+ dst_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
518
+
519
+ std::vector<uint32_t > params = { (uint32_t ) (src_misalignment / ggml_type_size (src->type )),
520
+ (uint32_t ) (idx_misalignment / ggml_type_size (idx->type )),
521
+ (uint32_t ) (dst_misalignment / ggml_type_size (dst->type )),
522
+ // Convert byte-strides to element-strides
523
+ (uint32_t ) (src->nb [1 ] / ggml_type_size (src->type )),
524
+ (uint32_t ) (src->nb [2 ] / ggml_type_size (src->type )),
525
+ (uint32_t ) (src->nb [3 ] / ggml_type_size (src->type )),
526
+ (uint32_t ) (idx->nb [0 ] / ggml_type_size (idx->type )),
527
+ (uint32_t ) (idx->nb [1 ] / ggml_type_size (idx->type )),
528
+ (uint32_t ) (idx->nb [2 ] / ggml_type_size (idx->type )),
529
+ (uint32_t ) (dst->nb [1 ] / ggml_type_size (dst->type )),
530
+ (uint32_t ) (dst->nb [2 ] / ggml_type_size (dst->type )),
531
+ (uint32_t ) (dst->nb [3 ] / ggml_type_size (dst->type )),
532
+ // Shape of src
533
+ (uint32_t ) src->ne [0 ],
534
+ (uint32_t ) src->ne [1 ],
535
+ (uint32_t ) src->ne [2 ],
536
+ (uint32_t ) src->ne [3 ],
537
+ // Shape of idx
538
+ (uint32_t ) (idx->ne [1 ]),
539
+ (uint32_t ) (idx->ne [2 ]) };
540
+
541
+ std::vector<wgpu::BindGroupEntry> entries = {
542
+ { .binding = 0 ,
543
+ .buffer = ggml_backend_webgpu_tensor_buf (src),
544
+ .offset = ggml_backend_webgpu_tensor_offset (src),
545
+ .size = ggml_nbytes (src) },
546
+ { .binding = 1 ,
547
+ .buffer = ggml_backend_webgpu_tensor_buf (idx),
548
+ .offset = ggml_backend_webgpu_tensor_offset (idx),
549
+ .size = ggml_nbytes (idx) },
550
+ { .binding = 2 ,
551
+ .buffer = ggml_backend_webgpu_tensor_buf (dst),
552
+ .offset = ggml_backend_webgpu_tensor_offset (dst),
553
+ .size = ggml_nbytes (dst) },
554
+ { .binding = 3 , .buffer = error_bufs.dev_buf , .offset = 0 , .size = error_bufs.dev_buf .GetSize () }
555
+ };
556
+
557
+ size_t max_wg_size = ctx->limits .maxComputeWorkgroupSizeX ;
558
+ uint32_t wg_x = (src->ne [1 ] * src->ne [2 ] * src->ne [3 ] + max_wg_size - 1 ) / max_wg_size;
559
+
560
+ std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
561
+ ctx->staged_set_row_error_bufs .push_back (error_bufs);
562
+
563
+ ggml_backend_webgpu_build_and_enqueue (ctx, ctx->set_rows_pipeline , params, entries, wg_x);
564
+ }
565
+
432
566
static void ggml_webgpu_mul_mat (webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
433
567
std::vector<uint32_t > params = {
434
568
(uint32_t ) dst->ne [1 ], // number of rows in result (M)
@@ -487,6 +621,11 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
487
621
ggml_webgpu_cpy (ctx, src0, node);
488
622
break ;
489
623
}
624
+ case GGML_OP_SET_ROWS:
625
+ {
626
+ ggml_webgpu_set_rows (ctx, src0, src1, node);
627
+ break ;
628
+ }
490
629
case GGML_OP_MUL_MAT:
491
630
{
492
631
ggml_webgpu_mul_mat (ctx, src0, src1, node);
@@ -771,6 +910,14 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
771
910
ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_mat_pipeline , wgsl_mul_mat, " mul_mat" );
772
911
}
773
912
913
+ static void ggml_webgpu_init_set_rows_pipeline (webgpu_context & webgpu_ctx) {
914
+ std::vector<wgpu::ConstantEntry> constants (1 );
915
+ constants[0 ].key = " wg_size" ;
916
+ constants[0 ].value = webgpu_ctx->limits .maxComputeWorkgroupSizeX ;
917
+ ggml_webgpu_create_pipeline (
918
+ webgpu_ctx->device , webgpu_ctx->set_rows_pipeline , wgsl_set_rows, " set_rows" , constants);
919
+ }
920
+
774
921
static void ggml_webgpu_init_cpy_pipeline (webgpu_context & webgpu_ctx) {
775
922
std::vector<wgpu::ConstantEntry> constants (1 );
776
923
constants[0 ].key = " wg_size" ;
@@ -827,11 +974,35 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
827
974
webgpu_ctx->queue = webgpu_ctx->device .GetQueue ();
828
975
829
976
// Create buffer pool for shader parameters
830
- webgpu_ctx->param_buf_pool .init (webgpu_ctx->device );
977
+ webgpu_ctx->param_buf_pool .init (webgpu_ctx->device ,
978
+ WEBGPU_NUM_PARAM_BUFS,
979
+ WEBGPU_PARAMS_BUF_SIZE_BYTES,
980
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
981
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
982
+ webgpu_ctx->set_rows_error_buf_pool .init (webgpu_ctx->device ,
983
+ WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
984
+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
985
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
986
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
831
987
832
988
ggml_webgpu_init_memset_pipeline (webgpu_ctx);
833
989
ggml_webgpu_init_mul_mat_pipeline (webgpu_ctx);
990
+ ggml_webgpu_init_set_rows_pipeline (webgpu_ctx);
834
991
ggml_webgpu_init_cpy_pipeline (webgpu_ctx);
992
+
993
+ #ifdef GGML_WEBGPU_DEBUG
994
+ // Initialize debug buffers
995
+ ggml_webgpu_create_buffer (webgpu_ctx->device ,
996
+ webgpu_ctx->debug_host_buf ,
997
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof (uint32_t ),
998
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
999
+ " debug_host_buf" );
1000
+ ggml_webgpu_create_buffer (webgpu_ctx->device ,
1001
+ webgpu_ctx->debug_dev_buf ,
1002
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof (uint32_t ),
1003
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
1004
+ " debug_dev_buf" );
1005
+ #endif
835
1006
webgpu_ctx->device_init = true ;
836
1007
}
837
1008
@@ -882,7 +1053,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
882
1053
case GGML_OP_VIEW:
883
1054
case GGML_OP_PERMUTE:
884
1055
return true ;
885
- case GGML_OP_CPY:
1056
+ case GGML_OP_CPY | GGML_OP_SET_ROWS :
886
1057
return op->type == GGML_TYPE_F16 && op->src [0 ]->type == GGML_TYPE_F32;
887
1058
case GGML_OP_MUL_MAT:
888
1059
return op->src [0 ]->type == GGML_TYPE_F32 && op->src [1 ]->type == GGML_TYPE_F32;
0 commit comments