Skip to content

Commit 51dea76

Browse files
committed
metal : fix nr constant [no ci]
1 parent 982c82f commit 51dea76

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4727,7 +4727,7 @@ void kernel_mul_mv_q4_K_f32_impl(
47274727
float yl[16];
47284728
float yh[16];
47294729

4730-
float sumf[N_R0_Q4_K]={0.f};
4730+
float sumf[nr0]={0.f};
47314731

47324732
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
47334733

@@ -4737,7 +4737,6 @@ void kernel_mul_mv_q4_K_f32_impl(
47374737
for (int ib = ix; ib < nb; ib += 4) {
47384738
float4 sumy = {0.f, 0.f, 0.f, 0.f};
47394739

4740-
#pragma unroll(8)
47414740
for (short i = 0; i < 8; ++i) {
47424741
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
47434742
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
@@ -4749,8 +4748,7 @@ void kernel_mul_mv_q4_K_f32_impl(
47494748
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
47504749
device const half * dh = &x[ib].d;
47514750

4752-
#pragma unroll(N_R0_Q4_K)
4753-
for (short row = 0; row < N_R0_Q4_K; row++) {
4751+
for (short row = 0; row < nr0; row++) {
47544752
sc16[0] = sc[0] & kmask1;
47554753
sc16[1] = sc[2] & kmask1;
47564754
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
@@ -4761,7 +4759,6 @@ void kernel_mul_mv_q4_K_f32_impl(
47614759
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
47624760
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
47634761

4764-
#pragma unroll(4)
47654762
for (short i = 0; i < 4; ++i) {
47664763
acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
47674764
acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
@@ -4792,7 +4789,7 @@ void kernel_mul_mv_q4_K_f32_impl(
47924789

47934790
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
47944791

4795-
for (int row = 0; row < N_R0_Q4_K && first_row + row < args.ne0; ++row) {
4792+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
47964793
float sum_all = simd_sum(sumf[row]);
47974794
if (tiisg == 0) {
47984795
dst_f32[first_row + row] = sum_all;

0 commit comments

Comments
 (0)