2020
2121#ifdef GGML_WEBGPU_DEBUG
2222# define WEBGPU_LOG_DEBUG (msg ) std::cout << msg << std::endl
23+ # define WEBGPU_DEBUG_BUF_ELEMS 32
2324#else
2425# define WEBGPU_LOG_DEBUG (msg ) ((void ) 0 )
2526#endif // GGML_WEBGPU_DEBUG
@@ -125,6 +126,7 @@ struct webgpu_context_struct {
125126
126127 wgpu::ComputePipeline memset_pipeline;
127128 wgpu::ComputePipeline mul_mat_pipeline;
129+ wgpu::ComputePipeline set_rows_pipeline;
128130 wgpu::ComputePipeline cpy_pipeline;
129131
130132 size_t memset_bytes_per_thread;
@@ -139,6 +141,12 @@ struct webgpu_context_struct {
139141 std::vector<webgpu_param_bufs> staged_param_bufs;
140142
141143 std::vector<wgpu::FutureWaitInfo> callback_futures;
144+
145+ #ifdef GGML_WEBGPU_DEBUG
146+ wgpu::Buffer debug_host_buf;
147+ wgpu::Buffer debug_dev_buf;
148+ #endif
149+
142150};
143151
144152typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@@ -283,6 +291,27 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
283291 UINT64_MAX);
284292}
285293
294+ #ifdef GGML_WEBGPU_DEBUG
295+ // This function adds debugging information to shaders, as WebGPU does not support printing directly.
296+ // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
297+ // debug statements in the shader, and then call this function after encoding the commands.
298+ static void ggml_backend_webgpu_debug (webgpu_context & ctx) {
299+ wgpu::CommandEncoder encoder = ctx->device .CreateCommandEncoder ();
300+ encoder.CopyBufferToBuffer (ctx->debug_dev_buf , 0 , ctx->debug_host_buf , 0 , ctx->debug_host_buf .GetSize ());
301+ wgpu::CommandBuffer commands = encoder.Finish ();
302+ ctx->queue .Submit (1 , &commands);
303+
304+ ggml_backend_webgpu_map_buffer (ctx, ctx->debug_host_buf , wgpu::MapMode::Read, 0 , ctx->debug_host_buf .GetSize ());
305+ const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf .GetConstMappedRange ();
306+ std::cout << " debug data:" ;
307+ for (size_t i = 0 ; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
308+ std::cout << " " << i << " : " << debug_data[i];
309+ }
310+ std::cout << " \n " ;
311+ ctx->debug_host_buf .Unmap ();
312+ }
313+ #endif
314+
286315static void ggml_backend_webgpu_build_and_enqueue (webgpu_context & ctx,
287316 wgpu::ComputePipeline & pipeline,
288317 std::vector<uint32_t > params,
@@ -429,6 +458,74 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
429458 ggml_backend_webgpu_build_and_enqueue (ctx, ctx->cpy_pipeline , params, entries, wg_x);
430459}
431460
461+ static void ggml_webgpu_set_rows (webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
462+ // For set rows specifically, we need to check if src and idx are empty tensors.
463+ if (ggml_is_empty (src) || ggml_is_empty (idx)) {
464+ return ;
465+ }
466+
467+ size_t src_offset = ggml_backend_webgpu_tensor_offset (src);
468+ // assumes power of 2 offset alignment
469+ size_t src_misalignment = src_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
470+ // align to minimum offset alignment
471+ src_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
472+ size_t idx_offset = ggml_backend_webgpu_tensor_offset (idx);
473+ size_t idx_misalignment = idx_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
474+ idx_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
475+ size_t dst_offset = ggml_backend_webgpu_tensor_offset (dst);
476+ size_t dst_misalignment = dst_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
477+ dst_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
478+
479+ std::vector<uint32_t > params = {
480+ (uint32_t ) (src_misalignment / ggml_type_size (src->type )),
481+ (uint32_t ) (idx_misalignment / ggml_type_size (idx->type )),
482+ (uint32_t ) (dst_misalignment / ggml_type_size (dst->type )),
483+ // Convert byte-strides to element-strides
484+ (uint32_t ) (src->nb [1 ] / ggml_type_size (src->type )),
485+ (uint32_t ) (src->nb [2 ] / ggml_type_size (src->type )),
486+ (uint32_t ) (src->nb [3 ] / ggml_type_size (src->type )),
487+ (uint32_t ) (idx->nb [0 ] / ggml_type_size (idx->type )),
488+ (uint32_t ) (idx->nb [1 ] / ggml_type_size (idx->type )),
489+ (uint32_t ) (idx->nb [2 ] / ggml_type_size (idx->type )),
490+ (uint32_t ) (dst->nb [1 ] / ggml_type_size (dst->type )),
491+ (uint32_t ) (dst->nb [2 ] / ggml_type_size (dst->type )),
492+ (uint32_t ) (dst->nb [3 ] / ggml_type_size (dst->type )),
493+ // Shape of src
494+ (uint32_t ) src->ne [0 ],
495+ (uint32_t ) src->ne [1 ],
496+ (uint32_t ) src->ne [2 ],
497+ (uint32_t ) src->ne [3 ],
498+ // broadcast shape of idx
499+ (uint32_t ) (src->ne [2 ] / idx->ne [1 ]),
500+ (uint32_t ) (src->ne [3 ] / idx->ne [2 ])
501+ };
502+
503+ std::vector<wgpu::BindGroupEntry> entries = {
504+ { .binding = 0 ,
505+ .buffer = ggml_backend_webgpu_tensor_buf (src),
506+ .offset = ggml_backend_webgpu_tensor_offset (src),
507+ .size = ggml_nbytes (src) },
508+ { .binding = 1 ,
509+ .buffer = ggml_backend_webgpu_tensor_buf (idx),
510+ .offset = ggml_backend_webgpu_tensor_offset (idx),
511+ .size = ggml_nbytes (idx) },
512+ { .binding = 2 ,
513+ .buffer = ggml_backend_webgpu_tensor_buf (dst),
514+ .offset = ggml_backend_webgpu_tensor_offset (dst),
515+ .size = ggml_nbytes (dst) },
516+ { .binding = 3 ,
517+ .buffer = ctx->debug_dev_buf ,
518+ .offset = 0 ,
519+ .size = ctx->debug_dev_buf .GetSize () }
520+ };
521+
522+ size_t max_wg_size = ctx->limits .maxComputeWorkgroupSizeX ;
523+ uint32_t wg_x = (src->ne [1 ] * src->ne [2 ] * src->ne [3 ] + max_wg_size - 1 ) / max_wg_size;
524+ ggml_backend_webgpu_build_and_enqueue (ctx, ctx->set_rows_pipeline , params, entries, wg_x);
525+ ggml_backend_webgpu_submit_queue (ctx);
526+ ggml_backend_webgpu_debug (ctx);
527+ }
528+
432529static void ggml_webgpu_mul_mat (webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
433530 std::vector<uint32_t > params = {
434531 (uint32_t ) dst->ne [1 ], // number of rows in result (M)
@@ -487,6 +584,11 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
487584 ggml_webgpu_cpy (ctx, src0, node);
488585 break ;
489586 }
587+ case GGML_OP_SET_ROWS:
588+ {
589+ ggml_webgpu_set_rows (ctx, src0, src1, node);
590+ break ;
591+ }
490592 case GGML_OP_MUL_MAT:
491593 {
492594 ggml_webgpu_mul_mat (ctx, src0, src1, node);
@@ -771,6 +873,13 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
771873 ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_mat_pipeline , wgsl_mul_mat, " mul_mat" );
772874}
773875
876+ static void ggml_webgpu_init_set_rows_pipeline (webgpu_context & webgpu_ctx) {
877+ std::vector<wgpu::ConstantEntry> constants (1 );
878+ constants[0 ].key = " wg_size" ;
879+ constants[0 ].value = webgpu_ctx->limits .maxComputeWorkgroupSizeX ;
880+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->set_rows_pipeline , wgsl_set_rows, " set_rows" , constants);
881+ }
882+
774883static void ggml_webgpu_init_cpy_pipeline (webgpu_context & webgpu_ctx) {
775884 std::vector<wgpu::ConstantEntry> constants (1 );
776885 constants[0 ].key = " wg_size" ;
@@ -831,7 +940,22 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
831940
832941 ggml_webgpu_init_memset_pipeline (webgpu_ctx);
833942 ggml_webgpu_init_mul_mat_pipeline (webgpu_ctx);
943+ ggml_webgpu_init_set_rows_pipeline (webgpu_ctx);
834944 ggml_webgpu_init_cpy_pipeline (webgpu_ctx);
945+
946+ #ifdef GGML_WEBGPU_DEBUG
947+ // Initialize debug buffers
948+ ggml_webgpu_create_buffer (webgpu_ctx->device ,
949+ webgpu_ctx->debug_host_buf ,
950+ WEBGPU_DEBUG_BUF_ELEMS * sizeof (uint32_t ),
951+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
952+ " debug_host_buf" );
953+ ggml_webgpu_create_buffer (webgpu_ctx->device ,
954+ webgpu_ctx->debug_dev_buf ,
955+ WEBGPU_DEBUG_BUF_ELEMS * sizeof (uint32_t ),
956+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
957+ " debug_dev_buf" );
958+ #endif
835959 webgpu_ctx->device_init = true ;
836960 }
837961
@@ -882,7 +1006,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
8821006 case GGML_OP_VIEW:
8831007 case GGML_OP_PERMUTE:
8841008 return true ;
885- case GGML_OP_CPY:
1009+ case GGML_OP_CPY | GGML_OP_SET_ROWS :
8861010 return op->type == GGML_TYPE_F16 && op->src [0 ]->type == GGML_TYPE_F32;
8871011 case GGML_OP_MUL_MAT:
8881012 return op->src [0 ]->type == GGML_TYPE_F32 && op->src [1 ]->type == GGML_TYPE_F32;
0 commit comments