Skip to content

Commit 6403f19

Browse files
committed
refacor: _recurrent -> _recr for brevity
It just _happens_ to have the same number of letters as _attn! Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 8e39e04 commit 6403f19

File tree

3 files changed

+50
-50
lines changed

3 files changed

+50
-50
lines changed

src/llama-graph.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,15 +409,15 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
409409
mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
410410
}
411411

412-
const int64_t n_rs = mem_state->get_state_recurrent()->get_n_rs();
412+
const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
413413

414414
if (s_copy) {
415415
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
416416
int32_t * data = (int32_t *) s_copy->data;
417417

418418
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
419419
for (uint32_t i = 0; i < n_rs; ++i) {
420-
data[i] = mem_state->get_state_recurrent()->s_copy(i);
420+
data[i] = mem_state->get_state_recr()->s_copy(i);
421421
}
422422
}
423423
}
@@ -1067,7 +1067,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10671067
}
10681068

10691069
{
1070-
const auto n_rs = mem_state->get_state_recurrent()->get_n_rs();
1070+
const auto n_rs = mem_state->get_state_recr()->get_n_rs();
10711071

10721072
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
10731073
ggml_set_input(inp->s_copy);
@@ -1584,7 +1584,7 @@ ggml_tensor * llm_graph_context::build_rs(
15841584
int32_t state_size,
15851585
int32_t n_seqs,
15861586
bool avoid_copies) const {
1587-
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recurrent();
1587+
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
15881588

15891589
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
15901590
}

src/llama-memory-hybrid.cpp

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ llama_memory_hybrid::llama_memory_hybrid(
2727
bool offload,
2828
/* layer filters */
2929
layer_filter_cb && filter_attn,
30-
layer_filter_cb && filter_recurrent) :
30+
layer_filter_cb && filter_recr) :
3131
hparams(model.hparams),
3232
mem_attn(new llama_kv_cache_unified(
3333
model,
@@ -44,11 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid(
4444
n_swa,
4545
swa_type
4646
)),
47-
mem_recurrent(new llama_memory_recurrent(
47+
mem_recr(new llama_memory_recurrent(
4848
model,
49-
filter_recurrent == nullptr ?
49+
filter_recr == nullptr ?
5050
[&](int32_t il) { return model.hparams.recurrent_layer(il); }
51-
: filter_recurrent,
51+
: filter_recr,
5252
type_r,
5353
type_s,
5454
offload,
@@ -77,7 +77,7 @@ llama_memory_state_ptr llama_memory_hybrid::init_batch(const llama_batch & batch
7777
}
7878

7979
// prepare the recurrent batches first
80-
if (!mem_recurrent->prepare(ubatches)) {
80+
if (!mem_recr->prepare(ubatches)) {
8181
// TODO: will the recurrent cache be in an undefined state at this point?
8282
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
8383
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@@ -108,82 +108,82 @@ bool llama_memory_hybrid::get_can_shift() const {
108108
}
109109

110110
void llama_memory_hybrid::clear(bool data) {
111-
mem_attn ->clear(data);
112-
mem_recurrent->clear(data);
111+
mem_attn->clear(data);
112+
mem_recr->clear(data);
113113
}
114114

115115
bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
116116
// Try removing from the recurrent cache first since it may fail. If it does
117117
// fail, the cache will not have been mutated.
118-
if (!mem_recurrent->seq_rm(seq_id, p0, p1)) {
118+
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
119119
return false;
120120
}
121121
return mem_attn->seq_rm(seq_id, p0, p1);
122122
}
123123

124124
void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
125-
mem_attn ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
126-
mem_recurrent->seq_cp(seq_id_src, seq_id_dst, p0, p1);
125+
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
126+
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
127127
}
128128

129129
void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
130-
mem_attn ->seq_keep(seq_id);
131-
mem_recurrent->seq_keep(seq_id);
130+
mem_attn->seq_keep(seq_id);
131+
mem_recr->seq_keep(seq_id);
132132
}
133133

134134
void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
135135
mem_attn->seq_add(seq_id, p0, p1, shift);
136-
mem_recurrent->seq_add(seq_id, p0, p1, shift);
136+
mem_recr->seq_add(seq_id, p0, p1, shift);
137137
}
138138

139139
void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
140-
mem_attn ->seq_div(seq_id, p0, p1, d);
141-
mem_recurrent->seq_div(seq_id, p0, p1, d);
140+
mem_attn->seq_div(seq_id, p0, p1, d);
141+
mem_recr->seq_div(seq_id, p0, p1, d);
142142
}
143143

144144
llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
145145
// the min of the total cache is the max of the two caches' min values
146-
return std::max(mem_attn->seq_pos_min(seq_id), mem_recurrent->seq_pos_min(seq_id));
146+
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
147147
}
148148

149149
llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
150150
// the max of the total cache is the min of the two caches' max values
151-
return std::min(mem_attn->seq_pos_max(seq_id), mem_recurrent->seq_pos_max(seq_id));
151+
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
152152
}
153153

154154
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
155-
mem_attn ->state_write(io, seq_id);
156-
mem_recurrent->state_write(io, seq_id);
155+
mem_attn->state_write(io, seq_id);
156+
mem_recr->state_write(io, seq_id);
157157
}
158158

