Skip to content

Commit 05cd699

Browse files
ikawrakowIwan Kawrakow
andauthored
CUDA: faster IQ3_K, IQ3_KS, IQ3_K_R4 (#714)
* Use bperm trick for iq3_ks - 5% PP performance gain * Use bperm trick for iq3_k -> 5% PP performance gain * Use bperm trick for iq3_k -> 8% PP performance gain * Use bperm trick for iq3_k_r4 gemv -> ~5% faster * Use bperm trick for iq3_k gemv -> ~3% faster * Use bperm trick for iq3_k gemv -> 4.5% gain --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 78de773 commit 05cd699

File tree

3 files changed

+88
-130
lines changed

3 files changed

+88
-130
lines changed

ggml/src/ggml-cuda/iqk_mmvq.cu

Lines changed: 38 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -776,28 +776,21 @@ __device__ __forceinline__ void vec_dot_iq3_k_r4_q8_1(
776776
//scales[1] = __vcmpeq4((scales_h[is] >> ib32) & 0x01010101, 0x01010101);
777777
//scales[0] = __vsub4(scales[0] ^ scales[1], scales[1]);
778778
const int8_t * s8 = (const int8_t *)scales;
779-
int2 val1;
780-
const int * q2 = (const int *)bq3->qs + 8*ib32 + 4*is;
781-
const int * qh = (const int *)bq3->qh + 4*ib32;
782-
int aux32[2];
783-
const uint8_t * aux8 = (const uint8_t *)aux32;
779+
const uint32_t * q2 = (const uint32_t *)bq3->qs + 8*ib32 + 4*is;
780+
const uint32_t * qh = (const uint32_t *)bq3->qh + 4*ib32;
784781
for (int i = 0; i < 4; ++i) {
785-
auto values1 = iq3nl_values + (((bq3->extra[i+4*is] >> ib32) & 1) << 3);
782+
uint32_t extra32 = uint32_t((bq3->extra[i+4*is] >> ib32) & 1) * 0x88888888;
783+
786784
int sumi1 = 0;
787-
int h = qh[i] >> 4*is;
788-
aux32[0] = ((q2[i] >> 0) & 0x03030303) | ((h << 2) & 0x04040404);
789-
aux32[1] = ((q2[i] >> 2) & 0x03030303) | ((h << 1) & 0x04040404);
790-
val1.x = int_from_table(aux8+0, (const uint8_t *)values1);
791-
val1.y = int_from_table(aux8+4, (const uint8_t *)values1);
792-
sumi1 = ggml_cuda_dp4a(val1.x, q8[0], ggml_cuda_dp4a(val1.y, q8[1], sumi1));
793-
aux32[0] = ((q2[i] >> 4) & 0x03030303) | ((h >> 0) & 0x04040404);
794-
aux32[1] = ((q2[i] >> 6) & 0x03030303) | ((h >> 1) & 0x04040404);
795-
val1.x = int_from_table(aux8+0, (const uint8_t *)values1);
796-
val1.y = int_from_table(aux8+4, (const uint8_t *)values1);
797-
sumi1 = ggml_cuda_dp4a(val1.x, q8[2], ggml_cuda_dp4a(val1.y, q8[3], sumi1));
785+
uint32_t h = qh[i] >> 4*is;
786+
uint32_t val1 = ((q2[i] >> 0) & 0x33333333) | extra32 | ((h << 2) & 0x04040404) | ((h << 4) & 0x40404040);
787+
uint32_t val2 = ((q2[i] >> 2) & 0x33333333) | extra32 | ((h << 1) & 0x04040404) | ((h << 3) & 0x40404040);
788+
int2 v1 = get_int_from_table_16(val1, iq3nl_values);
789+
int2 v2 = get_int_from_table_16(val2, iq3nl_values);
790+
sumi1 = ggml_cuda_dp4a(v1.x, q8[0], ggml_cuda_dp4a(v2.x, q8[1], sumi1));
791+
sumi1 = ggml_cuda_dp4a(v1.y, q8[2], ggml_cuda_dp4a(v2.y, q8[3], sumi1));
798792
const float d = __half2float(bq3->d[i]) * d8;
799793
result[i] += d * sumi1 * s8[i] * (s8[i+4] ? -1 : 1);
800-
//result[i] += d * sumi1 * s8[i];
801794
}
802795
}
803796

@@ -1021,41 +1014,32 @@ __device__ __forceinline__ void vec_dot_iq3_k_q8_1(
10211014
const uint16_t sh = bq3->scales_h >> (8*ib128 + il8/2);
10221015

10231016
const uint8_t extra = bq3->extra >> (8*ib128 + il8/2);
1024-
const uint16_t * values1 = iq3k_table + ((extra << 6) & 0x40);
1025-
const uint16_t * values2 = iq3k_table + ((extra << 5) & 0x40);
1026-
const uint16_t * values3 = iq3k_table + ((extra << 4) & 0x40);
1027-
const uint16_t * values4 = iq3k_table + ((extra << 3) & 0x40);
1017+
uint32_t extra32 = uint32_t(extra) * 0x01010101;
1018+
uint32_t extra32_1 = ((extra32 << 3) & 0x08080808) | ((extra32 << 5) & 0x80808080);
1019+
uint32_t extra32_2 = ((extra32 << 2) & 0x08080808) | ((extra32 << 4) & 0x80808080);
10281020

10291021
const int * q8;
10301022
int sumi[4] = {0, 0, 0, 0};
1031-
int v;
10321023
for (int i = 0; i < 2; ++i) {
10331024
uint32_t vl = ql[2*i+0] | (ql[2*i+1] << 16);
1034-
uint32_t vh = ((qh[2*i+0] | (qh[2*i+1] << 16)) << hshift) >> 2;
1025+
uint32_t vh = ((qh[2*i+0] | (qh[2*i+1] << 16)) << hshift);
1026+
1027+
uint32_t val1 = ((vl >> 0) & 0x33333333) | extra32_1 | ((vh >> 2) & 0x04040404) | ((vh >> 0) & 0x40404040);
1028+
uint32_t val2 = ((vl >> 2) & 0x33333333) | extra32_2 | ((vh >> 3) & 0x04040404) | ((vh >> 1) & 0x40404040);
1029+
int2 v1 = get_int_from_table_16(val1, iq3nl_values);
1030+
int2 v2 = get_int_from_table_16(val2, iq3nl_values);
10351031

10361032
q8 = (const int *)bq8_1[4*ib128+0].qs + 2*il8;
1037-
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
1038-
v = int_from_table_2(aux8, values1);
1039-
sumi[0] = ggml_cuda_dp4a(v, q8[i], sumi[0]);
1040-
vl >>= 2; vh >>= 1;
1033+
sumi[0] = ggml_cuda_dp4a(v1.x, q8[i], sumi[0]);
10411034

10421035
q8 += sizeof(block_q8_1)/4;
1043-
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
1044-
v = int_from_table_2(aux8, values2);
1045-
sumi[1] = ggml_cuda_dp4a(v, q8[i], sumi[1]);
1046-
vl >>= 2; vh >>= 1;
1036+
sumi[1] = ggml_cuda_dp4a(v2.x, q8[i], sumi[1]);
10471037

10481038
q8 += sizeof(block_q8_1)/4;
1049-
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
1050-
v = int_from_table_2(aux8, values3);
1051-
sumi[2] = ggml_cuda_dp4a(v, q8[i], sumi[2]);
1052-
vl >>= 2; vh >>= 1;
1039+
sumi[2] = ggml_cuda_dp4a(v1.y, q8[i], sumi[2]);
10531040

10541041
q8 += sizeof(block_q8_1)/4;
1055-
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
1056-
v = int_from_table_2(aux8, values4);
1057-
sumi[3] = ggml_cuda_dp4a(v, q8[i], sumi[3]);
1058-
1042+
sumi[3] = ggml_cuda_dp4a(v2.y, q8[i], sumi[3]);
10591043
}
10601044
const float d = __half2float(bq3->d);
10611045
const uint16_t * sl16 = (const uint16_t *)bq3->scales_l + 2*ib128;
@@ -1127,50 +1111,37 @@ __device__ __forceinline__ void vec_dot_iq3_ks_q8_1(
11271111
const uint16_t * ql = (const uint16_t *)bq3->qs + 16*ib128 + 4*il8;
11281112
const uint16_t * qh = (const uint16_t *)bq3->qh + 4*il8;
11291113

1130-
int32_t aux32;
1131-
const uint8_t * aux8 = (const uint8_t *)&aux32;
1132-
11331114
uint16_t extra = bq3->extra >> 4*ib128;
1134-
uint16_t extra_v = extra >> 8;
1115+
uint32_t extra_v = uint32_t(extra >> 8) * 0x01010101;
11351116

1136-
const uint16_t * values1 = iq3k_table + ((extra_v << 6) & 0x40);
1137-
const uint16_t * values2 = iq3k_table + ((extra_v << 5) & 0x40);
1138-
const uint16_t * values3 = iq3k_table + ((extra_v << 4) & 0x40);
1139-
const uint16_t * values4 = iq3k_table + ((extra_v << 3) & 0x40);
1117+
uint32_t extra32_1 = ((extra_v << 3) & 0x08080808) | ((extra_v << 5) & 0x80808080);
1118+
uint32_t extra32_2 = ((extra_v << 2) & 0x08080808) | ((extra_v << 4) & 0x80808080);
11401119

11411120
const int * q8;
11421121
int sumi[4] = {0, 0, 0, 0};
1143-
int v;
11441122
for (int i = 0; i < 2; ++i) {
11451123
uint32_t vl = ql[2*i+0] | (ql[2*i+1] << 16);
1146-
uint32_t vh = ((qh[2*i+0] | (qh[2*i+1] << 16)) >> 4*ib128) << 2;
1124+
uint32_t vh = ((qh[2*i+0] | (qh[2*i+1] << 16)) >> 4*ib128);
1125+
1126+
uint32_t val1 = ((vl >> 0) & 0x33333333) | extra32_1 | ((vh << 2) & 0x04040404) | ((vh << 4) & 0x40404040);
1127+
uint32_t val2 = ((vl >> 2) & 0x33333333) | extra32_2 | ((vh << 1) & 0x04040404) | ((vh << 3) & 0x40404040);
1128+
int2 v1 = get_int_from_table_16(val1, iq3nl_values);
1129+
int2 v2 = get_int_from_table_16(val2, iq3nl_values);
11471130

11481131
q8 = (const int *)bq8_1[4*ib128+0].qs + 2*il8;
1149-
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
1150-
v = int_from_table_2(aux8, values1);
1151-
sumi[0] = ggml_cuda_dp4a(v, q8[i], sumi[0]);
1152-
vl >>= 2; vh >>= 1;
1132+
sumi[0] = ggml_cuda_dp4a(v1.x, q8[i], sumi[0]);
11531133

11541134
q8 += sizeof(block_q8_1)/4;
1155-
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
1156-
v = int_from_table_2(aux8, values2);
1157-
sumi[1] = ggml_cuda_dp4a(v, q8[i], sumi[1]);
1158-
vl >>= 2; vh >>= 1;
1135+
sumi[1] = ggml_cuda_dp4a(v2.x, q8[i], sumi[1]);
11591136

11601137
q8 += sizeof(block_q8_1)/4;
1161-
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
1162-
v = int_from_table_2(aux8, values3);
1163-
sumi[2] = ggml_cuda_dp4a(v, q8[i], sumi[2]);
1164-
vl >>= 2; vh >>= 1;
1138+
sumi[2] = ggml_cuda_dp4a(v1.y, q8[i], sumi[2]);
11651139

11661140
q8 += sizeof(block_q8_1)/4;
1167-
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
1168-
v = int_from_table_2(aux8, values4);
1169-
sumi[3] = ggml_cuda_dp4a(v, q8[i], sumi[3]);
1170-
1141+
sumi[3] = ggml_cuda_dp4a(v2.y, q8[i], sumi[3]);
11711142
}
11721143
const uint16_t * sl16 = (const uint16_t *)bq3->scales;
1173-
aux32 = __vsub4(((sl16[0] | (sl16[1] << 16)) >> 4*ib128) & 0x0f0f0f0f, 0x10101010);
1144+
int32_t aux32 = __vsub4(((sl16[0] | (sl16[1] << 16)) >> 4*ib128) & 0x0f0f0f0f, 0x10101010);
11741145
const int8_t * a8 = (const int8_t *)&aux32;
11751146
*result += d * (__low2float(bq8_1[4*ib128+0].ds) * (a8[0] + ((extra << 4) & 0x10)) * sumi[0] +
11761147
__low2float(bq8_1[4*ib128+1].ds) * (a8[1] + ((extra << 3) & 0x10)) * sumi[1] +

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2696,8 +2696,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
26962696
constexpr int qstep = 8;
26972697
const int kqsx = threadIdx.x % qstep;
26982698

2699-
uint32_t aux32[4];
2700-
const uint8_t * aux8 = (const uint8_t *)aux32;
27012699
#pragma unroll
27022700
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
27032701
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
@@ -2711,35 +2709,33 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
27112709
const float d = bxi->d;
27122710

27132711
uint16_t extra = bxi->extra >> (kqsx/4);
2712+
uint32_t extra32[2] = { uint32_t(extra & 0xff) * 0x01010101, uint32_t(extra >> 8) * 0x01010101 };
27142713
int qh = get_int_b2(bxi->qh, kqsx);
27152714

27162715
#pragma unroll
27172716
for (int l = 0; l < qstep/4; ++l) {
27182717

2718+
//extra << 3, extra << 1, extra >> 1, extra >> 3
27192719
const int ql = get_int_b2(bxi->qs, kqsx + qstep*l);
2720-
aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404);
2721-
aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404);
2722-
aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404);
2723-
aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404);
2724-
2725-
const int val0 = int_from_table_2(aux8+ 0, iq3k_table + ((extra << 6) & 0x40));
2726-
const int val1 = int_from_table_2(aux8+ 4, iq3k_table + ((extra << 4) & 0x40));
2727-
const int val2 = int_from_table_2(aux8+ 8, iq3k_table + ((extra << 2) & 0x40));
2728-
const int val3 = int_from_table_2(aux8+12, iq3k_table + ((extra << 0) & 0x40));
2720+
uint32_t val1 = ((ql >> 0) & 0x33333333) | ((extra32[l] << 3) & 0x88888888)
2721+
| ((qh << 2) & 0x04040404) | ((qh << 4) & 0x40404040);
2722+
uint32_t val2 = ((ql >> 2) & 0x33333333) | ((extra32[l] << 1) & 0x88888888)
2723+
| ((qh << 1) & 0x04040404) | ((qh << 3) & 0x40404040);
2724+
int2 v1 = get_int_from_table_16(val1, iq3nl_values);
2725+
int2 v2 = get_int_from_table_16(val2, iq3nl_values);
27292726

2730-
extra >>= 8;
2731-
qh >>= 4;
2727+
qh >>= 4;
27322728

27332729
#ifdef INT8_MMA_AVAILABLE
2734-
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = val0;
2735-
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = val1;
2736-
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = val2;
2737-
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = val3;
2730+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = v1.x;
2731+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = v2.x;
2732+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = v1.y;
2733+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = v2.y;
27382734
#else
2739-
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = val0;
2740-
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = val1;
2741-
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = val2;
2742-
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = val3;
2735+
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x;
2736+
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x;
2737+
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y;
2738+
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y;
27432739
#endif // INT8_MMA_AVAILABLE
27442740
}
27452741

