Skip to content

Commit ad07a3b

Browse files
author
ochafik
committed
Merge remote-tracking branch 'origin/master' into tool-diffs
2 parents c879a57 + c753d7b commit ad07a3b

File tree

12 files changed

+295
-34
lines changed

12 files changed

+295
-34
lines changed

common/chat.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,15 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
116116
return diffs;
117117
}
118118

119+
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
120+
auto time = std::chrono::system_clock::to_time_t(now);
121+
auto local_time = *std::localtime(&time);
122+
std::ostringstream ss;
123+
ss << std::put_time(&local_time, format.c_str());
124+
auto res = ss.str();
125+
return res;
126+
}
127+
119128
typedef minja::chat_template common_chat_template;
120129

121130
struct common_chat_templates {
@@ -1381,7 +1390,6 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
13811390
common_chat_params data;
13821391

13831392
if (!inputs.tools.is_null()) {
1384-
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
13851393
std::string python_code_argument_name;
13861394
auto has_raw_python = false;
13871395

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8519,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
85198519

85208520
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
85218521
assert(n % QK_K == 0);
8522+
#ifdef __ARM_FEATURE_MATMUL_INT8
8523+
assert((nrc == 2) || (nrc == 1));
8524+
#else
85228525
assert(nrc == 1);
8526+
#endif
85238527
UNUSED(nrc);
85248528
UNUSED(bx);
85258529
UNUSED(by);
@@ -8530,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
85308534

85318535
const int nb = n / QK_K;
85328536

8537+
#if defined(__ARM_FEATURE_MATMUL_INT8)
8538+
if (nrc == 2) {
8539+
const block_q6_K * GGML_RESTRICT x0 = x;
8540+
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
8541+
const block_q8_K * GGML_RESTRICT y0 = y;
8542+
const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
8543+
8544+
float32x4_t vfsum = vdupq_n_f32(0.0f);
8545+
8546+
for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
8547+
const uint8_t * GGML_RESTRICT ql0 = x0->ql;
8548+
const uint8_t * GGML_RESTRICT ql1 = x1->ql;
8549+
const uint8_t * GGML_RESTRICT qh0 = x0->qh;
8550+
const uint8_t * GGML_RESTRICT qh1 = x1->qh;
8551+
const int8_t * GGML_RESTRICT qy0 = y0->qs;
8552+
const int8_t * GGML_RESTRICT qy1 = y1->qs;
8553+
8554+
const uint8x16_t mone = vdupq_n_u8(0x30);
8555+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
8556+
8557+
int32x4_t visum = vdupq_n_s32(0);
8558+
8559+
// process 8 blocks per iteration, totally 16 blocks
8560+
for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
8561+
int8x16_t vx0[8], vx1[8];
8562+
8563+
// de-quantize vx0[8]
8564+
{
8565+
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
8566+
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
8567+
8568+
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
8569+
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
8570+
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
8571+
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
8572+
8573+
vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
8574+
vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
8575+
vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
8576+
vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
8577+
8578+
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
8579+
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
8580+
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
8581+
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
8582+
8583+
vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
8584+
vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
8585+
vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
8586+
vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
8587+
}
8588+
8589+
// de-quantize vx1[8]
8590+
{
8591+
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
8592+
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
8593+
8594+
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
8595+
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
8596+
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
8597+
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
8598+
8599+
vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
8600+
vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
8601+
vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
8602+
vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
8603+
8604+
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
8605+
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
8606+
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
8607+
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
8608+
8609+
vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
8610+
vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
8611+
vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
8612+
vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
8613+
}
8614+
8615+
// process 16 elements (one block with same scale) per iteration
8616+
// - vx = concat(ql, qh) - 32
8617+
// - r1,r2,r3,r4 = smmla(vx, vy)
8618+
for (int k = 0; k < 8; ++k) {
8619+
const int blk = j * 8 + k;
8620+
8621+
const int8x16_t vy0 = vld1q_s8(qy0);
8622+
const int8x16_t vy1 = vld1q_s8(qy1);
8623+
qy0 += 16;
8624+
qy1 += 16;
8625+
8626+
const int32x4_t block_scale = {
8627+
x0->scales[blk],
8628+
x0->scales[blk],
8629+
x1->scales[blk],
8630+
x1->scales[blk],
8631+
};
8632+
8633+
// calculate four results at once with outer product
8634+
const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
8635+
const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
8636+
const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
8637+
const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
8638+
int32x4_t vr = vdupq_n_s32(0);
8639+
vr = vmmlaq_s32(vr, vx_l, vy_l);
8640+
vr = vmmlaq_s32(vr, vx_h, vy_h);
8641+
8642+
// apply block scale, will NOT overflow
8643+
// block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
8644+
visum = vmlaq_s32(visum, vr, block_scale);
8645+
}
8646+
}
8647+
8648+
// adjust bias, apply superblock scale
8649+
{
8650+
int32_t bias[4];
8651+
#ifdef __ARM_FEATURE_SVE
8652+
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
8653+
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
8654+
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
8655+
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
8656+
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
8657+
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
8658+
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
8659+
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
8660+
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
8661+
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
8662+
const svint64_t zero = svdup_n_s64(0);
8663+
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
8664+
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
8665+
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
8666+
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
8667+
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
8668+
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
8669+
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
8670+
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
8671+
#else
8672+
// NEON doesn't support int16 dot product, fallback to separated mul and add
8673+
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
8674+
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
8675+
8676+
int8x16_t scales_s8 = vld1q_s8(x0->scales);
8677+
const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
8678+
scales_s8 = vld1q_s8(x1->scales);
8679+
const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
8680+
8681+
int32x4_t prod;
8682+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
8683+
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
8684+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
8685+
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
8686+
bias[0] = vaddvq_s32(prod);
8687+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
8688+
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
8689+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
8690+
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
8691+
bias[1] = vaddvq_s32(prod);
8692+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
8693+
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
8694+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
8695+
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
8696+
bias[2] = vaddvq_s32(prod);
8697+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
8698+
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
8699+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
8700+
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
8701+
bias[3] = vaddvq_s32(prod);
8702+
8703+
#endif
8704+
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
8705+
8706+
const float32x4_t superblock_scale = {
8707+
GGML_FP16_TO_FP32(x0->d) * y0->d,
8708+
GGML_FP16_TO_FP32(x0->d) * y1->d,
8709+
GGML_FP16_TO_FP32(x1->d) * y0->d,
8710+
GGML_FP16_TO_FP32(x1->d) * y1->d,
8711+
};
8712+
8713+
visum = vsubq_s32(visum, vibias);
8714+
vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
8715+
}
8716+
}
8717+
8718+
// vfsum = ABCD -> ACBD
8719+
// AC -> s, BD -> (s+bs)
8720+
vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
8721+
vst1_f32(s, vget_low_f32 (vfsum));
8722+
vst1_f32(s + bs, vget_high_f32(vfsum));
8723+
8724+
return;
8725+
}
8726+
#endif
8727+
85338728
#ifdef __ARM_FEATURE_SVE
85348729
const int vector_length = ggml_cpu_get_sve_cnt()*8;
85358730
float sum = 0;

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
282282
.from_float = quantize_row_q6_K,
283283
.vec_dot = ggml_vec_dot_q6_K_q8_K,
284284
.vec_dot_type = GGML_TYPE_Q8_K,
285+
#if defined (__ARM_FEATURE_MATMUL_INT8)
286+
.nrows = 2,
287+
#else
285288
.nrows = 1,
289+
#endif
286290
},
287291
[GGML_TYPE_IQ2_XXS] = {
288292
.from_float = NULL,

ggml/src/ggml-cuda/mmq.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ void ggml_cuda_mul_mat_q(
122122
const int64_t s13 = src1->nb[3] / ts_src1;
123123
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
124124
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
125+
CUDA_CHECK(cudaGetLastError());
125126
}
126127

127128
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
@@ -205,6 +206,7 @@ void ggml_cuda_mul_mat_q(
205206
const int64_t s13 = src1->nb[2] / ts_src1;
206207
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
207208
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
209+
CUDA_CHECK(cudaGetLastError());
208210
}
209211

210212
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));

