Skip to content

Commit 4dca015

Browse files
authored
vulkan: Replace 16-bit unpack8 calls to work around legacy Windows AMD driver bug (ggml-org#17285)
1 parent 9a8860c commit 4dca015

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
300300

301301
if (iqs == 0) {
302302
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
303-
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
303+
buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147
304304
}
305305
}
306306

@@ -345,21 +345,22 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
345345

346346
// Repack 2x4 quants into one int
347347
// Add the 3rd bit instead of subtracting it to allow packing the quants
348-
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
349-
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
350-
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
351-
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
352-
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
353-
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
354-
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
355-
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
348+
// vec4 for unpack8 used due to #12147
349+
const i8vec2 vals00 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303)))).xy |
350+
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
351+
const i8vec2 vals01 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303)))).xy |
352+
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
353+
const i8vec2 vals10 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303)))).xy |
354+
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
355+
const i8vec2 vals11 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303)))).xy |
356+
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
356357
buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
357358
(pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
358359

359360
if (iqs == 0) {
360361
const uint is = iqs_k / 4;
361-
const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
362-
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
362+
const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
363+
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147
363364

364365
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
365366
}
@@ -516,15 +517,15 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
516517
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
517518
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
518519

519-
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
520-
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
521-
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
522-
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
520+
const i8vec2 vals00 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))).xy |
521+
unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);
522+
const i8vec2 vals01 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))).xy |
523+
unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);
523524
buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
524525

525526
if (iqs == 0) {
526527
const uint is = iqs_k / 4;
527-
const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
528+
const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy;
528529

529530
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
530531
}

0 commit comments

Comments
 (0)