@@ -167,6 +167,101 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
167
167
block_q4_K_packed128 block;
168
168
};
169
169
170
+ #if defined(IS_MUL_MM2)
171
+
172
+ // For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
173
+ // into shared memory and then process the whole tile using those scales.
174
+ // There is a fetch function that loads into private variables and then a store
175
+ // function that stores into shared memory.
176
+ // Q4_K and Q5_K have the same encoding of scales, so everything is shared except
177
+ // the part that fetches from the structure (which has a different block layout).
178
+ #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
179
+ const uint shAscales_stride = (BM + 2);
180
+ // 1 scale per 32 elements -> 8 scales per block, per row
181
+ shared vec2 shAscales[8 * shAscales_stride];
182
+ uvec4 row_v;
183
+ #endif
184
+
185
+ #if defined(DATA_A_Q4_K)
186
+ layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
187
+
188
+ void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
189
+ {
190
+ uint tids_per_row = BLOCK_SIZE / BM;
191
+ uint is_per_tid = 8 / tids_per_row;
192
+ uint is_start = is_per_tid * (tid % tids_per_row);
193
+ uint tid_row = tid / tids_per_row;
194
+
195
+ uint row = ir_BM + tid_row;
196
+ uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
197
+ if (in_bounds || row < p.M) {
198
+ row_v = data_a_q4_k_packed128[block_index].q4k[0];
199
+ }
200
+ }
201
+ #endif
202
+ #if defined(DATA_A_Q5_K)
203
+ layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
204
+
205
+ void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
206
+ {
207
+ uint tids_per_row = BLOCK_SIZE / BM;
208
+ uint is_per_tid = 8 / tids_per_row;
209
+ uint is_start = is_per_tid * (tid % tids_per_row);
210
+ uint tid_row = tid / tids_per_row;
211
+
212
+ uint row = ir_BM + tid_row;
213
+ uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
214
+ if (in_bounds || row < p.M) {
215
+ row_v = data_a_q5_k_packed128[block_index].q5k[0];
216
+ }
217
+ }
218
+ #endif
219
+
220
+ #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
221
+ void store_scalesQ4_K(uint tid)
222
+ {
223
+ barrier();
224
+
225
+ uint tids_per_row = BLOCK_SIZE / BM;
226
+ uint is_per_tid = 8 / tids_per_row;
227
+ uint is_start = is_per_tid * (tid % tids_per_row);
228
+ uint tid_row = tid / tids_per_row;
229
+
230
+ [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
231
+ uint is = idx + is_start;
232
+ uvec4 v = row_v;
233
+ const vec2 loadd = vec2(unpackFloat2x16(v.x));
234
+
235
+ uint32_t sc;
236
+ uint32_t mbyte;
237
+
238
+ uint32_t scale0 = v.y;
239
+ uint32_t scale4 = v.z;
240
+ uint32_t scale8 = v.w;
241
+
242
+ uint32_t sc_lo = scale0;
243
+ uint32_t mb_lo = scale4;
244
+ uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
245
+ uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
246
+
247
+ sc = is < 4 ? sc_lo : sc_hi;
248
+ mbyte = is < 4 ? mb_lo : mb_hi;
249
+ sc = sc >> (8 * (is & 3));
250
+ mbyte = mbyte >> (8 * (is & 3));
251
+ sc &= 0x3F;
252
+ mbyte &= 0x3F;
253
+
254
+ const float d = loadd.x * float(sc);
255
+ const float m = loadd.y * float(mbyte);
256
+ shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
257
+ }
258
+
259
+ barrier();
260
+ }
261
+ #endif
262
+
263
+ #endif
264
+
170
265
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
171
266
{
172
267
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
@@ -176,8 +271,12 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
176
271
const uint b = (idx & 0x20) >> 5; // 0,1
177
272
const uint is = (idx & 0xE0) >> 5; // 0..7
178
273
274
+ #if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
275
+ vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
276
+ float d = v.x;
277
+ float m = v.y;
278
+ #else
179
279
uvec4 v = bl128.block.q4k[0];
180
-
181
280
const vec2 loadd = vec2(unpackFloat2x16(v.x));
182
281
183
282
uint32_t sc;
@@ -201,6 +300,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
201
300
202
301
const float d = loadd.x * float(sc);
203
302
const float m = loadd.y * float(mbyte);
303
+ #endif
204
304
205
305
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
206
306
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
@@ -231,6 +331,11 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
231
331
const uint b = (idx & 0x20) >> 5; // 0,1
232
332
const uint is = (idx & 0xE0) >> 5; // 0..7
233
333
334
+ #if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
335
+ vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
336
+ float d = v.x;
337
+ float m = v.y;
338
+ #else
234
339
uvec4 v = bl128.block.q5k[0];
235
340
236
341
const f16vec2 loadd = unpackFloat2x16(v.x);
@@ -256,6 +361,7 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
256
361
257
362
const float16_t d = loadd.x * float16_t(sc);
258
363
const float16_t m = loadd.y * float16_t(mbyte);
364
+ #endif
259
365
260
366
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
261
367
qh = ((qh >> is) & 0x101) << 4;
@@ -264,9 +370,9 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
264
370
qs = (qs >> (b * 4)) & 0x0F0F;
265
371
qs = unpack8(qs | qh)[idx & 1];
266
372
267
- float16_t ret = d * (float16_t(qs) ) - m;
373
+ float ret = d * float(qs ) - m;
268
374
269
- return ret;
375
+ return float16_t( ret) ;
270
376
}
271
377
272
378
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
@@ -564,8 +670,12 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
564
670
#define dequantFuncA dequantFuncQ3_K
565
671
#elif defined(DATA_A_Q4_K)
566
672
#define dequantFuncA dequantFuncQ4_K
673
+ #define fetch_scales fetch_scalesQ4_K
674
+ #define store_scales store_scalesQ4_K
567
675
#elif defined(DATA_A_Q5_K)
568
676
#define dequantFuncA dequantFuncQ5_K
677
+ #define fetch_scales fetch_scalesQ5_K
678
+ #define store_scales store_scalesQ4_K
569
679
#elif defined(DATA_A_Q6_K)
570
680
#define dequantFuncA dequantFuncQ6_K
571
681
#elif defined(DATA_A_IQ1_S)
0 commit comments