Skip to content

Commit 2503fc7

Browse files
committed
cont : make impl more private
ggml-ci
1 parent 2be10dc commit 2503fc7

File tree

3 files changed

+50
-50
lines changed

3 files changed

+50
-50
lines changed

src/llama-graph.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
10231023

10241024
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
10251025

1026-
const auto n_kv = kv_self->n;
1026+
const auto n_kv = kv_self->get_n();
10271027

10281028
auto & cur = inp->pos_bucket;
10291029

@@ -1230,7 +1230,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12301230
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
12311231

12321232
{
1233-
const auto n_kv = kv_self->n;
1233+
const auto n_kv = kv_self->get_n();
12341234

12351235
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
12361236
//cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1242,7 +1242,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12421242
if (hparams.n_swa_pattern > 1) {
12431243
GGML_ASSERT(hparams.n_swa > 0);
12441244

1245-
const auto n_kv = kv_self->n;
1245+
const auto n_kv = kv_self->get_n();
12461246

12471247
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
12481248
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);

src/llama-kv-cache.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
151151
}
152152

153153
void llama_kv_cache_unified::clear() {
154-
for (int32_t i = 0; i < (int32_t) size; ++i) {
154+
for (uint32_t i = 0; i < size; ++i) {
155155
cells[i].pos = -1;
156156
cells[i].seq_id.clear();
157157
}
@@ -561,8 +561,8 @@ bool llama_kv_cache_unified::get_can_shift() const {
561561
return can_shift;
562562
}
563563

564-
const llama_kv_cache_unified::kv_layer & llama_kv_cache_unified::get_layer(int32_t il) const {
565-
return layers[il];
564+
uint32_t llama_kv_cache_unified::get_n() const {
565+
return n;
566566
}
567567

568568
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {

src/llama-kv-cache.h

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -90,36 +90,6 @@ struct llama_kv_cache_guard {
9090
// TODO: add notion of max sequences
9191
class llama_kv_cache_unified : public llama_kv_cache {
9292
public:
93-
// commit/restore cache
94-
struct slot_range {
95-
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
96-
uint32_t c1 = 0;
97-
};
98-
99-
struct kv_cell {
100-
llama_pos pos = -1;
101-
llama_pos delta = 0;
102-
103-
std::set<llama_seq_id> seq_id;
104-
105-
bool has_seq_id(const llama_seq_id & id) const {
106-
return seq_id.find(id) != seq_id.end();
107-
}
108-
109-
bool is_empty() const {
110-
return seq_id.empty();
111-
}
112-
113-
bool is_same_seq(const kv_cell & other) const {
114-
return seq_id == other.seq_id;
115-
}
116-
};
117-
118-
struct kv_layer {
119-
ggml_tensor * k = nullptr;
120-
ggml_tensor * v = nullptr;
121-
};
122-
12393
static uint32_t get_padding(const llama_cparams & cparams);
12494

12595
llama_kv_cache_unified(
@@ -133,16 +103,6 @@ class llama_kv_cache_unified : public llama_kv_cache {
133103

134104
~llama_kv_cache_unified() = default;
135105

136-
// Note: The value of head isn't only used to optimize searching
137-
// for a free KV slot. llama_decode_impl also uses it, so it
138-
// cannot be freely changed after a slot has been allocated.
139-
uint32_t head = 0;
140-
uint32_t size = 0;
141-
uint32_t used = 0; // used cells (i.e. at least one seq_id)
142-
143-
// computed before each graph build
144-
uint32_t n = 0;
145-
146106
//
147107
// llama_memory_i
148108
//
@@ -187,7 +147,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
187147

188148
bool get_can_shift() const override;
189149

190-
const kv_layer & get_layer(int32_t il) const;
150+
uint32_t get_n() const;
191151

192152
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
193153
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
@@ -210,12 +170,52 @@ class llama_kv_cache_unified : public llama_kv_cache {
210170
const llama_model & model;
211171
const llama_hparams & hparams;
212172

173+
// commit/restore cache
174+
struct slot_range {
175+
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
176+
uint32_t c1 = 0;
177+
};
178+
179+
struct kv_cell {
180+
llama_pos pos = -1;
181+
llama_pos delta = 0;
182+
183+
std::set<llama_seq_id> seq_id;
184+
185+
bool has_seq_id(const llama_seq_id & id) const {
186+
return seq_id.find(id) != seq_id.end();
187+
}
188+
189+
bool is_empty() const {
190+
return seq_id.empty();
191+
}
192+
193+
bool is_same_seq(const kv_cell & other) const {
194+
return seq_id == other.seq_id;
195+
}
196+
};
197+
198+
struct kv_layer {
199+
ggml_tensor * k = nullptr;
200+
ggml_tensor * v = nullptr;
201+
};
202+
213203
bool has_shift = false;
214204
bool do_defrag = false;
215205

216206
bool v_trans = true; // the value tensor is transposed
217207
bool can_shift = false;
218208

209+
// Note: The value of head isn't only used to optimize searching
210+
// for a free KV slot. llama_decode_impl also uses it, so it
211+
// cannot be freely changed after a slot has been allocated.
212+
uint32_t head = 0;
213+
uint32_t size = 0;
214+
uint32_t used = 0; // used cells (i.e. at least one seq_id)
215+
216+
// computed before each graph build
217+
uint32_t n = 0;
218+
219219
// required padding
220220
uint32_t padding = 1;
221221

@@ -279,9 +279,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
279279
// llama_kv_cache_unified_swa
280280
//
281281

282-
//class llama_kv_cache_unified_swa : public llama_kv_cache {
283-
//public:
284-
//};
282+
class llama_kv_cache_unified_swa : public llama_kv_cache {
283+
public:
284+
};
285285

286286
//
287287
// llama_kv_cache_recurrent

0 commit comments

Comments
 (0)