@@ -1079,19 +1079,26 @@ static void ggml_vk_mul_mat_q4_k(
10791079 const std::shared_ptr<kp::Tensor>& inB,
10801080 const std::shared_ptr<kp::Tensor>& out,
10811081 uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1082- int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
1083- int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
1084- int32_t ne1, int32_t r2, int32_t r3
1082+ int32_t ne00, int32_t ne01, int32_t ne02,
1083+ int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
1084+ int32_t ne0, int32_t ne1,
1085+ uint32_t nb01, uint32_t nb02, uint32_t nb03,
1086+ uint32_t nb11, uint32_t nb12, uint32_t nb13,
1087+ uint32_t r2, uint32_t r3
10851088) {
10861089 const static auto spirv = getSpirvShader (kp::shader_data::op_mul_mat_q4_k_comp_spv,
10871090 kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
10881091
10891092 struct PushConstants {
10901093 uint32_t inAOff, inBOff, outOff;
1091- int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
1094+ int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
1095+ uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
1096+ uint32_t r2, r3;
10921097 } pushConsts {
1093- 0 , 0 , 0 ,
1094- ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
1098+ inAOff, safe_divide (inBOff, 4 ), safe_divide (outOff, 4 ),
1099+ ne00, ne10, ne0, ne1, ne01, ne02, ne12,
1100+ nb01, nb02, nb03, nb11, nb12, nb13,
1101+ r2, r3
10951102 };
10961103
10971104 std::shared_ptr<kp::Algorithm> s_algo = nullptr ;
@@ -1705,7 +1712,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
17051712 case GGML_TYPE_Q4_K:
17061713 ggml_vk_mul_mat_q4_k (
17071714 seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1708- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
1715+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1716+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
17091717 );
17101718 break ;
17111719 case GGML_TYPE_Q6_K:
0 commit comments