diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index c596f00ec..744c16378 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -776,28 +776,21 @@ __device__ __forceinline__ void vec_dot_iq3_k_r4_q8_1( //scales[1] = __vcmpeq4((scales_h[is] >> ib32) & 0x01010101, 0x01010101); //scales[0] = __vsub4(scales[0] ^ scales[1], scales[1]); const int8_t * s8 = (const int8_t *)scales; - int2 val1; - const int * q2 = (const int *)bq3->qs + 8*ib32 + 4*is; - const int * qh = (const int *)bq3->qh + 4*ib32; - int aux32[2]; - const uint8_t * aux8 = (const uint8_t *)aux32; + const uint32_t * q2 = (const uint32_t *)bq3->qs + 8*ib32 + 4*is; + const uint32_t * qh = (const uint32_t *)bq3->qh + 4*ib32; for (int i = 0; i < 4; ++i) { - auto values1 = iq3nl_values + (((bq3->extra[i+4*is] >> ib32) & 1) << 3); + uint32_t extra32 = uint32_t((bq3->extra[i+4*is] >> ib32) & 1) * 0x88888888; + int sumi1 = 0; - int h = qh[i] >> 4*is; - aux32[0] = ((q2[i] >> 0) & 0x03030303) | ((h << 2) & 0x04040404); - aux32[1] = ((q2[i] >> 2) & 0x03030303) | ((h << 1) & 0x04040404); - val1.x = int_from_table(aux8+0, (const uint8_t *)values1); - val1.y = int_from_table(aux8+4, (const uint8_t *)values1); - sumi1 = ggml_cuda_dp4a(val1.x, q8[0], ggml_cuda_dp4a(val1.y, q8[1], sumi1)); - aux32[0] = ((q2[i] >> 4) & 0x03030303) | ((h >> 0) & 0x04040404); - aux32[1] = ((q2[i] >> 6) & 0x03030303) | ((h >> 1) & 0x04040404); - val1.x = int_from_table(aux8+0, (const uint8_t *)values1); - val1.y = int_from_table(aux8+4, (const uint8_t *)values1); - sumi1 = ggml_cuda_dp4a(val1.x, q8[2], ggml_cuda_dp4a(val1.y, q8[3], sumi1)); + uint32_t h = qh[i] >> 4*is; + uint32_t val1 = ((q2[i] >> 0) & 0x33333333) | extra32 | ((h << 2) & 0x04040404) | ((h << 4) & 0x40404040); + uint32_t val2 = ((q2[i] >> 2) & 0x33333333) | extra32 | ((h << 1) & 0x04040404) | ((h << 3) & 0x40404040); + int2 v1 = get_int_from_table_16(val1, iq3nl_values); + int2 v2 = get_int_from_table_16(val2, iq3nl_values); + sumi1 = ggml_cuda_dp4a(v1.x, q8[0], ggml_cuda_dp4a(v2.x, q8[1], sumi1)); + sumi1 = ggml_cuda_dp4a(v1.y, q8[2], ggml_cuda_dp4a(v2.y, q8[3], sumi1)); const float d = __half2float(bq3->d[i]) * d8; result[i] += d * sumi1 * s8[i] * (s8[i+4] ? -1 : 1); - //result[i] += d * sumi1 * s8[i]; } } @@ -1021,41 +1014,32 @@ __device__ __forceinline__ void vec_dot_iq3_k_q8_1( const uint16_t sh = bq3->scales_h >> (8*ib128 + il8/2); const uint8_t extra = bq3->extra >> (8*ib128 + il8/2); - const uint16_t * values1 = iq3k_table + ((extra << 6) & 0x40); - const uint16_t * values2 = iq3k_table + ((extra << 5) & 0x40); - const uint16_t * values3 = iq3k_table + ((extra << 4) & 0x40); - const uint16_t * values4 = iq3k_table + ((extra << 3) & 0x40); + uint32_t extra32 = uint32_t(extra) * 0x01010101; + uint32_t extra32_1 = ((extra32 << 3) & 0x08080808) | ((extra32 << 5) & 0x80808080); + uint32_t extra32_2 = ((extra32 << 2) & 0x08080808) | ((extra32 << 4) & 0x80808080); const int * q8; int sumi[4] = {0, 0, 0, 0}; - int v; for (int i = 0; i < 2; ++i) { uint32_t vl = ql[2*i+0] | (ql[2*i+1] << 16); - uint32_t vh = ((qh[2*i+0] | (qh[2*i+1] << 16)) << hshift) >> 2; + uint32_t vh = ((qh[2*i+0] | (qh[2*i+1] << 16)) << hshift); + + uint32_t val1 = ((vl >> 0) & 0x33333333) | extra32_1 | ((vh >> 2) & 0x04040404) | ((vh >> 0) & 0x40404040); + uint32_t val2 = ((vl >> 2) & 0x33333333) | extra32_2 | ((vh >> 3) & 0x04040404) | ((vh >> 1) & 0x40404040); + int2 v1 = get_int_from_table_16(val1, iq3nl_values); + int2 v2 = get_int_from_table_16(val2, iq3nl_values); q8 = (const int *)bq8_1[4*ib128+0].qs + 2*il8; - aux32 = (vl & 0x03030303) | (vh & 0x04040404); - v = int_from_table_2(aux8, values1); - sumi[0] = ggml_cuda_dp4a(v, q8[i], sumi[0]); - vl >>= 2; vh >>= 1; + sumi[0] = ggml_cuda_dp4a(v1.x, q8[i], sumi[0]); q8 += sizeof(block_q8_1)/4; - aux32 = (vl & 0x03030303) | (vh & 0x04040404); - v = int_from_table_2(aux8, values2); - sumi[1] = ggml_cuda_dp4a(v, q8[i], sumi[1]); - vl >>= 2; vh >>= 1; + sumi[1] = ggml_cuda_dp4a(v2.x, q8[i], sumi[1]); q8 += sizeof(block_q8_1)/4; - aux32 = (vl & 0x03030303) | (vh & 0x04040404); - v = int_from_table_2(aux8, values3); - sumi[2] = ggml_cuda_dp4a(v, q8[i], sumi[2]); - vl >>= 2; vh >>= 1; + sumi[2] = ggml_cuda_dp4a(v1.y, q8[i], sumi[2]); q8 += sizeof(block_q8_1)/4; - aux32 = (vl & 0x03030303) | (vh & 0x04040404); - v = int_from_table_2(aux8, values4); - sumi[3] = ggml_cuda_dp4a(v, q8[i], sumi[3]); - + sumi[3] = ggml_cuda_dp4a(v2.y, q8[i], sumi[3]); } const float d = __half2float(bq3->d); const uint16_t * sl16 = (const uint16_t *)bq3->scales_l + 2*ib128; @@ -1127,50 +1111,37 @@ __device__ __forceinline__ void vec_dot_iq3_ks_q8_1( const uint16_t * ql = (const uint16_t *)bq3->qs + 16*ib128 + 4*il8; const uint16_t * qh = (const uint16_t *)bq3->qh + 4*il8; - int32_t aux32; - const uint8_t * aux8 = (const uint8_t *)&aux32; - uint16_t extra = bq3->extra >> 4*ib128; - uint16_t extra_v = extra >> 8; + uint32_t extra_v = uint32_t(extra >> 8) * 0x01010101; - const uint16_t * values1 = iq3k_table + ((extra_v << 6) & 0x40); - const uint16_t * values2 = iq3k_table + ((extra_v << 5) & 0x40); - const uint16_t * values3 = iq3k_table + ((extra_v << 4) & 0x40); - const uint16_t * values4 = iq3k_table + ((extra_v << 3) & 0x40); + uint32_t extra32_1 = ((extra_v << 3) & 0x08080808) | ((extra_v << 5) & 0x80808080); + uint32_t extra32_2 = ((extra_v << 2) & 0x08080808) | ((extra_v << 4) & 0x80808080); const int * q8; int sumi[4] = {0, 0, 0, 0}; - int v; for (int i = 0; i < 2; ++i) { uint32_t vl = ql[2*i+0] | (ql[2*i+1] << 16); - uint32_t vh = ((qh[2*i+0] | (qh[2*i+1] << 16)) >> 4*ib128) << 2; + uint32_t vh = ((qh[2*i+0] | (qh[2*i+1] << 16)) >> 4*ib128); + + uint32_t val1 = ((vl >> 0) & 0x33333333) | extra32_1 | ((vh << 2) & 0x04040404) | ((vh << 4) & 0x40404040); + uint32_t val2 = ((vl >> 2) & 0x33333333) | extra32_2 | ((vh << 1) & 0x04040404) | ((vh << 3) & 0x40404040); + int2 v1 = get_int_from_table_16(val1, iq3nl_values); + int2 v2 = get_int_from_table_16(val2, iq3nl_values); q8 = (const int *)bq8_1[4*ib128+0].qs + 2*il8; - aux32 = (vl & 0x03030303) | (vh & 0x04040404); - v = int_from_table_2(aux8, values1); - sumi[0] = ggml_cuda_dp4a(v, q8[i], sumi[0]); - vl >>= 2; vh >>= 1; + sumi[0] = ggml_cuda_dp4a(v1.x, q8[i], sumi[0]); q8 += sizeof(block_q8_1)/4; - aux32 = (vl & 0x03030303) | (vh & 0x04040404); - v = int_from_table_2(aux8, values2); - sumi[1] = ggml_cuda_dp4a(v, q8[i], sumi[1]); - vl >>= 2; vh >>= 1; + sumi[1] = ggml_cuda_dp4a(v2.x, q8[i], sumi[1]); q8 += sizeof(block_q8_1)/4; - aux32 = (vl & 0x03030303) | (vh & 0x04040404); - v = int_from_table_2(aux8, values3); - sumi[2] = ggml_cuda_dp4a(v, q8[i], sumi[2]); - vl >>= 2; vh >>= 1; + sumi[2] = ggml_cuda_dp4a(v1.y, q8[i], sumi[2]); q8 += sizeof(block_q8_1)/4; - aux32 = (vl & 0x03030303) | (vh & 0x04040404); - v = int_from_table_2(aux8, values4); - sumi[3] = ggml_cuda_dp4a(v, q8[i], sumi[3]); - + sumi[3] = ggml_cuda_dp4a(v2.y, q8[i], sumi[3]); } const uint16_t * sl16 = (const uint16_t *)bq3->scales; - aux32 = __vsub4(((sl16[0] | (sl16[1] << 16)) >> 4*ib128) & 0x0f0f0f0f, 0x10101010); + int32_t aux32 = __vsub4(((sl16[0] | (sl16[1] << 16)) >> 4*ib128) & 0x0f0f0f0f, 0x10101010); const int8_t * a8 = (const int8_t *)&aux32; *result += d * (__low2float(bq8_1[4*ib128+0].ds) * (a8[0] + ((extra << 4) & 0x10)) * sumi[0] + __low2float(bq8_1[4*ib128+1].ds) * (a8[1] + ((extra << 3) & 0x10)) * sumi[1] + diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index b877545c6..304942eb0 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2696,8 +2696,6 @@ template static __device__ __forceinlin constexpr int qstep = 8; const int kqsx = threadIdx.x % qstep; - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; @@ -2711,35 +2709,33 @@ template static __device__ __forceinlin const float d = bxi->d; uint16_t extra = bxi->extra >> (kqsx/4); + uint32_t extra32[2] = { uint32_t(extra & 0xff) * 0x01010101, uint32_t(extra >> 8) * 0x01010101 }; int qh = get_int_b2(bxi->qh, kqsx); #pragma unroll for (int l = 0; l < qstep/4; ++l) { + //extra << 3, extra << 1, extra >> 1, extra >> 3 const int ql = get_int_b2(bxi->qs, kqsx + qstep*l); - aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404); - aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404); - aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404); - aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404); - - const int val0 = int_from_table_2(aux8+ 0, iq3k_table + ((extra << 6) & 0x40)); - const int val1 = int_from_table_2(aux8+ 4, iq3k_table + ((extra << 4) & 0x40)); - const int val2 = int_from_table_2(aux8+ 8, iq3k_table + ((extra << 2) & 0x40)); - const int val3 = int_from_table_2(aux8+12, iq3k_table + ((extra << 0) & 0x40)); + uint32_t val1 = ((ql >> 0) & 0x33333333) | ((extra32[l] << 3) & 0x88888888) + | ((qh << 2) & 0x04040404) | ((qh << 4) & 0x40404040); + uint32_t val2 = ((ql >> 2) & 0x33333333) | ((extra32[l] << 1) & 0x88888888) + | ((qh << 1) & 0x04040404) | ((qh << 3) & 0x40404040); + int2 v1 = get_int_from_table_16(val1, iq3nl_values); + int2 v2 = get_int_from_table_16(val2, iq3nl_values); - extra >>= 8; - qh >>= 4; + qh >>= 4; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = val0; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = val1; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = val2; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = val3; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = v2.y; #else - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = val0; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = val1; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = val2; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = val3; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y; #endif // INT8_MMA_AVAILABLE } @@ -2769,8 +2765,6 @@ template static __device__ __forceinlin constexpr int qstep = 8; const int kqsx = threadIdx.x % qstep; - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; @@ -2783,36 +2777,35 @@ template static __device__ __forceinlin const float d = __half2float(dptr[0]); const block_iq3_ks * bxi = (const block_iq3_ks *)(dptr + 1) + kbx0; - uint16_t extra = bxi->extra >> 8; + //uint16_t extra = bxi->extra >> 8; int qh = get_int_b2(bxi->qh, kqsx); + uint32_t extra32 = uint32_t(bxi->extra >> 8) * 0x01010101; + #pragma unroll for (int l = 0; l < qstep/4; ++l) { const int ql = get_int_b2(bxi->qs, kqsx + qstep*l); - aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404); - aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404); - aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404); - aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404); - - const int val0 = int_from_table_2(aux8+ 0, iq3k_table + ((extra << 6) & 0x40)); - const int val1 = int_from_table_2(aux8+ 4, iq3k_table + ((extra << 5) & 0x40)); - const int val2 = int_from_table_2(aux8+ 8, iq3k_table + ((extra << 4) & 0x40)); - const int val3 = int_from_table_2(aux8+12, iq3k_table + ((extra << 3) & 0x40)); + uint32_t val1 = ((ql >> 0) & 0x33333333) | ((qh << 2) & 0x04040404) | ((extra32 << 3) & 0x08080808) + | ((qh << 4) & 0x40404040) | ((extra32 << 5) & 0x80808080); + uint32_t val2 = ((ql >> 2) & 0x33333333) | ((qh << 1) & 0x04040404) | ((extra32 << 2) & 0x08080808) + | ((qh << 3) & 0x40404040) | ((extra32 << 4) & 0x80808080); + int2 v1 = get_int_from_table_16(val1, iq3nl_values); + int2 v2 = get_int_from_table_16(val2, iq3nl_values); - extra >>= 4; - qh >>= 4; + extra32 >>= 4; + qh >>= 4; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 0] = val0; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 8] = val1; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 16] = val2; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 24] = val3; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 8] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 16] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 24] = v2.y; #else - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = val0; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = val1; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = val2; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = val3; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y; #endif // INT8_MMA_AVAILABLE } diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu index a588969fb..b828055c9 100644 --- a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_r4.cu @@ -16,8 +16,6 @@ template static __device__ __forceinlin const int kqsx = threadIdx.x/4; // 0...7 -> block of 32 - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { int i = i0 + 4*threadIdx.y + threadIdx.x%4; @@ -37,29 +35,25 @@ template static __device__ __forceinlin #pragma unroll for (int l = 0; l < 2; ++l) { - auto values_l = iq3k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 6); + //auto values_l = iq3k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 6); + uint32_t extra32 = uint32_t((bxi->extra[ir+4*l] >> kqsx) & 1) * 0x88888888; const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l); - aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404); - aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404); - aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404); - aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404); - - int val0 = int_from_table_2(aux8+ 0, values_l); - int val1 = int_from_table_2(aux8+ 4, values_l); - int val2 = int_from_table_2(aux8+ 8, values_l); - int val3 = int_from_table_2(aux8+12, values_l); + uint32_t val1 = ((ql >> 0) & 0x33333333) | extra32 | ((qh << 2) & 0x04040404) | ((qh << 4) & 0x40404040); + uint32_t val2 = ((ql >> 2) & 0x33333333) | extra32 | ((qh << 1) & 0x04040404) | ((qh << 3) & 0x40404040); + int2 v1 = get_int_from_table_16(val1, iq3nl_values); + int2 v2 = get_int_from_table_16(val2, iq3nl_values); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = val0; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = val1; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = val2; - x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = val3; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = v2.y; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = val0; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = val1; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = val2; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = val3; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = v2.y; #endif // INT8_MMA_AVAILABLE qh >>= 4;