@@ -27,7 +27,7 @@ llama_memory_hybrid::llama_memory_hybrid(
2727 bool offload,
2828 /* layer filters */
2929 layer_filter_cb && filter_attn,
30- layer_filter_cb && filter_recurrent ) :
30+ layer_filter_cb && filter_recr ) :
3131 hparams(model.hparams),
3232 mem_attn(new llama_kv_cache_unified(
3333 model,
@@ -44,11 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid(
4444 n_swa,
4545 swa_type
4646 )),
47- mem_recurrent (new llama_memory_recurrent(
47+ mem_recr (new llama_memory_recurrent(
4848 model,
49- filter_recurrent == nullptr ?
49+ filter_recr == nullptr ?
5050 [&](int32_t il) { return model.hparams .recurrent_layer (il); }
51- : filter_recurrent ,
51+ : filter_recr ,
5252 type_r,
5353 type_s,
5454 offload,
@@ -77,7 +77,7 @@ llama_memory_state_ptr llama_memory_hybrid::init_batch(const llama_batch & batch
7777 }
7878
7979 // prepare the recurrent batches first
80- if (!mem_recurrent ->prepare (ubatches)) {
80+ if (!mem_recr ->prepare (ubatches)) {
8181 // TODO: will the recurrent cache be in an undefined state at this point?
8282 LLAMA_LOG_ERROR (" %s: failed to prepare recurrent ubatches\n " , __func__);
8383 return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@@ -108,82 +108,82 @@ bool llama_memory_hybrid::get_can_shift() const {
108108}
109109
110110void llama_memory_hybrid::clear (bool data) {
111- mem_attn ->clear (data);
112- mem_recurrent ->clear (data);
111+ mem_attn->clear (data);
112+ mem_recr ->clear (data);
113113}
114114
115115bool llama_memory_hybrid::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
116116 // Try removing from the recurrent cache first since it may fail. If it does
117117 // fail, the cache will not have been mutated.
118- if (!mem_recurrent ->seq_rm (seq_id, p0, p1)) {
118+ if (!mem_recr ->seq_rm (seq_id, p0, p1)) {
119119 return false ;
120120 }
121121 return mem_attn->seq_rm (seq_id, p0, p1);
122122}
123123
124124void llama_memory_hybrid::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
125- mem_attn ->seq_cp (seq_id_src, seq_id_dst, p0, p1);
126- mem_recurrent ->seq_cp (seq_id_src, seq_id_dst, p0, p1);
125+ mem_attn->seq_cp (seq_id_src, seq_id_dst, p0, p1);
126+ mem_recr ->seq_cp (seq_id_src, seq_id_dst, p0, p1);
127127}
128128
129129void llama_memory_hybrid::seq_keep (llama_seq_id seq_id) {
130- mem_attn ->seq_keep (seq_id);
131- mem_recurrent ->seq_keep (seq_id);
130+ mem_attn->seq_keep (seq_id);
131+ mem_recr ->seq_keep (seq_id);
132132}
133133
134134void llama_memory_hybrid::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
135135 mem_attn->seq_add (seq_id, p0, p1, shift);
136- mem_recurrent ->seq_add (seq_id, p0, p1, shift);
136+ mem_recr ->seq_add (seq_id, p0, p1, shift);
137137}
138138
139139void llama_memory_hybrid::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
140- mem_attn ->seq_div (seq_id, p0, p1, d);
141- mem_recurrent ->seq_div (seq_id, p0, p1, d);
140+ mem_attn->seq_div (seq_id, p0, p1, d);
141+ mem_recr ->seq_div (seq_id, p0, p1, d);
142142}
143143
144144llama_pos llama_memory_hybrid::seq_pos_min (llama_seq_id seq_id) const {
145145 // the min of the total cache is the max of the two caches' min values
146- return std::max (mem_attn->seq_pos_min (seq_id), mem_recurrent ->seq_pos_min (seq_id));
146+ return std::max (mem_attn->seq_pos_min (seq_id), mem_recr ->seq_pos_min (seq_id));
147147}
148148
149149llama_pos llama_memory_hybrid::seq_pos_max (llama_seq_id seq_id) const {
150150 // the max of the total cache is the min of the two caches' max values
151- return std::min (mem_attn->seq_pos_max (seq_id), mem_recurrent ->seq_pos_max (seq_id));
151+ return std::min (mem_attn->seq_pos_max (seq_id), mem_recr ->seq_pos_max (seq_id));
152152}
153153
154154void llama_memory_hybrid::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
155- mem_attn ->state_write (io, seq_id);
156- mem_recurrent ->state_write (io, seq_id);
155+ mem_attn->state_write (io, seq_id);
156+ mem_recr ->state_write (io, seq_id);
157157}
158158
159159void llama_memory_hybrid::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
160- mem_attn ->state_read (io, seq_id);
161- mem_recurrent ->state_read (io, seq_id);
160+ mem_attn->state_read (io, seq_id);
161+ mem_recr ->state_read (io, seq_id);
162162}
163163
164164llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn () const {
165165 return mem_attn.get ();
166166}
167167
168- llama_memory_recurrent * llama_memory_hybrid::get_mem_recurrent () const {
169- return mem_recurrent .get ();
168+ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr () const {
169+ return mem_recr .get ();
170170}
171171
172172llama_memory_hybrid_state::llama_memory_hybrid_state (llama_memory_status status) : status(status) {}
173173
174174llama_memory_hybrid_state::llama_memory_hybrid_state (llama_memory_hybrid * mem) :
175- state_attn (mem->get_mem_attn ()->init_full()),
176- state_recurrent (mem->get_mem_recurrent ()->init_full()),
177- status(llama_memory_status_combine(state_attn->get_status (), state_recurrent ->get_status())) {
175+ state_attn(mem->get_mem_attn ()->init_full()),
176+ state_recr (mem->get_mem_recr ()->init_full()),
177+ status(llama_memory_status_combine(state_attn->get_status (), state_recr ->get_status())) {
178178}
179179
180180llama_memory_hybrid_state::llama_memory_hybrid_state (
181181 llama_memory_hybrid * mem,
182182 llama_context * lctx,
183183 bool optimize) :
184- state_attn (mem->get_mem_attn ()->init_update(lctx, optimize)),
185- state_recurrent (mem->get_mem_recurrent ()->init_update(lctx, optimize)),
186- status(llama_memory_status_combine(state_attn->get_status (), state_recurrent ->get_status())) {
184+ state_attn(mem->get_mem_attn ()->init_update(lctx, optimize)),
185+ state_recr (mem->get_mem_recr ()->init_update(lctx, optimize)),
186+ status(llama_memory_status_combine(state_attn->get_status (), state_recr ->get_status())) {
187187}
188188
189189llama_memory_hybrid_state::llama_memory_hybrid_state (
@@ -194,16 +194,16 @@ llama_memory_hybrid_state::llama_memory_hybrid_state(
194194 sbatch(std::move(sbatch)),
195195 ubatches(std::move(ubatches)),
196196 // note: here we copy the ubatches. not sure if this is ideal
197- state_attn (new llama_kv_cache_unified_state(mem->get_mem_attn (), {}, std::move(heads_attn), this->ubatches)),
198- state_recurrent (new llama_memory_recurrent_state(mem->get_mem_recurrent (), {}, this ->ubatches)),
197+ state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn (), {}, std::move(heads_attn), this->ubatches)),
198+ state_recr (new llama_memory_recurrent_state(mem->get_mem_recr (), {}, this ->ubatches)),
199199 status(LLAMA_MEMORY_STATUS_SUCCESS) {
200200}
201201
202202bool llama_memory_hybrid_state::next () {
203203 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
204204
205- state_attn ->next ();
206- state_recurrent ->next ();
205+ state_attn->next ();
206+ state_recr ->next ();
207207
208208 if (++i_next >= ubatches.size ()) {
209209 return false ;
@@ -217,8 +217,8 @@ bool llama_memory_hybrid_state::apply() {
217217
218218 bool res = true ;
219219
220- res = res & state_attn ->apply ();
221- res = res & state_recurrent ->apply ();
220+ res = res & state_attn->apply ();
221+ res = res & state_recr ->apply ();
222222
223223 return res;
224224}
@@ -242,6 +242,6 @@ const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn()
242242 return static_cast <const llama_kv_cache_unified_state *>(state_attn.get ());
243243}
244244
245- const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recurrent () const {
246- return static_cast <const llama_memory_recurrent_state *>(state_recurrent .get ());
245+ const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr () const {
246+ return static_cast <const llama_memory_recurrent_state *>(state_recr .get ());
247247}
0 commit comments