@@ -2769,8 +2765,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
27692765
constexpr int qstep = 8;
27702766
const int kqsx = threadIdx.x % qstep;
27712767

2772-
uint32_t aux32[4];
2773-
const uint8_t * aux8 = (const uint8_t *)aux32;
27742768
#pragma unroll
27752769
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
27762770
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
@@ -2783,36 +2777,35 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
27832777
const float d = __half2float(dptr[0]);
27842778
const block_iq3_ks * bxi = (const block_iq3_ks *)(dptr + 1) + kbx0;
27852779

2786-
uint16_t extra = bxi->extra >> 8;
2780+
//uint16_t extra = bxi->extra >> 8;
27872781
int qh = get_int_b2(bxi->qh, kqsx);
27882782

2783+
uint32_t extra32 = uint32_t(bxi->extra >> 8) * 0x01010101;
2784+
27892785
#pragma unroll
27902786
for (int l = 0; l < qstep/4; ++l) {
27912787

27922788
const int ql = get_int_b2(bxi->qs, kqsx + qstep*l);
2793-
aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404);
2794-
aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404);
2795-
aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404);
2796-
aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404);
2797-
2798-
const int val0 = int_from_table_2(aux8+ 0, iq3k_table + ((extra << 6) & 0x40));
2799-
const int val1 = int_from_table_2(aux8+ 4, iq3k_table + ((extra << 5) & 0x40));
2800-
const int val2 = int_from_table_2(aux8+ 8, iq3k_table + ((extra << 4) & 0x40));
2801-
const int val3 = int_from_table_2(aux8+12, iq3k_table + ((extra << 3) & 0x40));
2789+
uint32_t val1 = ((ql >> 0) & 0x33333333) | ((qh << 2) & 0x04040404) | ((extra32 << 3) & 0x08080808)
2790+
| ((qh << 4) & 0x40404040) | ((extra32 << 5) & 0x80808080);
2791+
uint32_t val2 = ((ql >> 2) & 0x33333333) | ((qh << 1) & 0x04040404) | ((extra32 << 2) & 0x08080808)
2792+
| ((qh << 3) & 0x40404040) | ((extra32 << 4) & 0x80808080);
2793+
int2 v1 = get_int_from_table_16(val1, iq3nl_values);
2794+
int2 v2 = get_int_from_table_16(val2, iq3nl_values);
28022795

