@@ -500,10 +500,9 @@ void main() {
500500 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
501501 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
502502
503- const uint ib = idx / 128; // 2 values per idx
504- const uint ib32 = (idx % 128) / 16; // 0..7
505- const uint ib8 = (idx % 128) / 4;
506- const int i8 = 2 * int(idx % 4);
503+ const uint ib = idx / 32; // 8 values per idx
504+ const uint ib32 = (idx % 32) / 4; // 0..7
505+ const uint ib8 = idx % 32;
507506
508507 const float d = float(data_a[ib].d);
509508 const uint qh = data_a[ib].qh[ib32];
@@ -512,22 +511,16 @@ void main() {
512511 const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
513512 const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
514513
515- const ivec2 gvec = ivec2(
516- bitfieldExtract(grid, 2 * (i8), 2),
517- bitfieldExtract(grid, 2 * (i8 + 1), 2)
518- );
519- const vec2 v = dl * (vec2(gvec) + delta);
520-
521- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
522- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
514+ [[unroll]] for (int k = 0; k < 8; ++k) {
515+ buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
516+ }
523517#elif defined(DATA_A_IQ1_M)
524518 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
525519 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
526520
527- const uint ib = idx / 128 ; // 2 values per idx
528- const uint ib8 = ( idx % 128) / 4 ;
521+ const uint ib = idx / 32 ; // 8 values per idx
522+ const uint ib8 = idx % 32 ;
529523 const uint ib16 = ib8 / 2;
530- const int i8 = 2 * int(idx % 4);
531524
532525 const uint16_t[4] scales = data_a[ib].scales;
533526 const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@@ -538,21 +531,17 @@ void main() {
538531 const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
539532 const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
540533 const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
541- const ivec2 gvec = ivec2(
542- bitfieldExtract(grid, 2 * (i8), 2),
543- bitfieldExtract(grid, 2 * (i8 + 1), 2)
544- );
545- const vec2 v = dl * (vec2(gvec) + delta);
546534
547- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
548- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
535+ [[unroll]] for (int k = 0; k < 8; ++k) {
536+ buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
537+ }
549538#elif defined(DATA_A_IQ2_XXS)
550539 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
551540 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
552541
553- const uint ib = idx / 128 ; // 2 values per idx
554- const uint ib32 = (idx % 128 ) / 16 ; // 0..7
555- const uint ib8 = ( idx / 4) % 4;
542+ const uint ib = idx / 32 ; // 8 values per idx
543+ const uint ib32 = (idx % 32 ) / 4 ; // 0..7
544+ const uint ib8 = idx % 4;
556545
557546 const float d = float(data_a[ib].d);
558547 const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@@ -562,63 +551,81 @@ void main() {
562551 data_a[ib].qs[8*ib32 + 6],
563552 data_a[ib].qs[8*ib32 + 7]
564553 ));
565- const float db = d * 0.25 * (0.5 + (signs >> 28));
554+ const FLOAT_TYPE db = FLOAT_TYPE( d * 0.25 * (0.5 + (signs >> 28) ));
566555 const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
567- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
568- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
569- const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
570- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
571-
572- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
573- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
556+ const uint sign = sign7 | (bitCount(sign7) << 7);
557+ const uvec2 grid = iq2xxs_grid[qs];
558+ const vec4 grid0 = vec4(unpack8(grid.x));
559+ const vec4 grid1 = vec4(unpack8(grid.y));
560+
561+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
562+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
563+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
564+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
565+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
566+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
567+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
568+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
574569#elif defined(DATA_A_IQ2_XS)
575570 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
576571 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
577572
578- const uint ib = idx / 128 ; // 2 values per idx
579- const uint ib32 = (idx % 128 ) / 16; // 0..7
580- const uint ib8 = ( idx / 4) % 4; // 0..3
573+ const uint ib = idx / 32 ; // 8 values per idx
574+ const uint ib32 = (idx % 32 ) / 4; // 0..7
575+ const uint ib8 = idx % 4; // 0..3
581576
582577 const float d = float(data_a[ib].d);
583578 const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
584- const float db = d * 0.25 * (0.5 + scale);
579+ const FLOAT_TYPE db = FLOAT_TYPE( d * 0.25 * (0.5 + scale) );
585580 const uint qs = data_a[ib].qs[4 * ib32 + ib8];
586581 const uint sign7 = qs >> 9;
587- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
588- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
589- const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
590- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
591-
592- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
593- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
582+ const uint sign = sign7 | (bitCount(sign7) << 7);
583+ const uvec2 grid = iq2xs_grid[qs & 511];
584+ const vec4 grid0 = vec4(unpack8(grid.x));
585+ const vec4 grid1 = vec4(unpack8(grid.y));
586+
587+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
588+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
589+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
590+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
591+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
592+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
593+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
594+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
594595#elif defined(DATA_A_IQ2_S)
595596 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
596597 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
597598
598- const uint ib = idx / 128 ; // 2 values per idx
599- const uint ib8 = ( idx % 128) / 4 ; // 0..31
600- const uint ib32 = ib8 / 4; // 0..7
599+ const uint ib = idx / 32 ; // 8 values per idx
600+ const uint ib8 = idx % 32 ; // 0..31
601+ const uint ib32 = ib8 / 4; // 0..7
601602
602603 const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
603604 const uint qs = data_a[ib].qs[ib8];
604605 const uint qh = data_a[ib].qh[ib32];
605606 const uint qhshift = 2 * (ib8 % 4);
606- const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)) ;
607+ const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
607608
608609 const float d = float(data_a[ib].d);
609- const float db = d * 0.25 * (0.5 + scale);
610- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
611- const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
612- const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
613-
614- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
615- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
610+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
611+ const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
612+ const vec4 grid0 = vec4(unpack8(grid.x));
613+ const vec4 grid1 = vec4(unpack8(grid.y));
614+
615+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
616+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
617+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
618+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
619+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
620+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
621+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
622+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
616623#elif defined(DATA_A_IQ3_XXS)
617624 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
618625 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
619626
620- const uint ib = idx / 128 ; // 2 values per idx
621- const uint iqs = ( idx % 128) / 2 ; // 0..63
627+ const uint ib = idx / 64 ; // 4 values per idx
628+ const uint iqs = idx % 64 ; // 0..63
622629 const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
623630
624631 const float d = float(data_a[ib].d);
@@ -631,33 +638,36 @@ void main() {
631638 ));
632639 const float db = d * 0.5 * (0.5 + (signs >> 28));
633640 const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
634- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
635- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
636- const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
637- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
638-
639- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
640- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
641+ const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
642+ const uint grid = iq3xxs_grid[qs];
643+ const vec4 v = db * vec4(unpack8(grid));
644+
645+ buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
646+ buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
647+ buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
648+ buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
641649#elif defined(DATA_A_IQ3_S)
642650 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
643651 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
644652
645- const uint ib = idx / 128 ; // 2 values per idx
646- const uint iqs = ( idx % 128) / 2 ; // 0..63
653+ const uint ib = idx / 64 ; // 4 values per idx
654+ const uint iqs = idx % 64 ; // 0..63
647655 const uint iqh = iqs / 8;
648656
649657 const float d = float(data_a[ib].d);
650658 const uint qs = data_a[ib].qs[iqs];
651659 const uint qh = data_a[ib].qh[iqh];
652- const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4 )));
660+ const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2 )));
653661 const uint scale = data_a[ib].scales[iqs / 16];
654662 const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
655663 const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
656- const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)) ;
657- const vec2 v = db * vec2(sign01) * vec2( unpack8(grid).xy); // vec4 used due to #12147
664+ const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
665+ const vec4 v = db * vec4( unpack8(grid));
658666
659- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
660- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
667+ buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
668+ buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
669+ buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
670+ buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
661671#elif defined(DATA_A_IQ4_XS)
662672 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
663673 const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
0 commit comments