Skip to content

Commit f90131e

Browse files
ggerganovdmahurin
authored andcommitted
A few updates by just pattern matching with the Q2_K kernel
1 parent f12f803 commit f90131e

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5075,15 +5075,15 @@ void kernel_mul_mv_tq2_0_f32_impl(
50755075
const int im = tgpig.z;
50765076

50775077
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5078-
const int ib_row = first_row * nb;
50795078

50805079
const uint i12 = im%args.ne12;
50815080
const uint i13 = im/args.ne12;
50825081

5083-
const uint offset0 = (i12/args.r2)*(nb*args.ne01) + (i13/args.r3)*(nb*args.ne01*args.ne02);
5082+
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
5083+
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
50845084

5085-
device const block_tq2_0 * x = (device const block_tq2_0 *) src0 + ib_row + offset0;
5086-
device const float * y = (device const float *) src1 + r1*args.ne10 + im*args.ne00*args.ne1;
5085+
device const block_tq2_0 * x = (device const block_tq2_0 *) (src0 + offset0);
5086+
device const float * y = (device const float *) (src1 + offset1);
50875087

50885088
float yl[32];
50895089
float sumf[N_DST]={0.f}, all_sum;
@@ -5139,7 +5139,7 @@ void kernel_mul_mv_tq2_0_f32_impl(
51395139

51405140
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
51415141

5142-
for (int row = 0; row < N_DST; ++row) {
5142+
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
51435143
all_sum = simd_sum(sumf[row]);
51445144
if (tiisg == 0) {
51455145
dst_f32[first_row + row] = all_sum;

0 commit comments

Comments
 (0)