159159
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
160-
mem_attn ->state_read(io, seq_id);
161-
mem_recurrent->state_read(io, seq_id);
160+
mem_attn->state_read(io, seq_id);
161+
mem_recr->state_read(io, seq_id);
162162
}
163163

164164
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
165165
return mem_attn.get();
166166
}
167167

168-
llama_memory_recurrent * llama_memory_hybrid::get_mem_recurrent() const {
169-
return mem_recurrent.get();
168+
llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
169+
return mem_recr.get();
170170
}
171171

172172
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
173173

174174
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
175-
state_attn (mem->get_mem_attn ()->init_full()),
176-
state_recurrent(mem->get_mem_recurrent()->init_full()),
177-
status(llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status())) {
175+
state_attn(mem->get_mem_attn()->init_full()),
176+
state_recr(mem->get_mem_recr()->init_full()),
177+
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
178178
}
179179

180180
llama_memory_hybrid_state::llama_memory_hybrid_state(
181181
llama_memory_hybrid * mem,
182182
llama_context * lctx,
183183
bool optimize) :
184-
state_attn (mem->get_mem_attn ()->init_update(lctx, optimize)),
185-
state_recurrent(mem->get_mem_recurrent()->init_update(lctx, optimize)),
186-
status(llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status())) {
184+
state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
185+
state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
186+
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
187187
}
188188

189189
llama_memory_hybrid_state::llama_memory_hybrid_state(
@@ -194,16 +194,16 @@ llama_memory_hybrid_state::llama_memory_hybrid_state(
194194
sbatch(std::move(sbatch)),
195195
ubatches(std::move(ubatches)),
196196
// note: here we copy the ubatches. not sure if this is ideal
197-
state_attn (new llama_kv_cache_unified_state(mem->get_mem_attn(), {}, std::move(heads_attn), this->ubatches)),
198-
state_recurrent(new llama_memory_recurrent_state(mem->get_mem_recurrent(), {}, this->ubatches)),
197+
state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), {}, std::move(heads_attn), this->ubatches)),
198+
state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), {}, this->ubatches)),
199199
status(LLAMA_MEMORY_STATUS_SUCCESS) {
200200
}
201201

202202
bool llama_memory_hybrid_state::next() {
203203
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
204204

205-
state_attn ->next();
206-
state_recurrent->next();
205+
state_attn->next();
206+
state_recr->next();
207207

208208
if (++i_next >= ubatches.size()) {
209209
return false;
@@ -217,8 +217,8 @@ bool llama_memory_hybrid_state::apply() {
217217

218218
bool res = true;
219219

220-
res = res & state_attn ->apply();
221-
res = res & state_recurrent->apply();
220+
res = res & state_attn->apply();
221+
res = res & state_recr->apply();
222222

223223
return res;
224224
}
@@ -242,6 +242,6 @@ const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn()
242242
return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
243243
}
244244

245-
const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recurrent() const {
246-
return static_cast<const llama_memory_recurrent_state *>(state_recurrent.get());
245+
const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
246+
return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
247247
}

src/llama-memory-hybrid.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class llama_memory_hybrid : public llama_memory_i {
4040
uint32_t n_seq_max,
4141
bool offload,
4242
/* layer filters */
43-
layer_filter_cb && filter_attn = nullptr,
44-
layer_filter_cb && filter_recurrent = nullptr);
43+
layer_filter_cb && filter_attn = nullptr,
44+
layer_filter_cb && filter_recr = nullptr);
4545

4646
~llama_memory_hybrid() = default;
4747

@@ -80,14 +80,14 @@ class llama_memory_hybrid : public llama_memory_i {
8080
// llama_memory_hybrid specific API
8181
//
8282

83-
llama_kv_cache_unified * get_mem_attn () const;
84-
llama_memory_recurrent * get_mem_recurrent() const;
83+
llama_kv_cache_unified * get_mem_attn() const;
84+
llama_memory_recurrent * get_mem_recr() const;
8585

8686
private:
8787
const llama_hparams & hparams;
8888

89-
const std::unique_ptr<llama_kv_cache_unified> mem_attn;
90-
const std::unique_ptr<llama_memory_recurrent> mem_recurrent;
89+
const std::unique_ptr<llama_kv_cache_unified> mem_attn;
90+
const std::unique_ptr<llama_memory_recurrent> mem_recr;
9191
};
9292

9393
class llama_memory_hybrid_state : public llama_memory_state_i {
@@ -125,8 +125,8 @@ class llama_memory_hybrid_state : public llama_memory_state_i {
125125
// llama_memory_hybrid_state
126126
//
127127

128-
const llama_kv_cache_unified_state * get_state_attn () const;
129-
const llama_memory_recurrent_state * get_state_recurrent() const;
128+
const llama_kv_cache_unified_state * get_state_attn() const;
129+
const llama_memory_recurrent_state * get_state_recr() const;
130130

131131
private:
132132
llama_sbatch sbatch;
@@ -137,7 +137,7 @@ class llama_memory_hybrid_state : public llama_memory_state_i {
137137
std::vector<llama_ubatch> ubatches;
138138

139139
const llama_memory_state_ptr state_attn;
140-
const llama_memory_state_ptr state_recurrent;
140+
const llama_memory_state_ptr state_recr;
141141

142142
const llama_memory_status status;
143143
};

0 commit comments

Comments
 (0)