|
32 | 32 | layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; |
33 | 33 |
|
34 | 34 | layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; |
35 | | -#if defined(A_TYPE_PACKED16) |
36 | | -layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; |
37 | | -#endif |
38 | | -#if defined(A_TYPE_PACKED32) |
39 | | -layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; |
40 | | -#endif |
41 | | - |
42 | 35 | layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; |
43 | 36 | layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; |
44 | 37 |
|
@@ -250,100 +243,74 @@ void main() { |
250 | 243 | #endif |
251 | 244 | #elif defined(DATA_A_Q4_0) |
252 | 245 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |
253 | | - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; |
254 | | - |
255 | | - const uint ib = idx / 4; |
256 | | - const uint iqs = idx & 0x03; |
257 | | - |
258 | | - const float d = float(data_a_packed16[ib].d); |
259 | | - const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); |
260 | | - const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; |
261 | | - const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; |
262 | | - |
263 | | - buf_a[buf_idx ] = FLOAT_TYPE(v0.x); |
264 | | - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); |
265 | | - buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); |
266 | | - buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); |
267 | | - buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); |
268 | | - buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); |
269 | | - buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); |
270 | | - buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); |
| 246 | + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; |
| 247 | + |
| 248 | + const uint ib = idx / 16; |
| 249 | + const uint iqs = idx & 0xF; |
| 250 | + |
| 251 | + const float d = float(data_a[ib].d); |
| 252 | + const uint vui = uint(data_a[ib].qs[iqs]); |
| 253 | + const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; |
| 254 | + |
| 255 | + buf_a[buf_idx ] = FLOAT_TYPE(v.x); |
| 256 | + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); |
271 | 257 | #elif defined(DATA_A_Q4_1) |
272 | 258 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |
273 | | - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; |
274 | | - |
275 | | - const uint ib = idx / 4; |
276 | | - const uint iqs = idx & 0x03; |
277 | | - |
278 | | - const float d = float(data_a_packed16[ib].d); |
279 | | - const float m = float(data_a_packed16[ib].m); |
280 | | - const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); |
281 | | - const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; |
282 | | - const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; |
283 | | - |
284 | | - buf_a[buf_idx ] = FLOAT_TYPE(v0.x); |
285 | | - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); |
286 | | - buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); |
287 | | - buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); |
288 | | - buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); |
289 | | - buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); |
290 | | - buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); |
291 | | - buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); |
| 259 | + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; |
| 260 | + |
| 261 | + const uint ib = idx / 16; |
| 262 | + const uint iqs = idx & 0xF; |
| 263 | + |
| 264 | + const float d = float(data_a[ib].d); |
| 265 | + const float m = float(data_a[ib].m); |
| 266 | + const uint vui = uint(data_a[ib].qs[iqs]); |
| 267 | + const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m; |
| 268 | + |
| 269 | + buf_a[buf_idx ] = FLOAT_TYPE(v.x); |
| 270 | + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); |
292 | 271 | #elif defined(DATA_A_Q5_0) |
293 | 272 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |
294 | | - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; |
| 273 | + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; |
295 | 274 |
|
296 | | - const uint ib = idx / 8; |
297 | | - const uint iqs = idx & 0x07; |
| 275 | + const uint ib = idx / 16; |
| 276 | + const uint iqs = idx & 0xF; |
298 | 277 |
|
299 | | - const float d = float(data_a_packed16[ib].d); |
300 | | - const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]); |
301 | | - const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); |
302 | | - const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); |
303 | | - |
304 | | - const uint vui = uint(data_a_packed16[ib].qs[iqs]); |
305 | | - const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; |
| 278 | + const float d = float(data_a[ib].d); |
| 279 | + const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; |
| 280 | + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); |
| 281 | + const uint vui = uint(data_a[ib].qs[iqs]); |
| 282 | + const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; |
306 | 283 |
|
307 | 284 | buf_a[buf_idx ] = FLOAT_TYPE(v.x); |
308 | | - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); |
309 | 285 | buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); |
310 | | - buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); |
311 | 286 | #elif defined(DATA_A_Q5_1) |
312 | 287 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |
313 | | - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; |
314 | | - |
315 | | - const uint ib = idx / 8; |
316 | | - const uint iqs = idx & 0x07; |
| 288 | + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; |
317 | 289 |
|
318 | | - const float d = float(data_a_packed16[ib].d); |
319 | | - const float m = float(data_a_packed16[ib].m); |
320 | | - const uint uint_qh = data_a_packed16[ib].qh; |
321 | | - const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); |
322 | | - const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); |
| 290 | + const uint ib = idx / 16; |
| 291 | + const uint iqs = idx & 0xF; |
323 | 292 |
|
324 | | - const uint vui = uint(data_a_packed16[ib].qs[iqs]); |
325 | | - const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; |
| 293 | + const float d = float(data_a[ib].d); |
| 294 | + const float m = float(data_a[ib].m); |
| 295 | + const uint uint_qh = data_a[ib].qh; |
| 296 | + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); |
| 297 | + const uint vui = uint(data_a[ib].qs[iqs]); |
| 298 | + const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; |
326 | 299 |
|
327 | 300 | buf_a[buf_idx ] = FLOAT_TYPE(v.x); |
328 | | - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); |
329 | 301 | buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); |
330 | | - buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); |
331 | 302 | #elif defined(DATA_A_Q8_0) |
332 | 303 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |
333 | 304 | const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; |
334 | 305 |
|
335 | | - const uint ib = idx / 8; |
336 | | - const uint iqs = idx & 0x07; |
| 306 | + const uint ib = idx / 16; |
| 307 | + const uint iqs = (idx & 0xF) * 2; |
337 | 308 |
|
338 | | - const float d = float(data_a_packed16[ib].d); |
339 | | - const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]); |
340 | | - const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]); |
341 | | - const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; |
| 309 | + const float d = float(data_a[ib].d); |
| 310 | + const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d; |
342 | 311 |
|
343 | 312 | buf_a[buf_idx ] = FLOAT_TYPE(v.x); |
344 | 313 | buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); |
345 | | - buf_a[buf_idx + 2] = FLOAT_TYPE(v.z); |
346 | | - buf_a[buf_idx + 3] = FLOAT_TYPE(v.w); |
347 | 314 | #elif defined(DATA_A_Q2_K) |
348 | 315 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |
349 | 316 | const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; |
@@ -656,18 +623,17 @@ void main() { |
656 | 623 | buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); |
657 | 624 | #elif defined(DATA_A_IQ4_NL) |
658 | 625 | const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |
659 | | - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; |
| 626 | + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; |
660 | 627 |
|
661 | | - const uint ib = idx / 8; |
662 | | - const uint iqs = idx & 0x07; |
| 628 | + const uint ib = idx / 16; |
| 629 | + const uint iqs = idx & 0xF; |
663 | 630 |
|
664 | | - const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); |
665 | | - const uint vui = uint(data_a_packed16[ib].qs[iqs]); |
| 631 | + const float d = float(data_a[ib].d); |
| 632 | + const uint vui = uint(data_a[ib].qs[iqs]); |
| 633 | + const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d; |
666 | 634 |
|
667 | | - buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d; |
668 | | - buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; |
669 | | - buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; |
670 | | - buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d; |
| 635 | + buf_a[buf_idx ] = FLOAT_TYPE(v.x); |
| 636 | + buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); |
671 | 637 | #endif |
672 | 638 | } |
673 | 639 | [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { |
|
0 commit comments