@@ -300,7 +300,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
300300
301301 if (iqs == 0 ) {
302302 buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
303- buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8 ]);
303+ buf_a[buf_ib].scales = unpack8(uint32_t( data_a_packed16[ib_k].scales[iqs_k / 8 ])).xy; // vec4 used due to #12147
304304 }
305305}
306306
@@ -345,21 +345,22 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
345345
346346 // Repack 2x4 quants into one int
347347 // Add the 3rd bit instead of subtracting it to allow packing the quants
348- const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
349- unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2 ));
350- const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
351- unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1 ] >> hm_shift) & uint16_t(0x0101)) << 2 ));
352- const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
353- unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2 ] >> hm_shift) & uint16_t(0x0101)) << 2 ));
354- const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
355- unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3 ] >> hm_shift) & uint16_t(0x0101)) << 2 ));
348+ // vec4 for unpack8 used due to #12147
349+ const i8vec2 vals00 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303)))).xy |
350+ unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101))) << 2 )).xy;
351+ const i8vec2 vals01 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303)))).xy |
352+ unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1 ] >> hm_shift) & uint16_t(0x0101))) << 2 )).xy;
353+ const i8vec2 vals10 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303)))).xy |
354+ unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2 ] >> hm_shift) & uint16_t(0x0101))) << 2 )).xy;
355+ const i8vec2 vals11 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303)))).xy |
356+ unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3 ] >> hm_shift) & uint16_t(0x0101))) << 2 )).xy;
356357 buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
357358 (pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4 );
358359
359360 if (iqs == 0 ) {
360361 const uint is = iqs_k / 4 ;
361- const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2 ] >> (4 * (is / 8 ))) & 0x0F0F) |
362- (((data_a_packed16[ib_k].scales[(8 + (is % 4 )) / 2 ] >> (2 * (is / 4 ))) & 0x0303) << 4 )));
362+ const i8vec2 scales = i8vec2(unpack8(uint32_t( ((data_a_packed16[ib_k].scales[(is % 8 ) / 2 ] >> (4 * (is / 8 ))) & 0x0F0F) |
363+ (((data_a_packed16[ib_k].scales[(8 + (is % 4 )) / 2 ] >> (2 * (is / 4 ))) & 0x0303) << 4 ))).xy); // vec4 used due to #12147
363364
364365 buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32 );
365366 }
@@ -516,15 +517,15 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
516517 const uint qh_idx = (iqs_k / 32 ) * 8 + iqs;
517518 const uint qh_shift = ((iqs_k % 32 ) / 8 ) * 2 ;
518519
519- const i8vec2 vals00 = (unpack8(int16_t ((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
520- unpack8(int16_t (((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4 ))) - int8_t(32 );
521- const i8vec2 vals01 = (unpack8(int16_t ((data_a_packed16[ib_k].ql[ql_idx * 2 + 1 ] >> ql_shift) & uint16_t(0x0F0F))) |
522- unpack8(int16_t (((data_a_packed16[ib_k].qh[qh_idx * 2 + 1 ] >> qh_shift) & uint16_t(0x0303)) << 4 ))) - int8_t(32 );
520+ const i8vec2 vals00 = (unpack8(int32_t ((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))).xy |
521+ unpack8(int32_t (((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4 )).xy ) - int8_t(32 );
522+ const i8vec2 vals01 = (unpack8(int32_t ((data_a_packed16[ib_k].ql[ql_idx * 2 + 1 ] >> ql_shift) & uint16_t(0x0F0F))).xy |
523+ unpack8(int32_t (((data_a_packed16[ib_k].qh[qh_idx * 2 + 1 ] >> qh_shift) & uint16_t(0x0303)) << 4 )).xy ) - int8_t(32 );
523524 buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
524525
525526 if (iqs == 0 ) {
526527 const uint is = iqs_k / 4 ;
527- const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2 ]);
528+ const i8vec2 scales = unpack8(int32_t( data_a_packed16[ib_k].scales[is / 2 ])).xy ;
528529
529530 buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
530531 }
0 commit comments