File tree Expand file tree Collapse file tree 3 files changed +14
-16
lines changed Expand file tree Collapse file tree 3 files changed +14
-16
lines changed Original file line number Diff line number Diff line change @@ -286,27 +286,21 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
286286 for (uint32_t i = 0 ; i < n_kv; ++i) {
287287 const uint32_t cell_id = i + kv_self->head ;
288288
289- // ////////////////////////////////////////////
290- // TODO: this should not mutate the KV cache !
291- llama_kv_cell & kv_cell = const_cast < class llama_kv_cache_unified *>(kv_self)-> cells [i] ;
289+ const llama_kv_cell & kv_cell = kv_self-> cells [cell_id];
290+
291+ int32_t src = kv_cell. src0 ;
292292
293293 // prevent out-of-bound sources
294- if (kv_cell. src < 0 ) {
294+ if (src < 0 ) {
295295 GGML_ASSERT (kv_self->rs_z >= 0 ); // Need a valid zero-ed cell as a source
296- kv_cell. src = kv_self->rs_z ;
296+ src = kv_self->rs_z ;
297297 }
298- if ((uint32_t ) kv_cell. src >= kv_self->size ) {
298+ if ((uint32_t ) src >= kv_self->size ) {
299299 // ignore out-of-bound sources
300- kv_cell. src = cell_id;
300+ src = cell_id;
301301 }
302302
303- data[i] = kv_cell.src ;
304-
305- // TODO: do not mutate the KV cache
306- // ensure copy only happens once
307- if (kv_cell.src != (int32_t ) cell_id) {
308- kv_cell.src = cell_id;
309- }
303+ data[i] = src;
310304 }
311305 }
312306}
Original file line number Diff line number Diff line change @@ -665,10 +665,13 @@ bool llama_kv_cache_unified::find_slot(
665665 // Find first to-be-cleared cell
666666 rs_z = -1 ;
667667 for (int i = min; i <= max; ++i) {
668- if (cells[i].src == -1 ) {
668+ if (rs_z < 0 && cells[i].src == -1 ) {
669669 rs_z = i;
670- break ;
671670 }
671+ // Stage the source ids for all used cells to allow correct seq_* behavior
672+ // and still make these values available when setting the inputs
673+ cells[i].src0 = cells[i].src ;
674+ cells[i].src = i;
672675 }
673676
674677 // allow getting the range of used cells, from head to head + n
Original file line number Diff line number Diff line change @@ -47,6 +47,7 @@ struct llama_kv_cell {
4747 llama_pos pos = -1 ;
4848 llama_pos delta = 0 ;
4949 int32_t src = -1 ; // used by recurrent state models to copy states
50+ int32_t src0 = -1 ; // like src, but used when setting the inputs (allowing to copy once)
5051 int32_t tail = -1 ;
5152
5253 std::set<llama_seq_id> seq_id;
You can’t perform that action at this time.
0 commit comments