@@ -216,6 +216,15 @@ struct vk_device_struct {
216216
217217 vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
218218 vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
219+
220+ vk_pipeline pipeline_emb_mul_mat_mat_f32;
221+ vk_pipeline pipeline_emb_mul_mat_f16;
222+ vk_pipeline pipeline_emb_mul_mat_q4_0;
223+ vk_pipeline pipeline_emb_mul_mat_q4_1;
224+ vk_pipeline pipeline_emb_mul_mat_q4_k;
225+ vk_pipeline pipeline_emb_mul_mat_q6_k;
226+ vk_pipeline pipeline_emb_mul_mat_q8_0;
227+
219228 vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
220229 vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
221230 vk_pipeline pipeline_acc_f32;
@@ -1965,6 +1974,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
19651974 ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 , " mul_mat_vec_p021_f16_f32" , mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, " main" , 3 , 6 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
19661975 ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_nc_f16_f32 , " mul_mat_vec_nc_f16_f32" , mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, " main" , 3 , 7 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
19671976
1977+ ggml_vk_create_pipeline (device, device->pipeline_emb_mul_mat_mat_f32 , " emb_mul_mat_mat_f32" , emb_mul_mat_mat_f32_len, emb_mul_mat_mat_f32_data, " main" , 3 , 14 * sizeof (uint32_t ), {1 , 1 , 1 }, {device->subgroup_size }, 1 );
1978+ ggml_vk_create_pipeline (device, device->pipeline_emb_mul_mat_f16 , " emb_mul_mat_f16" , emb_mul_mat_f16_len, emb_mul_mat_f16_data, " main" , 3 , 21 * sizeof (uint32_t ), {1 , 1 , 1 }, {device->subgroup_size * 2 }, 1 );
1979+ ggml_vk_create_pipeline (device, device->pipeline_emb_mul_mat_q4_0 , " emb_mul_mat_q4_0" , emb_mul_mat_q4_0_len, emb_mul_mat_q4_0_data, " main" , 3 , 18 * sizeof (uint32_t ), {1 , 1 , 1 }, {(device->subgroup_size * 2 ) / 8 }, 1 );
1980+ ggml_vk_create_pipeline (device, device->pipeline_emb_mul_mat_q4_1 , " emb_mul_mat_q4_1" , emb_mul_mat_q4_1_len, emb_mul_mat_q4_1_data, " main" , 3 , 18 * sizeof (uint32_t ), {1 , 1 , 1 }, {(device->subgroup_size * 2 ) / 8 }, 1 );
1981+ ggml_vk_create_pipeline (device, device->pipeline_emb_mul_mat_q4_k , " emb_mul_mat_q4_k" , emb_mul_mat_q4_k_len, emb_mul_mat_q4_k_data, " main" , 3 , 18 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
1982+ ggml_vk_create_pipeline (device, device->pipeline_emb_mul_mat_q6_k , " emb_mul_mat_q6_k" , emb_mul_mat_q6_k_len, emb_mul_mat_q6_k_data, " main" , 3 , 18 * sizeof (uint32_t ), {1 , 1 , 1 }, {2 , device->subgroup_size }, 1 );
1983+ ggml_vk_create_pipeline (device, device->pipeline_emb_mul_mat_q8_0 , " emb_mul_mat_q8_0" , emb_mul_mat_q8_0_len, emb_mul_mat_q8_0_data, " main" , 3 , 18 * sizeof (uint32_t ), {1 , 1 , 1 }, {(device->subgroup_size * 2 ) / 8 }, 1 );
1984+
19681985 ggml_vk_create_pipeline (device, device->pipeline_norm_f32 , " norm_f32" , norm_f32_len, norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
19691986 ggml_vk_create_pipeline (device, device->pipeline_group_norm_f32 , " group_norm_f32" , group_norm_f32_len, group_norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
19701987 ggml_vk_create_pipeline (device, device->pipeline_rms_norm_f32 , " rms_norm_f32" , rms_norm_f32_len, rms_norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
@@ -4371,6 +4388,167 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
43714388 }
43724389}
43734390
4391+ static void ggml_vkemb_mul_mat (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
4392+ VK_LOG_DEBUG (" ggml_vkemb_mul_mat(" << src0 << " , name=" << src0->name << " , type=" << src0->type << " , ne0=" << src0->ne [0 ] << " , ne1=" << src0->ne [1 ] << " , ne2=" << src0->ne [2 ] << " , ne3=" << src0->ne [3 ] << " , nb0=" << src0->nb [0 ] << " , nb1=" << src0->nb [1 ] << " , nb2=" << src0->nb [2 ] << " , nb3=" << src0->nb [3 ];
4393+ std::cerr << " ), (" << src1 << " , name=" << src1->name << " , type=" << src1->type << " , ne0=" << src1->ne [0 ] << " , ne1=" << src1->ne [1 ] << " , ne2=" << src1->ne [2 ] << " , ne3=" << src1->ne [3 ] << " , nb0=" << src1->nb [0 ] << " , nb1=" << src1->nb [1 ] << " , nb2=" << src1->nb [2 ] << " , nb3=" << src1->nb [3 ];
4394+ std::cerr << " ), (" << dst << " , name=" << dst->name << " , type=" << dst->type << " , ne0=" << dst->ne [0 ] << " , ne1=" << dst->ne [1 ] << " , ne2=" << dst->ne [2 ] << " , ne3=" << dst->ne [3 ] << " , nb0=" << dst->nb [0 ] << " , nb1=" << dst->nb [1 ] << " , nb2=" << dst->nb [2 ] << " , nb3=" << dst->nb [3 ];
4395+ std::cerr << " ), " << (dryrun ? " dryrun" : " " ) << " )" );
4396+
4397+ const uint64_t ne00 = src0->ne [0 ];
4398+ const uint64_t ne01 = src0->ne [1 ];
4399+ const uint64_t ne02 = src0->ne [2 ];
4400+ const uint64_t ne03 = src0->ne [3 ];
4401+
4402+ const uint64_t ne10 = src1->ne [0 ];
4403+ const uint64_t ne11 = src1->ne [1 ];
4404+ const uint64_t ne12 = src1->ne [2 ];
4405+ const uint64_t ne13 = src1->ne [3 ];
4406+
4407+ const uint64_t ne0 = dst->ne [0 ];
4408+ const uint64_t ne1 = dst->ne [1 ];
4409+ const uint64_t ne2 = dst->ne [2 ];
4410+ const uint64_t ne3 = dst->ne [3 ];
4411+
4412+ const uint64_t nb00 = src0->nb [0 ];
4413+ const uint64_t nb01 = src0->nb [1 ];
4414+ const uint64_t nb02 = src0->nb [2 ];
4415+ const uint64_t nb03 = src0->nb [3 ];
4416+
4417+ const uint64_t nb10 = src1->nb [0 ];
4418+ const uint64_t nb11 = src1->nb [1 ];
4419+ const uint64_t nb12 = src1->nb [2 ];
4420+ const uint64_t nb13 = src1->nb [3 ];
4421+
4422+ const uint64_t nb1 = dst->nb [1 ];
4423+ const uint64_t nb2 = dst->nb [2 ];
4424+
4425+ const uint64_t r2 = ne12 / ne02;
4426+ const uint64_t r3 = ne13 / ne03;
4427+
4428+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer ->context ;
4429+ ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer ->context ;
4430+ ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer ->context ;
4431+
4432+ vk_buffer d_Qx = nullptr ;
4433+ size_t qx_buf_offset = 0 ;
4434+ vk_buffer d_Qy = nullptr ;
4435+ size_t qy_buf_offset = 0 ;
4436+
4437+ bool src0_uma = false ;
4438+ bool src1_uma = false ;
4439+
4440+ if (ctx->device ->uma ) {
4441+ ggml_vk_host_get (ctx->device , src0->data , d_Qx, qx_buf_offset);
4442+ ggml_vk_host_get (ctx->device , src1->data , d_Qy, qy_buf_offset);
4443+ src0_uma = d_Qx != nullptr ;
4444+ src1_uma = d_Qy != nullptr ;
4445+ }
4446+
4447+ const uint64_t x_ne = ne00 * ne01 * ne02 * ne03;
4448+ const uint64_t y_ne = ne10 * ne11 * ne12 * ne13;
4449+ const uint64_t d_ne = ne0 * ne1 * ne2 * ne3;
4450+
4451+ const uint64_t qx_sz = ggml_vk_align_size (ggml_type_size (src0->type ) * x_ne / ggml_blck_size (src0->type ), ctx->device ->properties .limits .minStorageBufferOffsetAlignment );
4452+ const uint64_t qy_sz = ggml_type_size (src1->type ) * y_ne / ggml_blck_size (src1->type );
4453+ const uint64_t d_sz = sizeof (float ) * d_ne;
4454+
4455+ vk_pipeline pipeline;
4456+ switch (src0->type ) {
4457+ case GGML_TYPE_F32:
4458+ pipeline = ctx->device ->pipeline_emb_mul_mat_mat_f32 ;
4459+ break ;
4460+ case GGML_TYPE_F16:
4461+ pipeline = ctx->device ->pipeline_emb_mul_mat_f16 ;
4462+ break ;
4463+ case GGML_TYPE_Q4_0:
4464+ pipeline = ctx->device ->pipeline_emb_mul_mat_q4_0 ;
4465+ break ;
4466+ case GGML_TYPE_Q4_1:
4467+ pipeline = ctx->device ->pipeline_emb_mul_mat_q4_1 ;
4468+ break ;
4469+ case GGML_TYPE_Q4_K:
4470+ pipeline = ctx->device ->pipeline_emb_mul_mat_q4_k ;
4471+ break ;
4472+ case GGML_TYPE_Q6_K:
4473+ pipeline = ctx->device ->pipeline_emb_mul_mat_q6_k ;
4474+ break ;
4475+ case GGML_TYPE_Q8_0:
4476+ pipeline = ctx->device ->pipeline_emb_mul_mat_q8_0 ;
4477+ break ;
4478+ default :
4479+ GGML_ABORT (" vkemb_mul_mat: unsupported quantization type: %d" , src0->type );
4480+ }
4481+
4482+ if (dryrun) {
4483+ // Request descriptor sets
4484+ ggml_pipeline_request_descriptor_sets (ctx->device , pipeline, 1 );
4485+ return ;
4486+ }
4487+
4488+ vk_buffer d_D = dst_buf_ctx->dev_buffer ;
4489+ const uint64_t d_buf_offset = vk_tensor_offset (dst) + dst->view_offs ;
4490+ GGML_ASSERT (d_D != nullptr );
4491+ GGML_ASSERT (d_D->size >= d_buf_offset + d_sz);
4492+ if (!src0_uma) {
4493+ d_Qx = src0_buf_ctx->dev_buffer ;
4494+ qx_buf_offset = vk_tensor_offset (src0) + src0->view_offs ;
4495+ GGML_ASSERT (d_Qx != nullptr );
4496+ }
4497+ if (!src1_uma) {
4498+ d_Qy = src1_buf_ctx->dev_buffer ;
4499+ qy_buf_offset = vk_tensor_offset (src1) + src1->view_offs ;
4500+ GGML_ASSERT (d_Qy != nullptr );
4501+ }
4502+
4503+ const uint64_t qx_buffer_offset = (qx_buf_offset / ctx->device ->properties .limits .minStorageBufferOffsetAlignment ) * ctx->device ->properties .limits .minStorageBufferOffsetAlignment ;
4504+ const uint64_t qx_shader_offset = qx_buf_offset - qx_buffer_offset;
4505+
4506+ const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device ->properties .limits .minStorageBufferOffsetAlignment ) * ctx->device ->properties .limits .minStorageBufferOffsetAlignment ;
4507+ const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
4508+
4509+ const uint64_t d_buffer_offset = (d_buf_offset / ctx->device ->properties .limits .minStorageBufferOffsetAlignment ) * ctx->device ->properties .limits .minStorageBufferOffsetAlignment ;
4510+ const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
4511+
4512+ // compute
4513+ ggml_vk_sync_buffers (subctx);
4514+ switch (src0->type ) {
4515+ case GGML_TYPE_F32:
4516+ {
4517+ const std::array<uint32_t , 14 > pc = { (uint32_t )qx_shader_offset, (uint32_t )(qy_shader_offset / ggml_type_size (src1->type )), (uint32_t )(d_shader_offset / ggml_type_size (dst->type )), (uint32_t )ne00, (uint32_t )ne01, (uint32_t )ne02, (uint32_t )ne11, (uint32_t )ne12, (uint32_t )nb01, (uint32_t )nb02, (uint32_t )nb11, (uint32_t )nb12, (uint32_t )nb1, (uint32_t )nb2 };
4518+ ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 14 * sizeof (uint32_t ), &pc, { (uint32_t )ne01, (uint32_t )ne11, (uint32_t )std::max (ne12, ne02) });
4519+ break ;
4520+ }
4521+ case GGML_TYPE_F16:
4522+ {
4523+ const std::array<uint32_t , 21 > pc = { (uint32_t )qx_shader_offset, (uint32_t )(qy_shader_offset / ggml_type_size (src1->type )), (uint32_t )(d_shader_offset / ggml_type_size (dst->type )), (uint32_t )ne00, (uint32_t )ne01, (uint32_t )ne02, (uint32_t )nb00, (uint32_t )nb01, (uint32_t )nb02, (uint32_t )nb03, (uint32_t )ne10, (uint32_t )ne11, (uint32_t )ne12, (uint32_t )nb10, (uint32_t )nb11, (uint32_t )nb12, (uint32_t )nb13, (uint32_t )ne0, (uint32_t )ne1, (uint32_t )r2, (uint32_t )r3 };
4524+ ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 21 * sizeof (uint32_t ), &pc, { (uint32_t )ne01, (uint32_t )((ne11 + 3 ) / 4 ), (uint32_t )(ne12 * ne13) });
4525+ break ;
4526+ }
4527+ case GGML_TYPE_Q4_0:
4528+ case GGML_TYPE_Q4_1:
4529+ case GGML_TYPE_Q8_0:
4530+ {
4531+ const std::array<uint32_t , 18 > pc = { (uint32_t )qx_shader_offset, (uint32_t )(qy_shader_offset / ggml_type_size (src1->type )), (uint32_t )(d_shader_offset / ggml_type_size (dst->type )), (uint32_t )ne00, (uint32_t )ne01, (uint32_t )ne02, (uint32_t )ne10, (uint32_t )ne12, (uint32_t )ne0, (uint32_t )ne1, (uint32_t )nb01, (uint32_t )nb02, (uint32_t )nb03, (uint32_t )nb11, (uint32_t )nb12, (uint32_t )nb13, (uint32_t )r2, (uint32_t )r3 };
4532+ ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 18 * sizeof (uint32_t ), &pc, { (uint32_t )((ne01 + 7 ) / 8 ), (uint32_t )ne11, (uint32_t )(ne12 * ne13) });
4533+ break ;
4534+ }
4535+ case GGML_TYPE_Q4_K:
4536+ {
4537+ const std::array<uint32_t , 18 > pc = { (uint32_t )qx_shader_offset, (uint32_t )(qy_shader_offset / ggml_type_size (src1->type )), (uint32_t )(d_shader_offset / ggml_type_size (dst->type )), (uint32_t )ne00, (uint32_t )ne10, (uint32_t )ne0, (uint32_t )ne1, (uint32_t )ne01, (uint32_t )ne02, (uint32_t )ne12, (uint32_t )nb01, (uint32_t )nb02, (uint32_t )nb03, (uint32_t )nb11, (uint32_t )nb12, (uint32_t )nb13, (uint32_t )r2, (uint32_t )r3 };
4538+ ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 18 * sizeof (uint32_t ), &pc, { (uint32_t )((ne01 + 3 ) / 4 ), (uint32_t )ne11, (uint32_t )(ne12 * ne13) });
4539+ break ;
4540+ }
4541+ case GGML_TYPE_Q6_K:
4542+ {
4543+ const std::array<uint32_t , 18 > pc = { (uint32_t )qx_shader_offset, (uint32_t )(qy_shader_offset / ggml_type_size (src1->type )), (uint32_t )(d_shader_offset / ggml_type_size (dst->type )), (uint32_t )ne00, (uint32_t )ne10, (uint32_t )ne0, (uint32_t )ne1, (uint32_t )ne01, (uint32_t )ne02, (uint32_t )ne12, (uint32_t )nb01, (uint32_t )nb02, (uint32_t )nb03, (uint32_t )nb11, (uint32_t )nb12, (uint32_t )nb13, (uint32_t )r2, (uint32_t )r3 };
4544+ ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 18 * sizeof (uint32_t ), &pc, { (uint32_t )((ne01 + 1 ) / 2 ), (uint32_t )ne11, (uint32_t )(ne12 * ne13) });
4545+ break ;
4546+ }
4547+ default :
4548+ GGML_ABORT (" vkemb_mul_mat: unsupported quantization type: %d" , src0->type );
4549+ }
4550+ }
4551+
43744552static void ggml_vk_mul_mat_id_q_f16 (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false ) {
43754553 VK_LOG_DEBUG (" ggml_vk_mul_mat_id_q_f16((" << src0 << " , name=" << src0->name << " , type=" << src0->type << " , ne0=" << src0->ne [0 ] << " , ne1=" << src0->ne [1 ] << " , ne2=" << src0->ne [2 ] << " , ne3=" << src0->ne [3 ] << " , nb0=" << src0->nb [0 ] << " , nb1=" << src0->nb [1 ] << " , nb2=" << src0->nb [2 ] << " , nb3=" << src0->nb [3 ];
43764554 std::cerr << " ), (" << src1 << " , name=" << src1->name << " , type=" << src1->type << " , ne0=" << src1->ne [0 ] << " , ne1=" << src1->ne [1 ] << " , ne2=" << src1->ne [2 ] << " , ne3=" << src1->ne [3 ] << " , nb0=" << src1->nb [0 ] << " , nb1=" << src1->nb [1 ] << " , nb2=" << src1->nb [2 ] << " , nb3=" << src1->nb [3 ];
@@ -7086,8 +7264,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
70867264
70877265 break ;
70887266 case GGML_OP_MUL_MAT:
7089- ggml_vk_mul_mat (ctx, compute_ctx, src0, src1, node, dryrun);
7090-
7267+ if (ctx->device ->embedded ) {
7268+ ggml_vkemb_mul_mat (ctx, compute_ctx, src0, src1, node, dryrun);
7269+ } else {
7270+ ggml_vk_mul_mat (ctx, compute_ctx, src0, src1, node, dryrun);
7271+ }
70917272 break ;
70927273 case GGML_OP_MUL_MAT_ID:
70937274 ggml_vk_mul_mat_id (ctx, compute_ctx, src0, src1, src2, node, dryrun);
@@ -8108,7 +8289,33 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
81088289}
81098290
81108291static bool ggml_backend_vkemb_device_supports_op (ggml_backend_dev_t dev, const ggml_tensor * op) {
8111- return ggml_backend_vk_device_supports_op (dev, op);
8292+ switch (op->op ) {
8293+ case GGML_OP_MUL_MAT:
8294+ if (op->src [1 ]->type != GGML_TYPE_F32 || ggml_is_transposed (op->src [0 ]) || ggml_is_transposed (op->src [1 ]))
8295+ return false ;
8296+
8297+ switch (op->src [0 ]->type ) {
8298+ case GGML_TYPE_F32:
8299+ return op->ne [3 ] == 1 ;
8300+ case GGML_TYPE_Q8_0:
8301+ // TODO (slp) - Fix Q8_0 with permutations
8302+ if (ggml_is_permuted (op->src [0 ]) || ggml_is_permuted (op->src [1 ])) {
8303+ return false ;
8304+ }
8305+ case GGML_TYPE_Q6_K:
8306+ case GGML_TYPE_F16:
8307+ case GGML_TYPE_Q4_0:
8308+ case GGML_TYPE_Q4_1:
8309+ case GGML_TYPE_Q4_K:
8310+ return true ;
8311+ default :
8312+ return false ;
8313+ }
8314+ case GGML_OP_MUL_MAT_ID:
8315+ return false ;
8316+ default :
8317+ return ggml_backend_vk_device_supports_op (dev, op);
8318+ }
81128319}
81138320
81148321static bool ggml_backend_vk_device_supports_buft (ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
0 commit comments