@@ -163,39 +163,47 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
163163   block_q4_K_packed16 block;
164164};
165165
166+ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
167+    block_q4_K_packed128 block;
168+ };
169+ 
166170float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
167171{
168172    decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
173+     decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
169174    const uint idx = coordInBlock[1];
170175
171176    const uint b = (idx & 0x20) >> 5;            // 0,1
172177    const uint is = (idx & 0xE0) >> 5;         // 0..7
173178
174-     const f16vec2 loadd = bl.block.d;
179+     uvec4 v = bl128.block.q4k[0];
180+ 
181+     const f16vec2 loadd = unpackFloat2x16(v.x);
175182
176183    uint32_t sc;
177184    uint32_t mbyte;
178185
179-     uint32_t scidx0 = (is < 4) ? is : (is + 4);
180-     uint32_t scidx1 = (is < 4) ? is : (is - 4);
181-     uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
182-     uint32_t scidxshift1 = (is < 4) ? 0 : 2;
183-     uint32_t mbidx0 = is + 4;
184-     uint32_t mbidx1 = (is < 4) ? is + 4 : is;
185-     uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
186-     uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
187-     uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
188-     uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
186+     uint32_t scale0 = v.y;
187+     uint32_t scale4 = v.z;
188+     uint32_t scale8 = v.w;
189189
190-     sc    = uint8_t((bl.block.scales[scidx0] & 0xF)                         | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
191-     mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
190+     uint32_t sc_lo = scale0;
191+     uint32_t mb_lo = scale4;
192+     uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
193+     uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
194+ 
195+     sc = is < 4 ? sc_lo : sc_hi;
196+     mbyte = is < 4 ? mb_lo : mb_hi;
197+     sc = sc >> (8 * (is & 3));
198+     mbyte = mbyte >> (8 * (is & 3));
199+     sc &= 0x3F;
200+     mbyte &= 0x3F;
192201
193202    const float16_t d = loadd.x * float16_t(sc);
194203    const float16_t m = loadd.y * float16_t(mbyte);
195204
196205    uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
197-     qs = (qs >> (b * 4)) & 0x0F0F;
198-     qs = unpack8(qs)[idx & 1];
206+     qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
199207
200208    float16_t ret = d * float16_t(qs) - m;
201209
@@ -210,47 +218,53 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
210218   block_q5_K_packed16 block;
211219};
212220
221+ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
222+    block_q5_K_packed128 block;
223+ };
224+ 
213225float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
214226{
215227    decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
228+     decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
216229    const uint idx = coordInBlock[1];
217230
218231    const uint b = (idx & 0x20) >> 5;          // 0,1
219232    const uint is = (idx & 0xE0) >> 5;         // 0..7
220233
221-     const uint32_t hm = 0x0101 << is ;
234+     uvec4 v = bl128.block.q5k[0] ;
222235
223-     const f16vec2 loadd = bl.block.d ;
236+     const f16vec2 loadd = unpackFloat2x16(v.x) ;
224237
225238    uint32_t sc;
226239    uint32_t mbyte;
227240
228-     uint32_t scidx0 = (is < 4) ? is : (is + 4);
229-     uint32_t scidx1 = (is < 4) ? is : (is - 4);
230-     uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
231-     uint32_t scidxshift1 = (is < 4) ? 0 : 2;
232-     uint32_t mbidx0 = is + 4;
233-     uint32_t mbidx1 = (is < 4) ? is + 4 : is;
234-     uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
235-     uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
236-     uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
237-     uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
241+     uint32_t scale0 = v.y;
242+     uint32_t scale4 = v.z;
243+     uint32_t scale8 = v.w;
238244
239-     sc    = uint8_t((bl.block.scales[scidx0] & 0xF)                         | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
240-     mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
245+     uint32_t sc_lo = scale0;
246+     uint32_t mb_lo = scale4;
247+     uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
248+     uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
249+ 
250+     sc = is < 4 ? sc_lo : sc_hi;
251+     mbyte = is < 4 ? mb_lo : mb_hi;
252+     sc = sc >> (8 * (is & 3));
253+     mbyte = mbyte >> (8 * (is & 3));
254+     sc &= 0x3F;
255+     mbyte &= 0x3F;
241256
242257    const float16_t d = loadd.x * float16_t(sc);
243258    const float16_t m = loadd.y * float16_t(mbyte);
244259
245260    uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
246-     qh = qh & hm;
247-     qh = unpack8(qh)[idx & 1];
261+     qh = ((qh >> is) & 0x101) << 4;
248262
249263    uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
250264    qs = (qs >> (b * 4)) & 0x0F0F;
251-     qs = unpack8(qs)[idx & 1];
265+     qs = unpack8(qs | qh )[idx & 1];
252266
253-     float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0)) ) - m;
267+     float16_t ret = d * (float16_t(qs)) - m;
254268
255269    return ret;
256270}
0 commit comments