1616#define WEBGPU_LOG_DEBUG (msg ) ((void ) 0 )
1717#endif // GGML_WEBGPU_DEBUG
1818
19+ /* Constants */
20+
1921// TODO: find a better way to get the memory available
2022#define WEBGPU_MAX_BUFFERS 32
2123
24+ #define WEBGPU_MUL_MAT_WG_SIZE 64
25+ #define WEBGPU_MUL_MAT_PARAMS_SIZE (7 * sizeof (uint32_t )) // M, N, K, batch sizes, broadcasts
26+
27+ /* End Constants */
28+
2229// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
2330static void * const webgpu_ptr_base = (void *)(uintptr_t ) 0x1000 ; // NOLINT
2431
@@ -138,18 +145,16 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer buff
138145 );
139146}
140147
141- static void ggml_backend_webgpu_buffer_memset (webgpu_context ctx, wgpu::Buffer buf, uint8_t value, size_t offset, size_t size) {
148+ static void ggml_backend_webgpu_buffer_memset (webgpu_context ctx, wgpu::Buffer buf, uint32_t value, size_t offset, size_t size) {
142149 wgpu::Device device = ctx->device ;
143150
144151 // map the host parameters buffer
145152 ggml_backend_webgpu_map_buffer (ctx, ctx->memset_params_host_buf , wgpu::MapMode::Write, 0 , ctx->memset_params_host_buf .GetSize ());
146153 uint32_t * params = (uint32_t *) ctx->memset_params_host_buf .GetMappedRange ();
147154
148- // This is a trick to set all bytes of a u32 to the same 1 byte value.
149- uint32_t val32 = (uint32_t )value * 0x01010101 ;
150155 params[0 ] = (uint32_t )offset;
151156 params[1 ] = (uint32_t )size;
152- params[2 ] = val32 ;
157+ params[2 ] = value ;
153158 ctx->memset_params_host_buf .Unmap ();
154159
155160 wgpu::BindGroupEntry entries[2 ];
@@ -191,7 +196,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer b
191196/* * GGML Backend Interface */
192197
193198static const char * ggml_backend_webgpu_name (ggml_backend_t backend) {
194- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_name()" );
195199 ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context ;
196200 return ctx->name .c_str ();
197201}
@@ -201,6 +205,7 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
201205 WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_free(" << ctx->name << " )" );
202206
203207 // TODO: cleanup
208+ GGML_UNUSED (ctx);
204209}
205210
206211// Returns true if node has enqueued work into the queue, false otherwise
@@ -244,6 +249,11 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
244249 params[0 ] = (uint32_t )node->ne [1 ]; // number of rows in result (M)
245250 params[1 ] = (uint32_t )node->ne [0 ]; // number of columns in result (N)
246251 params[2 ] = (uint32_t )src0->ne [0 ]; // number of columns in src0/src1 (K)
252+ params[3 ] = (uint32_t )src0->ne [2 ]; // batch size in dimension 2
253+ params[4 ] = (uint32_t )src0->ne [3 ]; // batch size in dimension 3
254+ params[5 ] = (uint32_t )(src1->ne [2 ]/src0->ne [2 ]); // broadcast in dimension 2
255+ params[6 ] = (uint32_t )(src1->ne [3 ]/src0->ne [3 ]); // broadcast in dimension 3
256+
247257 ctx->mul_mat_params_host_buf .Unmap ();
248258
249259 wgpu::BindGroupEntry entries[4 ];
@@ -282,7 +292,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
282292 wgpu::ComputePassEncoder pass = encoder.BeginComputePass ();
283293 pass.SetPipeline (ctx->mul_mat_pipeline );
284294 pass.SetBindGroup (0 , bind_group);
285- pass.DispatchWorkgroups (node->ne [0 ] * node->ne [1 ]);
295+ pass.DispatchWorkgroups (( node->ne [0 ] * node->ne [1 ] * node-> ne [ 2 ] * node-> ne [ 3 ] + WEBGPU_MUL_MAT_WG_SIZE - 1 ) / WEBGPU_MUL_MAT_WG_SIZE );
286296 pass.End ();
287297 wgpu::CommandBuffer commands = encoder.Finish ();
288298
@@ -352,7 +362,9 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
352362
353363 ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context ;
354364 size_t total_offset = webgpu_tensor_offset (tensor) + tensor->view_offs + offset;
355- ggml_backend_webgpu_buffer_memset (buf_ctx->webgpu_ctx , buf_ctx->buffer , value, total_offset, size);
365+ // This is a trick to set all bytes of a u32 to the same 1 byte value.
366+ uint32_t val32 = (uint32_t )value * 0x01010101 ;
367+ ggml_backend_webgpu_buffer_memset (buf_ctx->webgpu_ctx , buf_ctx->buffer , val32, total_offset, size);
356368}
357369
358370static void ggml_backend_webgpu_buffer_set_tensor (ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
@@ -363,10 +375,21 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
363375 size_t total_offset = webgpu_tensor_offset (tensor) + tensor->view_offs + offset;
364376
365377 // TODO: wait on this?
366- webgpu_ctx->queue .WriteBuffer (buf_ctx->buffer , total_offset, data, size);
378+ webgpu_ctx->queue .WriteBuffer (buf_ctx->buffer , total_offset, data, (size/4 )*4 );
379+
380+ if (size % 4 != 0 ) {
381+ // If size is not a multiple of 4, we need to memset the remaining bytes
382+ size_t remaining_size = size % 4 ;
383+ // pack the remaining bytes into a uint32_t
384+ uint32_t val32 = 0 ;
385+ for (size_t i = 0 ; i < remaining_size; i++) {
386+ ((uint8_t *)&val32)[i] = ((const uint8_t *)data)[size - remaining_size + i];
387+ }
388+ // memset the remaining bytes
389+ ggml_backend_webgpu_buffer_memset (webgpu_ctx, buf_ctx->buffer , val32, total_offset + (size - remaining_size), remaining_size);
390+ }
367391}
368392
369- // TODO: we need a staging buffer for this, since WebGPU does not allow reading from storage buffers directly.
370393static void ggml_backend_webgpu_buffer_get_tensor (ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
371394 WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_buffer_get_tensor(" << buffer << " , " << tensor << " , " << data << " , " << offset << " , " << size << " )" );
372395
@@ -376,33 +399,39 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
376399
377400 size_t total_offset = webgpu_tensor_offset (tensor) + tensor->view_offs + offset;
378401
402+ size_t final_size = size;
403+ if (size % 4 != 0 ) {
404+ // If size is not a multiple of 4, we need to round it up to the next multiple of 4
405+ final_size = size + (4 - (size % 4 ));
406+ }
407+
379408 if (webgpu_ctx->get_tensor_staging_buf == nullptr ||
380- webgpu_ctx->get_tensor_staging_buf .GetSize () < size ) {
409+ webgpu_ctx->get_tensor_staging_buf .GetSize () < final_size ) {
381410 // Create a new staging buffer if it doesn't exist or is too small
382411 if (webgpu_ctx->get_tensor_staging_buf ) {
383412 webgpu_ctx->get_tensor_staging_buf .Destroy ();
384413 }
385- ggml_webgpu_create_buffer (device, webgpu_ctx->get_tensor_staging_buf , size ,
414+ ggml_webgpu_create_buffer (device, webgpu_ctx->get_tensor_staging_buf , final_size ,
386415 wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
387416 }
388417
389418 // Copy the data from the buffer to the staging buffer
390419 wgpu::CommandEncoder encoder = device.CreateCommandEncoder ();
391- encoder.CopyBufferToBuffer (buf_ctx->buffer , total_offset, webgpu_ctx->get_tensor_staging_buf , 0 , size );
420+ encoder.CopyBufferToBuffer (buf_ctx->buffer , total_offset, webgpu_ctx->get_tensor_staging_buf , 0 , final_size );
392421 wgpu::CommandBuffer commands = encoder.Finish ();
393422 // Submit the command buffer to the queue
394423 webgpu_ctx->queue .Submit (1 , &commands);
395424
396425 // Map the staging buffer to read the data
397- ggml_backend_webgpu_map_buffer (webgpu_ctx, webgpu_ctx->get_tensor_staging_buf , wgpu::MapMode::Read, 0 , size);
398- const void * mapped_range = webgpu_ctx->get_tensor_staging_buf .GetConstMappedRange ();
426+ ggml_backend_webgpu_map_buffer (webgpu_ctx, webgpu_ctx->get_tensor_staging_buf , wgpu::MapMode::Read, 0 , final_size);
427+ // Must specify size here since the staging buffer might be larger than the tensor size
428+ const void * mapped_range = webgpu_ctx->get_tensor_staging_buf .GetConstMappedRange (0 , final_size);
399429
400430 // Copy the data from the mapped range to the output buffer
401431 std::memcpy (data, mapped_range, size);
402432 webgpu_ctx->get_tensor_staging_buf .Unmap ();
403433}
404434
405- // TODO
406435static void ggml_backend_webgpu_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value) {
407436 WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_buffer_clear(" << buffer << " , " << value << " )" );
408437
@@ -427,7 +456,6 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
427456/* GGML Backend Buffer Type Interface */
428457
429458static const char * ggml_backend_webgpu_buffer_type_get_name (ggml_backend_buffer_type_t buft) {
430- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_buffer_type_get_name()" );
431459 ggml_backend_webgpu_device_context * ctx = static_cast <ggml_backend_webgpu_device_context *>(buft->device ->context );
432460 return ctx->device_name .c_str ();
433461}
@@ -446,14 +474,12 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
446474}
447475
448476static size_t ggml_backend_webgpu_buffer_type_get_alignment (ggml_backend_buffer_type_t buft) {
449- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_buffer_type_get_alignment()" );
450477 ggml_backend_webgpu_device_context * ctx = static_cast <ggml_backend_webgpu_device_context *>(buft->device ->context );
451478 return ctx->webgpu_ctx ->limits .minStorageBufferOffsetAlignment ;
452479}
453480
454481// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
455482static size_t ggml_backend_webgpu_buffer_type_get_max_size (ggml_backend_buffer_type_t buft) {
456- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_buffer_type_get_max_size()" );
457483 ggml_backend_webgpu_device_context * ctx = static_cast <ggml_backend_webgpu_device_context *>(buft->device ->context );
458484 return ctx->webgpu_ctx ->limits .maxStorageBufferBindingSize ;
459485}
@@ -473,16 +499,13 @@ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_
473499}
474500
475501static void ggml_backend_webgpu_device_get_memory (ggml_backend_dev_t dev, size_t * free, size_t * total) {
476- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_device_get_memory()" );
477-
478502 ggml_backend_webgpu_device_context * ctx = static_cast <ggml_backend_webgpu_device_context *>(dev->context );
479503 // TODO: what do we actually want to return here?
480504 *free = ctx->webgpu_ctx ->limits .maxBufferSize * WEBGPU_MAX_BUFFERS;
481505 *total = ctx->webgpu_ctx ->limits .maxBufferSize * WEBGPU_MAX_BUFFERS;
482506}
483507
484508static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type (ggml_backend_dev_t dev) {
485- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_device_get_type()" );
486509 GGML_UNUSED (dev);
487510 return GGML_BACKEND_DEVICE_TYPE_GPU;
488511}
@@ -526,11 +549,10 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) {
526549
527550static void ggml_webgpu_init_mul_mat_pipeline (webgpu_context webgpu_ctx) {
528551 ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_mat_pipeline , wgsl_mul_mat, " mul_mat" );
529- ggml_webgpu_create_buffer (webgpu_ctx->device , webgpu_ctx->mul_mat_params_dev_buf ,
530- 3 * sizeof (uint32_t ), // 3 parameters: M, N, K
552+ ggml_webgpu_create_buffer (webgpu_ctx->device , webgpu_ctx->mul_mat_params_dev_buf , WEBGPU_MUL_MAT_PARAMS_SIZE,
531553 wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst);
532- ggml_webgpu_create_buffer (webgpu_ctx->device , webgpu_ctx->mul_mat_params_host_buf ,
533- 3 * sizeof ( uint32_t ), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc);
554+ ggml_webgpu_create_buffer (webgpu_ctx->device , webgpu_ctx->mul_mat_params_host_buf ,WEBGPU_MUL_MAT_PARAMS_SIZE,
555+ wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc);
534556}
535557
536558// TODO: Does this need to be thread safe? Is it only called once?
@@ -617,13 +639,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
617639 // what should we support first?
618640 switch (op->op ) {
619641 case GGML_OP_NONE:
620- case GGML_OP_RESHAPE:
621- case GGML_OP_VIEW:
622- case GGML_OP_PERMUTE:
623- case GGML_OP_TRANSPOSE:
624- case GGML_OP_MUL_MAT:
625642 return true ;
626-
643+ case GGML_OP_MUL_MAT:
644+ return op->src [0 ]->type == GGML_TYPE_F32 && op->src [1 ]->type == GGML_TYPE_F32;
627645 default :
628646 return false ;
629647 }
@@ -652,13 +670,11 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
652670/* GGML Backend Registration Interface */
653671
654672static const char * ggml_backend_webgpu_reg_get_name (ggml_backend_reg_t reg) {
655- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_reg_get_name()" );
656673 ggml_backend_webgpu_reg_context * ctx = static_cast <ggml_backend_webgpu_reg_context *>(reg->context );
657674 return ctx->name ;
658675}
659676
660677static size_t ggml_backend_webgpu_reg_get_device_count (ggml_backend_reg_t reg) {
661- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_reg_get_device_count()" );
662678 ggml_backend_webgpu_reg_context * ctx = static_cast <ggml_backend_webgpu_reg_context *>(reg->context );
663679 return ctx->device_count ;
664680}
0 commit comments