Skip to content

Commit 8488f5e

Browse files
committed
refactor: Make status and child states const in hybrid and iswa
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent c80e68c commit 8488f5e

File tree

4 files changed

+35
-42
lines changed

4 files changed

+35
-42
lines changed

src/llama-kv-cache-hybrid-recurrent.cpp

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -171,35 +171,32 @@ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() c
171171

172172
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) : status(status) {}
173173

174-
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv)
175-
: status(LLAMA_MEMORY_STATUS_SUCCESS) {
176-
state_attn = kv->get_kv_attn ()->init_full();
177-
state_recurrent = kv->get_kv_recurrent()->init_full();
178-
179-
status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status());
174+
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv) :
175+
state_attn (kv->get_kv_attn ()->init_full()),
176+
state_recurrent(kv->get_kv_recurrent()->init_full()),
177+
status(llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status())) {
180178
}
181179

182180
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
183181
llama_kv_cache_hybrid_recurrent * kv,
184182
llama_context * lctx,
185-
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
186-
state_attn = kv->get_kv_attn ()->init_update(lctx, optimize);
187-
state_recurrent = kv->get_kv_recurrent()->init_update(lctx, optimize);
188-
189-
status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status());
183+
bool optimize) :
184+
state_attn (kv->get_kv_attn ()->init_update(lctx, optimize)),
185+
state_recurrent(kv->get_kv_recurrent()->init_update(lctx, optimize)),
186+
status(llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status())) {
190187
}
191188

192189
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
193190
llama_kv_cache_hybrid_recurrent * kv,
194191
llama_sbatch sbatch,
195192
std::vector<uint32_t> heads_attn,
196-
std::vector<llama_ubatch> ubatches)
197-
: status(LLAMA_MEMORY_STATUS_SUCCESS),
193+
std::vector<llama_ubatch> ubatches) :
198194
sbatch(std::move(sbatch)),
199-
ubatches(std::move(ubatches)) {
195+
ubatches(std::move(ubatches)),
200196
// note: here we copy the ubatches. not sure if this is ideal
201-
state_attn .reset(new llama_kv_cache_unified_state (kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches));
202-
state_recurrent.reset(new llama_kv_cache_recurrent_state(kv->get_kv_recurrent(), {}, this->ubatches));
197+
state_attn (new llama_kv_cache_unified_state (kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)),
198+
state_recurrent(new llama_kv_cache_recurrent_state(kv->get_kv_recurrent(), {}, this->ubatches)),
199+
status(LLAMA_MEMORY_STATUS_SUCCESS) {
203200
}
204201

205202
bool llama_kv_cache_hybrid_recurrent_state::next() {

src/llama-kv-cache-hybrid-recurrent.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,15 @@ class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i {
130130
const llama_kv_cache_recurrent_state * get_state_recurrent() const;
131131

132132
private:
133-
llama_memory_status status;
134-
135133
llama_sbatch sbatch;
136134

137135
// the index of the next ubatch to process
138136
size_t i_next = 0;
139137

140138
std::vector<llama_ubatch> ubatches;
141139

142-
llama_memory_state_ptr state_attn;
143-
llama_memory_state_ptr state_recurrent;
140+
const llama_memory_state_ptr state_attn;
141+
const llama_memory_state_ptr state_recurrent;
142+
143+
const llama_memory_status status;
144144
};

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -197,37 +197,33 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
197197
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
198198

199199
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
200-
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
201-
state_base = kv->get_base()->init_full();
202-
state_swa = kv->get_swa ()->init_full();
203-
204-
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
200+
llama_kv_cache_unified_iswa * kv) :
201+
state_base(kv->get_base()->init_full()),
202+
state_swa (kv->get_swa ()->init_full()),
203+
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
205204
}
206205

207206
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
208207
llama_kv_cache_unified_iswa * kv,
209208
llama_context * lctx,
210-
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
211-
state_base = kv->get_base()->init_update(lctx, optimize);
212-
state_swa = kv->get_swa ()->init_update(lctx, optimize);
213-
214-
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
209+
bool optimize) :
210+
state_base(kv->get_base()->init_update(lctx, optimize)),
211+
state_swa (kv->get_swa ()->init_update(lctx, optimize)),
212+
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
215213
}
216214

217215
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
218216
llama_kv_cache_unified_iswa * kv,
219217
llama_sbatch sbatch,
220218
std::vector<uint32_t> heads_base,
221219
std::vector<uint32_t> heads_swa,
222-
std::vector<llama_ubatch> ubatches)
223-
: status(LLAMA_MEMORY_STATUS_SUCCESS),
224-
sbatch(std::move(sbatch)),
225-
ubatches(std::move(ubatches)) {
220+
std::vector<llama_ubatch> ubatches) :
221+
sbatch(std::move(sbatch)),
222+
ubatches(std::move(ubatches)),
226223
// note: here we copy the ubatches. not sure if this is ideal
227-
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
228-
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
229-
230-
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
224+
state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)),
225+
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)),
226+
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
231227
}
232228

233229
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;

src/llama-kv-cache-unified-iswa.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,6 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
117117
const llama_kv_cache_unified_state * get_swa() const;
118118

119119
private:
120-
llama_memory_status status;
121-
122120
//llama_kv_cache_unified_iswa * kv;
123121

124122
llama_sbatch sbatch;
@@ -128,6 +126,8 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
128126

129127
std::vector<llama_ubatch> ubatches;
130128

131-
llama_memory_state_ptr state_base;
132-
llama_memory_state_ptr state_swa;
129+
const llama_memory_state_ptr state_base;
130+
const llama_memory_state_ptr state_swa;
131+
132+
const llama_memory_status status;
133133
};

0 commit comments

Comments
 (0)