@@ -167,6 +167,101 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
167167   block_q4_K_packed128 block;
168168};
169169
170+ #if defined(IS_MUL_MM2)
171+ 
172+ // For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
173+ // into shared memory and then process the whole tile using those scales.
174+ // There is a fetch function that loads into private variables and then a store
175+ // function that stores into shared memory.
176+ // Q4_K and Q5_K have the same encoding of scales, so everything is shared except
177+ // the part that fetches from the structure (which has a different block layout).
178+ #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
179+ const uint shAscales_stride = (BM + 2);
180+ // 1 scale per 32 elements -> 8 scales per block, per row
181+ shared vec2 shAscales[8 * shAscales_stride];
182+ uvec4 row_v;
183+ #endif
184+ 
185+ #if defined(DATA_A_Q4_K)
186+ layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
187+ 
188+ void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
189+ {
190+     uint tids_per_row = BLOCK_SIZE / BM;
191+     uint is_per_tid = 8 / tids_per_row;
192+     uint is_start = is_per_tid * (tid % tids_per_row);
193+     uint tid_row = tid / tids_per_row;
194+ 
195+     uint row = ir_BM + tid_row;
196+     uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
197+     if (in_bounds || row < p.M) {
198+         row_v = data_a_q4_k_packed128[block_index].q4k[0];
199+     }
200+ }
201+ #endif
202+ #if defined(DATA_A_Q5_K)
203+ layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
204+ 
205+ void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
206+ {
207+     uint tids_per_row = BLOCK_SIZE / BM;
208+     uint is_per_tid = 8 / tids_per_row;
209+     uint is_start = is_per_tid * (tid % tids_per_row);
210+     uint tid_row = tid / tids_per_row;
211+ 
212+     uint row = ir_BM + tid_row;
213+     uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
214+     if (in_bounds || row < p.M) {
215+         row_v = data_a_q5_k_packed128[block_index].q5k[0];
216+     }
217+ }
218+ #endif
219+ 
220+ #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
221+ void store_scalesQ4_K(uint tid)
222+ {
223+     barrier();
224+ 
225+     uint tids_per_row = BLOCK_SIZE / BM;
226+     uint is_per_tid = 8 / tids_per_row;
227+     uint is_start = is_per_tid * (tid % tids_per_row);
228+     uint tid_row = tid / tids_per_row;
229+ 
230+     [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
231+         uint is = idx + is_start;
232+         uvec4 v = row_v;
233+         const vec2 loadd = vec2(unpackFloat2x16(v.x));
234+ 
235+         uint32_t sc;
236+         uint32_t mbyte;
237+ 
238+         uint32_t scale0 = v.y;
239+         uint32_t scale4 = v.z;
240+         uint32_t scale8 = v.w;
241+ 
242+         uint32_t sc_lo = scale0;
243+         uint32_t mb_lo = scale4;
244+         uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
245+         uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
246+ 
247+         sc = is < 4 ? sc_lo : sc_hi;
248+         mbyte = is < 4 ? mb_lo : mb_hi;
249+         sc = sc >> (8 * (is & 3));
250+         mbyte = mbyte >> (8 * (is & 3));
251+         sc &= 0x3F;
252+         mbyte &= 0x3F;
253+ 
254+         const float d = loadd.x * float(sc);
255+         const float m = loadd.y * float(mbyte);
256+         shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
257+     }
258+ 
259+     barrier();
260+ }
261+ #endif
262+ 
263+ #endif
264+ 
170265float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
171266{
172267    decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
@@ -176,8 +271,12 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
176271    const uint b = (idx & 0x20) >> 5;            // 0,1
177272    const uint is = (idx & 0xE0) >> 5;         // 0..7
178273
274+ #if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
275+     vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
276+     float d = v.x;
277+     float m = v.y;
278+ #else
179279    uvec4 v = bl128.block.q4k[0];
180- 
181280    const vec2 loadd = vec2(unpackFloat2x16(v.x));
182281
183282    uint32_t sc;
@@ -201,6 +300,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
201300
202301    const float d = loadd.x * float(sc);
203302    const float m = loadd.y * float(mbyte);
303+ #endif
204304
205305    uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
206306    qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
@@ -231,6 +331,11 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
231331    const uint b = (idx & 0x20) >> 5;          // 0,1
232332    const uint is = (idx & 0xE0) >> 5;         // 0..7
233333
334+ #if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
335+     vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
336+     float d = v.x;
337+     float m = v.y;
338+ #else
234339    uvec4 v = bl128.block.q5k[0];
235340
236341    const f16vec2 loadd = unpackFloat2x16(v.x);
@@ -256,6 +361,7 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
256361
257362    const float16_t d = loadd.x * float16_t(sc);
258363    const float16_t m = loadd.y * float16_t(mbyte);
364+ #endif
259365
260366    uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
261367    qh = ((qh >> is) & 0x101) << 4;
@@ -264,9 +370,9 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
264370    qs = (qs >> (b * 4)) & 0x0F0F;
265371    qs = unpack8(qs | qh)[idx & 1];
266372
267-     float16_t  ret = d * (float16_t(qs) ) - m;
373+     float  ret = d * float(qs ) - m;
268374
269-     return ret;
375+     return float16_t( ret) ;
270376}
271377
272378layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
@@ -564,8 +670,12 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
564670#define dequantFuncA dequantFuncQ3_K
565671#elif defined(DATA_A_Q4_K)
566672#define dequantFuncA dequantFuncQ4_K
673+ #define fetch_scales fetch_scalesQ4_K
674+ #define store_scales store_scalesQ4_K
567675#elif defined(DATA_A_Q5_K)
568676#define dequantFuncA dequantFuncQ5_K
677+ #define fetch_scales fetch_scalesQ5_K
678+ #define store_scales store_scalesQ4_K
569679#elif defined(DATA_A_Q6_K)
570680#define dequantFuncA dequantFuncQ6_K
571681#elif defined(DATA_A_IQ1_S)
0 commit comments