Skip to content

Commit 13e017e

Browse files
committed
Fix Merge
1 parent e50c7f9 commit 13e017e

File tree

1 file changed

+63
-2
lines changed

1 file changed

+63
-2
lines changed

ggml/src/ggml-cuda/cpy-utils.cuh

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,27 @@ static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, blo
161161
}
162162
}
163163

164+
static __device__ const int8_t iq4nl_index[241] = {
165+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
166+
1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
167+
3, 3, 3, 3, 3, 3, 19, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 20, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
168+
5, 5, 21, 21, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 22, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 23, 23, 8, 8, 8, 8,
169+
8, 8, 8, 8, 8, 8, 24, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 25, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 26, 26,
170+
11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 27, 27, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 28, 13, 13, 13,
171+
13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 29, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
172+
14, 14, 14, 14, 30, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15
173+
};
174+
static __device__ __forceinline__ int best_index_iq4nl(const int8_t * values, float x) {
175+
int ix = (int)x - values[0];
176+
if (ix < 0 || ix >= 241) return ix < 0 ? 0 : 15;
177+
ix = iq4nl_index[ix];
178+
return ix < 16 ? ix : x - values[ix-16] < values[ix-15] - x ? ix-16 : ix-15;
179+
}
180+
164181
static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
182+
// const float * xi = (const float *) cxi;
183+
// block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
184+
165185
float amax = 0.0f;
166186
float vmax = 0.0f;
167187

@@ -176,12 +196,14 @@ static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, b
176196
float d = vmax / kvalues_iq4nl[0];
177197
const float id = d ? 1.0f/d : 0.0f;
178198

199+
//dsti->d = d;
200+
179201
float sumqx = 0, sumq2 = 0;
180202
for (int j = 0; j < QK4_NL/2; ++j) {
181203
const float x0 = x[0 + j]*id;
182204
const float x1 = x[QK4_NL/2 + j]*id;
183-
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
184-
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
205+
const uint8_t xi0 = best_index_iq4nl(kvalues_iq4nl, x0);
206+
const uint8_t xi1 = best_index_iq4nl(kvalues_iq4nl, x1);
185207
y->qs[j] = xi0 | (xi1 << 4);
186208
const float v0 = kvalues_iq4nl[xi0];
187209
const float v1 = kvalues_iq4nl[xi1];
@@ -194,6 +216,41 @@ static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, b
194216
y->d = sumq2 > 0 ? sumqx/sumq2 : d;
195217
}
196218

219+
static __device__ void quantize_f32_q6_0_block(const float * __restrict__ x, block_q6_0 * __restrict__ y) {
220+
// const float * xi = (const float *) cxi;
221+
// block_q6_0 * dsti = (block_q6_0 *) cdsti;
222+
223+
float amax = 0.0f;
224+
float vmax = 0.0f;
225+
226+
for (int j = 0; j < QK6_0; ++j) {
227+
const float v = x[j];
228+
const float av = fabsf(x[j]);
229+
if (amax < av) {
230+
amax = av;
231+
vmax = v;
232+
}
233+
}
234+
235+
const float d = vmax / -32;
236+
const float id = d ? 1.0f/d : 0.0f;
237+
238+
y->d = d;
239+
memset(y->qh, 0, QK6_0/4);
240+
241+
for (int j = 0; j < QK6_0/2; ++j) {
242+
const float x0 = x[0 + j]*id;
243+
const float x1 = x[QK4_0/2 + j]*id;
244+
245+
const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f));
246+
const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f));
247+
248+
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
249+
const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
250+
y->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4)));
251+
}
252+
}
253+
197254
// Wrapper functions for cpy.cu compatibility
198255
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
199256
quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
@@ -211,6 +268,10 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
211268
quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
212269
}
213270

271+
static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) {
272+
quantize_f32_q6_0_block((const float *)cxi, (block_q6_0 *)cdsti);
273+
}
274+
214275
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
215276
quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
216277
}

0 commit comments

Comments
 (0)