ggml/src/ggml-cuda/quantize.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1(
5656
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
5757
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
5858

59-
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
59+
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
6060

6161
if (i0 >= ne0) {
6262
return;
6363
}
6464

65-
const int64_t i1 = blockIdx.y;
65+
const int64_t i1 = blockIdx.x;
6666
const int64_t i2 = blockIdx.z % ne2;
6767
const int64_t i3 = blockIdx.z / ne2;
6868

@@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1(
7575

7676
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
7777

78-
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
79-
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
78+
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
79+
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
8080
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
8181

8282
// Load 4 floats per thread and calculate max. abs. value between them:
@@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda(
166166
GGML_ASSERT(ne00 % 4 == 0);
167167
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
168168

169-
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
170-
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
169+
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
170+
const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
171+
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
171172
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
172173
switch (mmq_get_q8_1_ds_layout(type_src0)) {
173174
case MMQ_Q8_1_DS_LAYOUT_D4:

src/llama-context.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,10 +1704,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
17041704
}
17051705
}
17061706

1707-
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
17081707
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
17091708

1710-
kv_self->state_write(io);
1709+
if (kv_self != nullptr) {
1710+
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1711+
kv_self->state_write(io);
1712+
}
17111713

17121714
return io.n_bytes();
17131715
}

src/llama-kv-cache.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,13 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
441441

442442
void llama_kv_cache_unified::set_full() {
443443
n = size;
444+
445+
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
446+
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
447+
// we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
448+
// setting it to 0 is the simplest way to achieve that
449+
// ref: https://github.com/ggml-org/llama.cpp/issues/13359
450+
head = 0;
444451
}
445452

446453
llama_sbatch llama_kv_cache_unified::sbatch_init(
@@ -1712,6 +1719,7 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
17121719

17131720
void llama_kv_cache_recurrent::set_full() {
17141721
n = size;
1722+
head = 0;
17151723
}
17161724

17171725
llama_sbatch llama_kv_cache_recurrent::sbatch_init(

src/llama-kv-cache.h

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
171171
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
172172
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
173173

174-
// Note: The value of head isn't only used to optimize searching
175-
// for a free KV slot. llama_decode_impl also uses it, so it
176-
// cannot be freely changed after a slot has been allocated.
177-
uint32_t head = 0;
178-
uint32_t size = 0;
174+
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
175+
uint32_t size = 0; // total number of cells, shared across all sequences
179176
uint32_t used = 0; // used cells (i.e. at least one seq_id)
180177

181178
// computed before each graph build
@@ -343,11 +340,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
343340
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
344341
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
345342

346-
// Note: The value of head isn't only used to optimize searching
347-
// for a free KV slot. llama_decode_impl also uses it, so it
348-
// cannot be freely changed after a slot has been allocated.
349-
uint32_t head = 0;
350-
uint32_t size = 0;
343+
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
344+
uint32_t size = 0; // total number of cells, shared across all sequences
351345
uint32_t used = 0; // used cells (i.e. at least one seq_id)
352346

353347
// computed before each graph build

0 commit comments

Comments
 (0)