@@ -50,13 +50,6 @@ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
5050
5151/* Struct definitions */
5252
53- struct webgpu_pipeline_info {
54- std::string name;
55- const char * shader_code;
56- ggml_type src0_type;
57- ggml_type src1_type;
58- };
59-
6053// Forward reference
6154static void ggml_webgpu_create_buffer (wgpu::Device & device,
6255 wgpu::Buffer & buffer,
@@ -571,12 +564,12 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
571564 (uint32_t ) dst->ne [1 ], // number of rows in result (M)
572565 (uint32_t ) dst->ne [0 ], // number of columns in result (N)
573566 (uint32_t ) src0->ne [0 ], // number of columns in src0/src1 (K)
574- (uint32_t ) (src0->nb [1 ] / ggml_type_size (src0->type )), // stride (elements) of src0 in dimension 1
575- (uint32_t ) (src1->nb [1 ] / ggml_type_size (src1->type )), // stride (elements) of src1 in dimension 1
576- (uint32_t ) (src0->nb [2 ] / ggml_type_size (src0->type )), // stride (elements) of src0 in dimension 2
577- (uint32_t ) (src1->nb [2 ] / ggml_type_size (src1->type )), // stride (elements) of src1 in dimension 2
578- (uint32_t ) (src0->nb [3 ] / ggml_type_size (src0->type )), // stride (elements) of src0 in dimension 3
579- (uint32_t ) (src1->nb [3 ] / ggml_type_size (src1->type )), // stride (elements) of src1 in dimension 3
567+ (uint32_t ) (src0->nb [1 ] / ggml_type_size (src0->type )), // stride (elements/blocks ) of src0 in dimension 1
568+ (uint32_t ) (src1->nb [1 ] / ggml_type_size (src1->type )), // stride (elements/blocks ) of src1 in dimension 1
569+ (uint32_t ) (src0->nb [2 ] / ggml_type_size (src0->type )), // stride (elements/blocks ) of src0 in dimension 2
570+ (uint32_t ) (src1->nb [2 ] / ggml_type_size (src1->type )), // stride (elements/blocks ) of src1 in dimension 2
571+ (uint32_t ) (src0->nb [3 ] / ggml_type_size (src0->type )), // stride (elements/blocks ) of src0 in dimension 3
572+ (uint32_t ) (src1->nb [3 ] / ggml_type_size (src1->type )), // stride (elements/blocks ) of src1 in dimension 3
580573 (uint32_t ) src0->ne [2 ], // batch size in dimension 2
581574 (uint32_t ) src0->ne [3 ], // batch size in dimension 3
582575 (uint32_t ) (src1->ne [2 ] / src0->ne [2 ]), // broadcast in dimension 2
@@ -596,16 +589,11 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
596589 .buffer = ggml_webgpu_tensor_buf (dst),
597590 .offset = ggml_webgpu_tensor_align_offset (ctx, dst),
598591 .size = ggml_webgpu_tensor_binding_size (ctx, dst) },
599- // { .binding = 3,
600- // .buffer = ctx->debug_dev_buf,
601- // .offset = 0,
602- // .size = ctx->debug_dev_buf.GetSize() }
603592 };
604593
605594 uint32_t wg_x =
606595 (dst->ne [0 ] * dst->ne [1 ] * dst->ne [2 ] * dst->ne [3 ] + WEBGPU_MUL_MAT_WG_SIZE - 1 ) / WEBGPU_MUL_MAT_WG_SIZE;
607596 ggml_backend_webgpu_build_and_enqueue (ctx, ctx->mul_mat_pipeline [src0->type ][src1->type ], params, entries, wg_x);
608- // ggml_backend_webgpu_debug(ctx);
609597}
610598
611599// Returns true if node has enqueued work into the queue, false otherwise
@@ -915,103 +903,94 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
915903}
916904
917905static void ggml_webgpu_init_mul_mat_pipeline (webgpu_context & webgpu_ctx) {
918- webgpu_pipeline_info pipeline_infos[22 ] = {
919- { .name = " mul_mat_f32_f32" ,
920- .shader_code = wgsl_mul_mat_f32_f32,
921- .src0_type = GGML_TYPE_F32,
922- .src1_type = GGML_TYPE_F32 },
923- { .name = " mul_mat_f16_f16" ,
924- .shader_code = wgsl_mul_mat_f16_f16,
925- .src0_type = GGML_TYPE_F16,
926- .src1_type = GGML_TYPE_F16 },
927- { .name = " mul_mat_f16_f32" ,
928- .shader_code = wgsl_mul_mat_f16_f32,
929- .src0_type = GGML_TYPE_F16,
930- .src1_type = GGML_TYPE_F32 },
931- { .name = " mul_mat_q4_0_f32" ,
932- .shader_code = wgsl_mul_mat_q4_0_f32,
933- .src0_type = GGML_TYPE_Q4_0,
934- .src1_type = GGML_TYPE_F32 },
935- { .name = " mul_mat_q4_1_f32" ,
936- .shader_code = wgsl_mul_mat_q4_1_f32,
937- .src0_type = GGML_TYPE_Q4_1,
938- .src1_type = GGML_TYPE_F32 },
939- { .name = " mul_mat_q5_0_f32" ,
940- .shader_code = wgsl_mul_mat_q5_0_f32,
941- .src0_type = GGML_TYPE_Q5_0,
942- .src1_type = GGML_TYPE_F32 },
943- { .name = " mul_mat_q5_1_f32" ,
944- .shader_code = wgsl_mul_mat_q5_1_f32,
945- .src0_type = GGML_TYPE_Q5_1,
946- .src1_type = GGML_TYPE_F32 },
947- { .name = " mul_mat_q8_0_f32" ,
948- .shader_code = wgsl_mul_mat_q8_0_f32,
949- .src0_type = GGML_TYPE_Q8_0,
950- .src1_type = GGML_TYPE_F32 },
951- { .name = " mul_mat_q2_k_f32" ,
952- .shader_code = wgsl_mul_mat_q2_k_f32,
953- .src0_type = GGML_TYPE_Q2_K,
954- .src1_type = GGML_TYPE_F32 },
955- { .name = " mul_mat_q3_k_f32" ,
956- .shader_code = wgsl_mul_mat_q3_k_f32,
957- .src0_type = GGML_TYPE_Q3_K,
958- .src1_type = GGML_TYPE_F32 },
959- { .name = " mul_mat_q4_k_f32" ,
960- .shader_code = wgsl_mul_mat_q4_k_f32,
961- .src0_type = GGML_TYPE_Q4_K,
962- .src1_type = GGML_TYPE_F32 },
963- { .name = " mul_mat_q5_k_f32" ,
964- .shader_code = wgsl_mul_mat_q5_k_f32,
965- .src0_type = GGML_TYPE_Q5_K,
966- .src1_type = GGML_TYPE_F32 },
967- { .name = " mul_mat_q6_k_f32" ,
968- .shader_code = wgsl_mul_mat_q6_k_f32,
969- .src0_type = GGML_TYPE_Q6_K,
970- .src1_type = GGML_TYPE_F32 },
971- { .name = " mul_mat_iq2_xxs_f32" ,
972- .shader_code = wgsl_mul_mat_iq2_xxs_f32,
973- .src0_type = GGML_TYPE_IQ2_XXS,
974- .src1_type = GGML_TYPE_F32 },
975- { .name = " mul_mat_iq2_xs_f32" ,
976- .shader_code = wgsl_mul_mat_iq2_xs_f32,
977- .src0_type = GGML_TYPE_IQ2_XS,
978- .src1_type = GGML_TYPE_F32 },
979- { .name = " mul_mat_iq2_s_f32" ,
980- .shader_code = wgsl_mul_mat_iq2_s_f32,
981- .src0_type = GGML_TYPE_IQ2_S,
982- .src1_type = GGML_TYPE_F32 },
983- { .name = " mul_mat_iq3_xxs_f32" ,
984- .shader_code = wgsl_mul_mat_iq3_xxs_f32,
985- .src0_type = GGML_TYPE_IQ3_XXS,
986- .src1_type = GGML_TYPE_F32 },
987- { .name = " mul_mat_iq3_s_f32" ,
988- .shader_code = wgsl_mul_mat_iq3_s_f32,
989- .src0_type = GGML_TYPE_IQ3_S,
990- .src1_type = GGML_TYPE_F32 },
991- { .name = " mul_mat_iq1_s_f32" ,
992- .shader_code = wgsl_mul_mat_iq1_s_f32,
993- .src0_type = GGML_TYPE_IQ1_S,
994- .src1_type = GGML_TYPE_F32 },
995- { .name = " mul_mat_iq1_m_f32" ,
996- .shader_code = wgsl_mul_mat_iq1_m_f32,
997- .src0_type = GGML_TYPE_IQ1_M,
998- .src1_type = GGML_TYPE_F32 },
999- { .name = " mul_mat_iq4_nl_f32" ,
1000- .shader_code = wgsl_mul_mat_iq4_nl_f32,
1001- .src0_type = GGML_TYPE_IQ4_NL,
1002- .src1_type = GGML_TYPE_F32 },
1003- { .name = " mul_mat_iq4_xs_f32" ,
1004- .shader_code = wgsl_mul_mat_iq4_xs_f32,
1005- .src0_type = GGML_TYPE_IQ4_XS,
1006- .src1_type = GGML_TYPE_F32 }
1007- };
1008-
1009- for (auto & pipeline_info : pipeline_infos) {
1010- ggml_webgpu_create_pipeline (webgpu_ctx->device ,
1011- webgpu_ctx->mul_mat_pipeline [pipeline_info.src0_type ][pipeline_info.src1_type ],
1012- pipeline_info.shader_code ,
1013- pipeline_info.name .data ());
1014- }
906+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
907+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_F32][GGML_TYPE_F32],
908+ wgsl_mul_mat_f32_f32,
909+ " mul_mat_f32_f32" );
910+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
911+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_F16][GGML_TYPE_F16],
912+ wgsl_mul_mat_f16_f16,
913+ " mul_mat_f16_f16" );
914+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
915+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_F16][GGML_TYPE_F32],
916+ wgsl_mul_mat_f16_f32,
917+ " mul_mat_f16_f32" );
918+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
919+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_Q4_0][GGML_TYPE_F32],
920+ wgsl_mul_mat_q4_0_f32,
921+ " mul_mat_q4_0_f32" );
922+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
923+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_Q4_1][GGML_TYPE_F32],
924+ wgsl_mul_mat_q4_1_f32,
925+ " mul_mat_q4_1_f32" );
926+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
927+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_Q5_0][GGML_TYPE_F32],
928+ wgsl_mul_mat_q5_0_f32,
929+ " mul_mat_q5_0_f32" );
930+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
931+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_Q5_1][GGML_TYPE_F32],
932+ wgsl_mul_mat_q5_1_f32,
933+ " mul_mat_q5_1_f32" );
934+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
935+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_Q8_0][GGML_TYPE_F32],
936+ wgsl_mul_mat_q8_0_f32,
937+ " mul_mat_q8_0_f32" );
938+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
939+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_Q2_K][GGML_TYPE_F32],
940+ wgsl_mul_mat_q2_k_f32,
941+ " mul_mat_q2_k_f32" );
942+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
943+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_Q3_K][GGML_TYPE_F32],
944+ wgsl_mul_mat_q3_k_f32,
945+ " mul_mat_q3_k_f32" );
946+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
947+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_Q4_K][GGML_TYPE_F32],
948+ wgsl_mul_mat_q4_k_f32,
949+ " mul_mat_q4_k_f32" );
950+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
951+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_Q5_K][GGML_TYPE_F32],
952+ wgsl_mul_mat_q5_k_f32,
953+ " mul_mat_q5_k_f32" );
954+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
955+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_Q6_K][GGML_TYPE_F32],
956+ wgsl_mul_mat_q6_k_f32,
957+ " mul_mat_q6_k_f32" );
958+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
959+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
960+ wgsl_mul_mat_iq2_xxs_f32,
961+ " mul_mat_iq2_xxs_f32" );
962+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
963+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
964+ wgsl_mul_mat_iq2_xs_f32,
965+ " mul_mat_iq2_xs_f32" );
966+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
967+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_IQ2_S][GGML_TYPE_F32],
968+ wgsl_mul_mat_iq2_s_f32,
969+ " mul_mat_iq2_s_f32" );
970+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
971+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
972+ wgsl_mul_mat_iq3_xxs_f32,
973+ " mul_mat_iq3_xxs_f32" );
974+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
975+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_IQ3_S][GGML_TYPE_F32],
976+ wgsl_mul_mat_iq3_s_f32,
977+ " mul_mat_iq3_s_f32" );
978+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
979+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_IQ1_S][GGML_TYPE_F32],
980+ wgsl_mul_mat_iq1_s_f32,
981+ " mul_mat_iq1_s_f32" );
982+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
983+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_IQ1_M][GGML_TYPE_F32],
984+ wgsl_mul_mat_iq1_m_f32,
985+ " mul_mat_iq1_m_f32" );
986+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
987+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
988+ wgsl_mul_mat_iq4_nl_f32,
989+ " mul_mat_iq4_nl_f32" );
990+ ggml_webgpu_create_pipeline (webgpu_ctx->device ,
991+ webgpu_ctx->mul_mat_pipeline [GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
992+ wgsl_mul_mat_iq4_xs_f32,
993+ " mul_mat_iq4_xs_f32" );
1015994}
1016995
1017996static void ggml_webgpu_init_set_rows_pipeline (webgpu_context & webgpu_ctx) {
0 commit comments