@@ -1018,26 +1018,32 @@ static void ggml_vk_mul_mat_impl(
10181018 int32_t ne00, int32_t ne01, int32_t ne02,
10191019 int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
10201020 int32_t ne0, int32_t ne1,
1021+ uint32_t nb01, uint32_t nb02, uint32_t nb03,
1022+ uint32_t nb11, uint32_t nb12, uint32_t nb13,
10211023 uint32_t r2, uint32_t r3
10221024) {
10231025 struct PushConstants {
10241026 uint32_t inAOff, inBOff, outOff;
10251027 int32_t ne00, ne01, ne02;
10261028 int32_t ne10, ne12;
10271029 int32_t ne0, ne1;
1030+ uint32_t nb01, nb02, nb03;
1031+ uint32_t nb11, nb12, nb13;
10281032 uint32_t r2, r3;
10291033 } pushConsts {
10301034 safe_divide (inAOff, block_size), safe_divide (inBOff, 4 ), safe_divide (outOff, 4 ),
10311035 ne00, ne01, ne02,
10321036 ne10, ne12,
10331037 ne0, ne1,
1038+ nb01, nb02, nb03,
1039+ nb11, nb12, nb13,
10341040 r2, r3
10351041 };
10361042
10371043 auto name = std::string (__func__) + " _" + suffix;
10381044 std::shared_ptr<kp::Algorithm> s_algo = nullptr ;
10391045 if (!komputeManager ()->hasAlgorithm (name)) {
1040- const uint32_t local_x = ggml_vk_current_device ().subgroupSize * 2 ;
1046+ const uint32_t local_x = ( ggml_vk_current_device ().subgroupSize * 2 ) / 8 ;
10411047 s_algo = komputeManager ()->algorithm <uint32_t , PushConstants>(name, s_kompute_context->pool .get (), {inA, inB, out}, spirv, {unsigned ((ne01 + 7 )/8 ), unsigned (ne11), unsigned (ne12*ne13)}, {local_x}, {pushConsts});
10421048 } else {
10431049 s_algo = komputeManager ()->getAlgorithm (name);
@@ -1694,19 +1700,22 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
16941700 case GGML_TYPE_Q8_0:
16951701 ggml_vk_mul_mat_q8_0 (
16961702 seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1697- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1703+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1704+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
16981705 );
16991706 break ;
17001707 case GGML_TYPE_Q4_0:
17011708 ggml_vk_mul_mat_q4_0 (
17021709 seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1703- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1710+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1711+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
17041712 );
17051713 break ;
17061714 case GGML_TYPE_Q4_1:
17071715 ggml_vk_mul_mat_q4_1 (
17081716 seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1709- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1717+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
1718+ nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
17101719 );
17111720 break ;
17121721 case GGML_TYPE_Q4_K:
0 commit comments