@@ -916,9 +916,9 @@ static void ggml_vk_mul_mat_f16(
916916 const std::shared_ptr<kp::Tensor>& out,
917917 uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
918918 int32_t ne00, int32_t ne01, int32_t ne02,
919- uint32_t nb00, uint32_t nb01, uint32_t nb02,
919+ uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
920920 int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
921- uint32_t nb10, uint32_t nb11, uint32_t nb12,
921+ uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13,
922922 int32_t ne0, int32_t ne1,
923923 uint32_t r2, uint32_t r3
924924) {
@@ -928,17 +928,17 @@ static void ggml_vk_mul_mat_f16(
928928 struct PushConstants {
929929 uint32_t inAOff, inBOff, outOff;
930930 int32_t ne00, ne01, ne02;
931- uint32_t nb00, nb01, nb02;
931+ uint32_t nb00, nb01, nb02, nb03 ;
932932 int32_t ne10, ne11, ne12;
933- uint32_t nb10, nb11, nb12;
933+ uint32_t nb10, nb11, nb12, nb13 ;
934934 int32_t ne0, ne1;
935935 uint32_t r2, r3;
936936 } pushConsts {
937937 safe_divide (inAOff, 2 ), safe_divide (inBOff, 4 ), safe_divide (outOff, 4 ),
938938 ne00, ne01, ne02,
939- nb00, nb01, nb02,
939+ nb00, nb01, nb02, nb03,
940940 ne10, ne11, ne12,
941- nb10, nb11, nb12,
941+ nb10, nb11, nb12, nb13,
942942 ne0, ne1,
943943 r2, r3
944944 };
@@ -1693,7 +1693,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
16931693 case GGML_TYPE_F16:
16941694 ggml_vk_mul_mat_f16 (
16951695 seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1696- ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
1696+ ne00, ne01, ne02, nb00, nb01, nb02, nb03,
1697+ ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
16971698 ne0, ne1, r2, r3
16981699 );
16991700 break ;
0 commit comments