@@ -167,6 +167,101 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
167167 block_q4_K_packed128 block;
168168};
169169
170+ #if defined(IS_MUL_MM2)
171+
172+ // For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
173+ // into shared memory and then process the whole tile using those scales.
174+ // There is a fetch function that loads into private variables and then a store
175+ // function that stores into shared memory.
176+ // Q4_K and Q5_K have the same encoding of scales, so everything is shared except
177+ // the part that fetches from the structure (which has a different block layout).
178+ #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
179+ const uint shAscales_stride = (BM + 2);
180+ // 1 scale per 32 elements -> 8 scales per block, per row
181+ shared vec2 shAscales[8 * shAscales_stride];
182+ uvec4 row_v;
183+ #endif
184+
185+ #if defined(DATA_A_Q4_K)
186+ layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
187+
188+ void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
189+ {
190+ uint tids_per_row = BLOCK_SIZE / BM;
191+ uint is_per_tid = 8 / tids_per_row;
192+ uint is_start = is_per_tid * (tid % tids_per_row);
193+ uint tid_row = tid / tids_per_row;
194+
195+ uint row = ir_BM + tid_row;
196+ uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
197+ if (in_bounds || row < p.M) {
198+ row_v = data_a_q4_k_packed128[block_index].q4k[0];
199+ }
200+ }
201+ #endif
202+ #if defined(DATA_A_Q5_K)
203+ layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
204+
205+ void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
206+ {
207+ uint tids_per_row = BLOCK_SIZE / BM;
208+ uint is_per_tid = 8 / tids_per_row;
209+ uint is_start = is_per_tid * (tid % tids_per_row);
210+ uint tid_row = tid / tids_per_row;
211+
212+ uint row = ir_BM + tid_row;
213+ uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
214+ if (in_bounds || row < p.M) {
215+ row_v = data_a_q5_k_packed128[block_index].q5k[0];
216+ }
217+ }
218+ #endif
219+
220+ #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
221+ void store_scalesQ4_K(uint tid)
222+ {
223+ barrier();
224+
225+ uint tids_per_row = BLOCK_SIZE / BM;
226+ uint is_per_tid = 8 / tids_per_row;
227+ uint is_start = is_per_tid * (tid % tids_per_row);
228+ uint tid_row = tid / tids_per_row;
229+
230+ [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
231+ uint is = idx + is_start;
232+ uvec4 v = row_v;
233+ const vec2 loadd = vec2(unpackFloat2x16(v.x));
234+
235+ uint32_t sc;
236+ uint32_t mbyte;
237+
238+ uint32_t scale0 = v.y;
239+ uint32_t scale4 = v.z;
240+ uint32_t scale8 = v.w;
241+
242+ uint32_t sc_lo = scale0;
243+ uint32_t mb_lo = scale4;
244+ uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
245+ uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
246+
247+ sc = is < 4 ? sc_lo : sc_hi;
248+ mbyte = is < 4 ? mb_lo : mb_hi;
249+ sc = sc >> (8 * (is & 3));
250+ mbyte = mbyte >> (8 * (is & 3));
251+ sc &= 0x3F;
252+ mbyte &= 0x3F;
253+
254+ const float d = loadd.x * float(sc);
255+ const float m = loadd.y * float(mbyte);
256+ shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
257+ }
258+
259+ barrier();
260+ }
261+ #endif
262+
263+ #endif
264+
170265float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
171266{
172267 decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
@@ -176,8 +271,12 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
176271 const uint b = (idx & 0x20) >> 5; // 0,1
177272 const uint is = (idx & 0xE0) >> 5; // 0..7
178273
274+ #if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
275+ vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
276+ float d = v.x;
277+ float m = v.y;
278+ #else
179279 uvec4 v = bl128.block.q4k[0];
180-
181280 const vec2 loadd = vec2(unpackFloat2x16(v.x));
182281
183282 uint32_t sc;
@@ -201,6 +300,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
201300
202301 const float d = loadd.x * float(sc);
203302 const float m = loadd.y * float(mbyte);
303+ #endif
204304
205305 uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
206306 qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
@@ -231,6 +331,11 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
231331 const uint b = (idx & 0x20) >> 5; // 0,1
232332 const uint is = (idx & 0xE0) >> 5; // 0..7
233333
334+ #if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
335+ vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
336+ float d = v.x;
337+ float m = v.y;
338+ #else
234339 uvec4 v = bl128.block.q5k[0];
235340
236341 const f16vec2 loadd = unpackFloat2x16(v.x);
@@ -256,6 +361,7 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
256361
257362 const float16_t d = loadd.x * float16_t(sc);
258363 const float16_t m = loadd.y * float16_t(mbyte);
364+ #endif
259365
260366 uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
261367 qh = ((qh >> is) & 0x101) << 4;
@@ -264,9 +370,9 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
264370 qs = (qs >> (b * 4)) & 0x0F0F;
265371 qs = unpack8(qs | qh)[idx & 1];
266372
267- float16_t ret = d * (float16_t(qs) ) - m;
373+ float ret = d * float(qs ) - m;
268374
269- return ret;
375+ return float16_t( ret) ;
270376}
271377
272378layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
@@ -564,8 +670,12 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
564670#define dequantFuncA dequantFuncQ3_K
565671#elif defined(DATA_A_Q4_K)
566672#define dequantFuncA dequantFuncQ4_K
673+ #define fetch_scales fetch_scalesQ4_K
674+ #define store_scales store_scalesQ4_K
567675#elif defined(DATA_A_Q5_K)
568676#define dequantFuncA dequantFuncQ5_K
677+ #define fetch_scales fetch_scalesQ5_K
678+ #define store_scales store_scalesQ4_K
569679#elif defined(DATA_A_Q6_K)
570680#define dequantFuncA dequantFuncQ6_K
571681#elif defined(DATA_A_IQ1_S)
0 commit comments