@@ -157,39 +157,47 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
157157 block_q4_K_packed16 block;
158158};
159159
160+ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
161+ block_q4_K_packed128 block;
162+ };
163+
160164float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
161165{
162166 decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
167+ decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
163168 const uint idx = coordInBlock[1];
164169
165170 const uint b = (idx & 0x20) >> 5; // 0,1
166171 const uint is = (idx & 0xE0) >> 5; // 0..7
167172
168- const f16vec2 loadd = bl.block.d;
173+ uvec4 v = bl128.block.q4k[0];
174+
175+ const f16vec2 loadd = unpackFloat2x16(v.x);
169176
170177 uint32_t sc;
171178 uint32_t mbyte;
172179
173- uint32_t scidx0 = (is < 4) ? is : (is + 4);
174- uint32_t scidx1 = (is < 4) ? is : (is - 4);
175- uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
176- uint32_t scidxshift1 = (is < 4) ? 0 : 2;
177- uint32_t mbidx0 = is + 4;
178- uint32_t mbidx1 = (is < 4) ? is + 4 : is;
179- uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
180- uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
181- uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
182- uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
180+ uint32_t scale0 = v.y;
181+ uint32_t scale4 = v.z;
182+ uint32_t scale8 = v.w;
183183
184- sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
185- mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
184+ uint32_t sc_lo = scale0;
185+ uint32_t mb_lo = scale4;
186+ uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
187+ uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
188+
189+ sc = is < 4 ? sc_lo : sc_hi;
190+ mbyte = is < 4 ? mb_lo : mb_hi;
191+ sc = sc >> (8 * (is & 3));
192+ mbyte = mbyte >> (8 * (is & 3));
193+ sc &= 0x3F;
194+ mbyte &= 0x3F;
186195
187196 const float16_t d = loadd.x * float16_t(sc);
188197 const float16_t m = loadd.y * float16_t(mbyte);
189198
190199 uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
191- qs = (qs >> (b * 4)) & 0x0F0F;
192- qs = unpack8(qs)[idx & 1];
200+ qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
193201
194202 float16_t ret = d * float16_t(qs) - m;
195203
@@ -204,47 +212,53 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
204212 block_q5_K_packed16 block;
205213};
206214
215+ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
216+ block_q5_K_packed128 block;
217+ };
218+
207219float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
208220{
209221 decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
222+ decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
210223 const uint idx = coordInBlock[1];
211224
212225 const uint b = (idx & 0x20) >> 5; // 0,1
213226 const uint is = (idx & 0xE0) >> 5; // 0..7
214227
215- const uint32_t hm = 0x0101 << is ;
228+ uvec4 v = bl128.block.q5k[0] ;
216229
217- const f16vec2 loadd = bl.block.d ;
230+ const f16vec2 loadd = unpackFloat2x16(v.x) ;
218231
219232 uint32_t sc;
220233 uint32_t mbyte;
221234
222- uint32_t scidx0 = (is < 4) ? is : (is + 4);
223- uint32_t scidx1 = (is < 4) ? is : (is - 4);
224- uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
225- uint32_t scidxshift1 = (is < 4) ? 0 : 2;
226- uint32_t mbidx0 = is + 4;
227- uint32_t mbidx1 = (is < 4) ? is + 4 : is;
228- uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
229- uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
230- uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
231- uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
235+ uint32_t scale0 = v.y;
236+ uint32_t scale4 = v.z;
237+ uint32_t scale8 = v.w;
232238
233- sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
234- mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
239+ uint32_t sc_lo = scale0;
240+ uint32_t mb_lo = scale4;
241+ uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
242+ uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
243+
244+ sc = is < 4 ? sc_lo : sc_hi;
245+ mbyte = is < 4 ? mb_lo : mb_hi;
246+ sc = sc >> (8 * (is & 3));
247+ mbyte = mbyte >> (8 * (is & 3));
248+ sc &= 0x3F;
249+ mbyte &= 0x3F;
235250
236251 const float16_t d = loadd.x * float16_t(sc);
237252 const float16_t m = loadd.y * float16_t(mbyte);
238253
239254 uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
240- qh = qh & hm;
241- qh = unpack8(qh)[idx & 1];
255+ qh = ((qh >> is) & 0x101) << 4;
242256
243257 uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
244258 qs = (qs >> (b * 4)) & 0x0F0F;
245- qs = unpack8(qs)[idx & 1];
259+ qs = unpack8(qs | qh )[idx & 1];
246260
247- float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0)) ) - m;
261+ float16_t ret = d * (float16_t(qs)) - m;
248262
249263 return ret;
250264}
0 commit comments