2803-
extra >>= 4;
2804-
qh >>= 4;
2796+
extra32 >>= 4;
2797+
qh >>= 4;
28052798

28062799
#ifdef INT8_MMA_AVAILABLE
2807-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 0] = val0;
2808-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 8] = val1;
2809-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 16] = val2;
2810-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 24] = val3;
2800+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 0] = v1.x;
2801+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 8] = v2.x;
2802+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 16] = v1.y;
2803+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 24] = v2.y;
28112804
#else
2812-
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = val0;
2813-
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = val1;
2814-
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = val2;
2815-
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = val3;
2805+
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x;
2806+
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x;
2807+
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y;
2808+
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y;
28162809
#endif // INT8_MMA_AVAILABLE
28172810
}
28182811

ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1616

1717
const int kqsx = threadIdx.x/4; // 0...7 -> block of 32
1818

19-
uint32_t aux32[4];
20-
const uint8_t * aux8 = (const uint8_t *)aux32;
2119
#pragma unroll
2220
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
2321
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
@@ -37,29 +35,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
3735
#pragma unroll
3836
for (int l = 0; l < 2; ++l) {
3937

40-
auto values_l = iq3k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 6);
38+
//auto values_l = iq3k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 6);
39+
uint32_t extra32 = uint32_t((bxi->extra[ir+4*l] >> kqsx) & 1) * 0x88888888;
4140

