@@ -50,6 +50,13 @@ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
5050
5151/* Struct definitions */
5252
53+ struct webgpu_pipeline_info {
54+ std::string name;
55+ const char * shader_code;
56+ ggml_type src0_type;
57+ ggml_type src1_type;
58+ };
59+
5360// Forward reference
5461static void ggml_webgpu_create_buffer (wgpu::Device & device,
5562 wgpu::Buffer & buffer,
@@ -124,7 +131,8 @@ struct webgpu_context_struct {
124131 webgpu_buf_pool set_rows_error_buf_pool;
125132
126133 wgpu::ComputePipeline memset_pipeline;
127- wgpu::ComputePipeline mul_mat_pipeline;
134+ // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16]
135+ wgpu::ComputePipeline mul_mat_pipeline[2 ][2 ];
128136 wgpu::ComputePipeline set_rows_pipeline;
129137 wgpu::ComputePipeline cpy_pipeline;
130138
@@ -227,6 +235,15 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
227235
228236/* * End WebGPU object initializations */
229237
238+ /* * Utility Functions */
239+
240+ size_t ggml_webgpu_binding_size (ggml_tensor * t, size_t misalignment) {
241+ return (ggml_nbytes (t) + misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1 ) &
242+ ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1 );
243+ }
244+
245+ /* * End Utility Functions */
246+
230247/* * WebGPU Actions */
231248
232249// Wait for the queue to finish processing all submitted work
@@ -479,13 +496,11 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
479496 { .binding = 0 ,
480497 .buffer = ggml_backend_webgpu_tensor_buf (src),
481498 .offset = src_offset,
482- .size = (ggml_nbytes (src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1 ) &
483- ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1 ) },
499+ .size = ggml_webgpu_binding_size (src, src_misalignment) },
484500 { .binding = 1 ,
485501 .buffer = ggml_backend_webgpu_tensor_buf (dst),
486502 .offset = dst_offset,
487- .size = (ggml_nbytes (dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1 ) &
488- ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1 ) }
503+ .size = ggml_webgpu_binding_size (dst, dst_misalignment) }
489504 };
490505
491506 size_t max_wg_size = ctx->limits .maxComputeWorkgroupSizeX ;
@@ -542,15 +557,15 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
542557 { .binding = 0 ,
543558 .buffer = ggml_backend_webgpu_tensor_buf (src),
544559 .offset = ggml_backend_webgpu_tensor_offset (src),
545- .size = ggml_nbytes (src) },
560+ .size = ggml_webgpu_binding_size (src, src_misalignment) },
546561 { .binding = 1 ,
547562 .buffer = ggml_backend_webgpu_tensor_buf (idx),
548563 .offset = ggml_backend_webgpu_tensor_offset (idx),
549- .size = ggml_nbytes (idx) },
564+ .size = ggml_webgpu_binding_size (idx, idx_misalignment) },
550565 { .binding = 2 ,
551566 .buffer = ggml_backend_webgpu_tensor_buf (dst),
552567 .offset = ggml_backend_webgpu_tensor_offset (dst),
553- .size = ggml_nbytes (dst) },
568+ .size = ggml_webgpu_binding_size (dst, dst_misalignment) },
554569 { .binding = 3 , .buffer = error_bufs.dev_buf , .offset = 0 , .size = error_bufs.dev_buf .GetSize () }
555570 };
556571
@@ -564,7 +579,21 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
564579}
565580
566581static void ggml_webgpu_mul_mat (webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
582+ size_t src0_offset = ggml_backend_webgpu_tensor_offset (src0);
583+ size_t src0_misalignment = src0_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
584+ // align to minimum offset alignment
585+ src0_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
586+ size_t src1_offset = ggml_backend_webgpu_tensor_offset (src1);
587+ size_t src1_misalignment = src1_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
588+ src1_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
589+ size_t dst_offset = ggml_backend_webgpu_tensor_offset (dst);
590+ size_t dst_misalignment = dst_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
591+ dst_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
592+
567593 std::vector<uint32_t > params = {
594+ (uint32_t ) (src0_misalignment / ggml_type_size (src0->type )),
595+ (uint32_t ) (src1_misalignment / ggml_type_size (src1->type )),
596+ (uint32_t ) (dst_misalignment / ggml_type_size (dst->type )),
568597 (uint32_t ) dst->ne [1 ], // number of rows in result (M)
569598 (uint32_t ) dst->ne [0 ], // number of columns in result (N)
570599 (uint32_t ) src0->ne [0 ], // number of columns in src0/src1 (K)
@@ -584,20 +613,20 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
584613 { .binding = 0 ,
585614 .buffer = ggml_backend_webgpu_tensor_buf (src0),
586615 .offset = ggml_backend_webgpu_tensor_offset (src0),
587- .size = ggml_nbytes (src0) },
616+ .size = ggml_webgpu_binding_size (src0, src0_misalignment ) },
588617 { .binding = 1 ,
589618 .buffer = ggml_backend_webgpu_tensor_buf (src1),
590619 .offset = ggml_backend_webgpu_tensor_offset (src1),
591- .size = ggml_nbytes (src1) },
620+ .size = ggml_webgpu_binding_size (src1, src1_misalignment ) },
592621 { .binding = 2 ,
593622 .buffer = ggml_backend_webgpu_tensor_buf (dst),
594623 .offset = ggml_backend_webgpu_tensor_offset (dst),
595- .size = ggml_nbytes (dst) }
624+ .size = ggml_webgpu_binding_size (dst, dst_misalignment) }
596625 };
597626
598627 uint32_t wg_x =
599628 (dst->ne [0 ] * dst->ne [1 ] * dst->ne [2 ] * dst->ne [3 ] + WEBGPU_MUL_MAT_WG_SIZE - 1 ) / WEBGPU_MUL_MAT_WG_SIZE;
600- ggml_backend_webgpu_build_and_enqueue (ctx, ctx->mul_mat_pipeline , params, entries, wg_x);
629+ ggml_backend_webgpu_build_and_enqueue (ctx, ctx->mul_mat_pipeline [src0-> type ][src1-> type ] , params, entries, wg_x);
601630}
602631
603632// Returns true if node has enqueued work into the queue, false otherwise
@@ -907,7 +936,31 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
907936}
908937
909938static void ggml_webgpu_init_mul_mat_pipeline (webgpu_context & webgpu_ctx) {
910- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_mat_pipeline , wgsl_mul_mat, " mul_mat" );
939+ webgpu_pipeline_info pipeline_infos[4 ] = {
940+ { .name = " mul_mat_f32_f32" ,
941+ .shader_code = wgsl_mul_mat_f32_f32,
942+ .src0_type = GGML_TYPE_F32,
943+ .src1_type = GGML_TYPE_F32 },
944+ { .name = " mul_mat_f16_f16" ,
945+ .shader_code = wgsl_mul_mat_f16_f16,
946+ .src0_type = GGML_TYPE_F16,
947+ .src1_type = GGML_TYPE_F16 },
948+ { .name = " mul_mat_f32_f16" ,
949+ .shader_code = wgsl_mul_mat_f32_f16,
950+ .src0_type = GGML_TYPE_F32,
951+ .src1_type = GGML_TYPE_F16 },
952+ { .name = " mul_mat_f16_f32" ,
953+ .shader_code = wgsl_mul_mat_f16_f32,
954+ .src0_type = GGML_TYPE_F16,
955+ .src1_type = GGML_TYPE_F32 }
956+ };
957+
958+ for (auto & pipeline_info : pipeline_infos) {
959+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
960+ webgpu_ctx->mul_mat_pipeline [pipeline_info.src0_type ][pipeline_info.src1_type ],
961+ pipeline_info.shader_code ,
962+ pipeline_info.name .data ());
963+ }
911964}
912965
913966static void ggml_webgpu_init_set_rows_pipeline (webgpu_context & webgpu_ctx) {
@@ -1056,7 +1109,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
10561109 case GGML_OP_CPY | GGML_OP_SET_ROWS:
10571110 return op->type == GGML_TYPE_F16 && op->src [0 ]->type == GGML_TYPE_F32;
10581111 case GGML_OP_MUL_MAT:
1059- return op->src [0 ]->type == GGML_TYPE_F32 && op->src [1 ]->type == GGML_TYPE_F32;
1112+ return (op->src [0 ]->type == GGML_TYPE_F32 || op->src [0 ]->type == GGML_TYPE_F16) &&
1113+ (op->src [1 ]->type == GGML_TYPE_F32 || op->src [1 ]->type == GGML_TYPE_F16);
10601114 default :
10611115 return false ;
10621116 }
0 commit comments