@@ -2583,7 +2583,7 @@ struct llama_hparams {
25832583 return n_embd_head_v * n_head_kv;
25842584 }
25852585
2586- uint32_t n_embd_r(uint32_t il) const { // dimension of the rolling state embeddings
2586+ uint32_t n_embd_r(uint32_t il) const { // dimension of the rolling state embeddings
25872587 // TODO: support using an SSM in place of the MLP of a Transformer
25882588 if (n_head_kv(il) != 0) { return 0; }
25892589 // corresponds to Mamba's conv_states size or RWKV's token_shift states size
@@ -2597,7 +2597,7 @@ struct llama_hparams {
25972597 }
25982598 }
25992599
2600- uint32_t n_embd_s(uint32_t il) const { // dimension of the recurrent state embeddings
2600+ uint32_t n_embd_s(uint32_t il) const { // dimension of the recurrent state embeddings
26012601 // TODO: support using an SSM in place of the MLP of a Transformer
26022602 if (n_head_kv(il) != 0) { return 0; }
26032603
@@ -2875,17 +2875,13 @@ struct llama_kv_self_cache {
28752875
28762876struct llama_rs_cell {
28772877 llama_pos pos = -1;
2878- int32_t src = -1; // copy source id (cleared next when -1)
2878+ int32_t src = -1; // copy source id (cleared next when -1)
28792879
28802880 std::set<llama_seq_id> seq_id;
28812881
2882- bool has_seq_id(const llama_seq_id & id) const {
2883- return seq_id.find(id) != seq_id.end();
2884- }
2882+ bool has_seq_id(const llama_seq_id & id) const { return seq_id.find(id) != seq_id.end(); }
28852883
2886- bool is_empty() const {
2887- return seq_id.empty();
2888- }
2884+ bool is_empty() const { return seq_id.empty(); }
28892885};
28902886
28912887struct llama_rs_seq_meta {
@@ -2895,46 +2891,45 @@ struct llama_rs_seq_meta {
28952891
28962892// ring-buffered tree of cached recurrent state data
28972893struct llama_rs_self_cache {
2898-
2899- uint32_t head = 0; // first state used for the last slot
2894+ uint32_t head = 0; // first state used for the last slot
29002895 uint32_t size = 0;
29012896 uint32_t used = 0;
29022897
29032898 // computed when finding a slot
2904- uint32_t n = 0; // range of states used for the last slot
2899+ uint32_t n = 0; // range of states used for the last slot
29052900
29062901 // with state models, a cell can hold the state for more than one past token
29072902 std::vector<llama_rs_cell> cells;
29082903
29092904 // find tail cells faster
2910- std::vector<llama_rs_seq_meta> seq_tails; // map seq_ids to cell ids
2905+ std::vector<llama_rs_seq_meta> seq_tails; // map seq_ids to cell ids
29112906
29122907 // per layer
29132908 // NOTE: the naming of r and s is arbitrary
2914- std::vector<struct ggml_tensor *> r_l; // rolling/shift states
2915- std::vector<struct ggml_tensor *> s_l; // ssm (recurrent) states
2909+ std::vector<struct ggml_tensor *> r_l; // rolling/shift states
2910+ std::vector<struct ggml_tensor *> s_l; // ssm (recurrent) states
29162911
29172912 // Inefficient, but thorough verification and rebuilding of the rs cache
29182913 // from only the cells list with `pos` and seq_ids.
29192914 // Should not be called in a hot loop except when desperate and/or debugging.
29202915 bool rebuild(bool debug) {
29212916 bool was_valid = true;
29222917 // skip for non-recurrent models
2923- if (size == 0) { return true; }
2918+ if (size == 0) {
2919+ return true;
2920+ }
29242921 // the source of truth is the cells list
29252922 // buffer sizes
29262923 if (size != cells.size()) {
29272924 if (debug) {
2928- LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n",
2929- __func__, cells.size(), size);
2925+ LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n", __func__, cells.size(), size);
29302926 }
29312927 cells.resize(size);
29322928 was_valid = false;
29332929 }
29342930 if (size != seq_tails.size()) {
29352931 if (debug) {
2936- LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n",
2937- __func__, seq_tails.size(), size);
2932+ LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n", __func__, seq_tails.size(), size);
29382933 }
29392934 seq_tails.resize(size);
29402935 was_valid = false;
@@ -2994,7 +2989,7 @@ struct llama_rs_self_cache {
29942989 for (uint32_t cell_id = 0; cell_id < size; ++cell_id) {
29952990 llama_rs_cell & cell = cells[cell_id];
29962991 if (cell.has_seq_id(seq_id)) {
2997- seq_cells.push_back({cell.pos, cell_id});
2992+ seq_cells.push_back({ cell.pos, cell_id });
29982993 }
29992994 }
30002995 // sort by pos and then by cell_id
@@ -3718,16 +3713,16 @@ static bool llama_kv_cache_init(
37183713 }
37193714
37203715 if (has_kv) {
3721- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i)* kv_size);
3722- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i)* kv_size);
3716+ ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i) * kv_size);
3717+ ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i) * kv_size);
37233718 ggml_format_name(k, "cache_k_l%d", i);
37243719 ggml_format_name(v, "cache_v_l%d", i);
37253720 cache.kv.k_l.push_back(k);
37263721 cache.kv.v_l.push_back(v);
37273722 }
37283723 if (has_rs) {
3729- ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i)* rs_size);
3730- ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i)* rs_size);
3724+ ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i) * rs_size);
3725+ ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i) * rs_size);
37313726 ggml_format_name(r, "cache_r_l%d", i);
37323727 ggml_format_name(s, "cache_s_l%d", i);
37333728 cache.rs.r_l.push_back(r);
@@ -4370,8 +4365,8 @@ struct llama_kv_slot_restorer {
43704365 bool do_restore = false;
43714366
43724367 explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
4373- old_state.head = cache.kv.head;
4374- old_state.n = cache.kv.n;
4368+ old_state.head = cache.kv.head;
4369+ old_state.n = cache.kv.n;
43754370 }
43764371
43774372 // saves a slot information for future restoration
@@ -4388,10 +4383,10 @@ struct llama_kv_slot_restorer {
43884383 // and rollback changes from all llama_kv_cache_find_slot calls
43894384 void restore(struct llama_kv_cache & cache) {
43904385 if (do_restore) {
4391- cache.kv.head = old_state.head;
4392- cache.kv.n = old_state.n;
4386+ cache.kv.head = old_state.head;
4387+ cache.kv.n = old_state.n;
43934388
4394- if (cache.rs.size > 0) { // recurrent models like Mamba or RWKV can't have a state partially erased
4389+ if (cache.rs.size > 0) { // recurrent models like Mamba or RWKV can't have a state partially erased
43954390 llama_kv_cache_seq_rm(cache, -1, -1, -1);
43964391 } else {
43974392 for (auto & slot : slot_boundaries) {
0 commit comments