4241
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
43-
aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404);
44-
aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404);
45-
aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404);
46-
aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404);
47-
48-
int val0 = int_from_table_2(aux8+ 0, values_l);
49-
int val1 = int_from_table_2(aux8+ 4, values_l);
50-
int val2 = int_from_table_2(aux8+ 8, values_l);
51-
int val3 = int_from_table_2(aux8+12, values_l);
42+
uint32_t val1 = ((ql >> 0) & 0x33333333) | extra32 | ((qh << 2) & 0x04040404) | ((qh << 4) & 0x40404040);
43+
uint32_t val2 = ((ql >> 2) & 0x33333333) | extra32 | ((qh << 1) & 0x04040404) | ((qh << 3) & 0x40404040);
44+
int2 v1 = get_int_from_table_16(val1, iq3nl_values);
45+
int2 v2 = get_int_from_table_16(val2, iq3nl_values);
5246

5347
#ifdef INT8_MMA_AVAILABLE
54-
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = val0;
55-
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = val1;
56-
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = val2;
57-
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = val3;
48+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = v1.x;
49+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = v2.x;
50+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = v1.y;
51+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = v2.y;
5852
#else
59-
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = val0;
60-
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = val1;
61-
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = val2;
62-
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = val3;
53+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = v1.x;
54+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = v2.x;
55+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = v1.y;
56+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = v2.y;
6357
#endif // INT8_MMA_AVAILABLE
6458

6559
qh >>= 4;

0 commit comments

Comments
 (0)