Skip to content

Commit c4c4cf3

Browse files
committed
Revert "memory : move the recurrent state into the memory context"
This reverts commit 00f115f.
1 parent a0a0d28 commit c4c4cf3

File tree

4 files changed

+18
-26
lines changed

4 files changed

+18
-26
lines changed

src/llama-graph.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,6 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
235235
}
236236
}
237237

238-
llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) :
239-
mctx(mctx),
240-
head(mctx->get_head()),
241-
rs_z(mctx->get_rs_z()) {
242-
}
243-
244238
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
245239
GGML_UNUSED(ubatch);
246240

@@ -269,8 +263,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
269263
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
270264
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
271265

272-
res &= this->head == mctx->get_head();
273-
res &= this->rs_z == mctx->get_rs_z();
266+
res &= head == mctx->get_head();
267+
res &= rs_z == mctx->get_rs_z();
274268

275269
return res;
276270
}
@@ -1906,6 +1900,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
19061900
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
19071901
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
19081902

1903+
inp->head = mctx_cur->get_head();
1904+
inp->rs_z = mctx_cur->get_rs_z();
1905+
19091906
return inp;
19101907
}
19111908

src/llama-graph.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ class llm_graph_input_cls : public llm_graph_input_i {
219219

220220
class llm_graph_input_rs : public llm_graph_input_i {
221221
public:
222-
llm_graph_input_rs(const llama_memory_recurrent_context * mctx);
222+
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
223223
virtual ~llm_graph_input_rs() = default;
224224

225225
void set_input(const llama_ubatch * ubatch) override;
@@ -235,9 +235,9 @@ class llm_graph_input_rs : public llm_graph_input_i {
235235

236236
const llama_memory_recurrent_context * mctx;
237237

238-
// need to match for valid graph reuse
239-
const uint32_t head;
240-
const int32_t rs_z;
238+
// used in view offsets, need to match for valid graph reuse
239+
uint32_t head;
240+
int32_t rs_z;
241241
};
242242

243243
class llm_graph_input_cross_embd : public llm_graph_input_i {

src/llama-memory-recurrent.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,15 +1093,12 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
10931093
llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
10941094

10951095
llama_memory_recurrent_context::llama_memory_recurrent_context(
1096-
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem),
1097-
n_rs(mem->size), head(0), rs_z(0), size(mem->size) {
1096+
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
10981097
}
10991098

11001099
llama_memory_recurrent_context::llama_memory_recurrent_context(
11011100
llama_memory_recurrent * mem,
1102-
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)),
1103-
n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) {
1104-
}
1101+
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
11051102

11061103
llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
11071104

@@ -1142,19 +1139,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
11421139
}
11431140

11441141
uint32_t llama_memory_recurrent_context::get_n_rs() const {
1145-
return n_rs;
1142+
return is_full ? mem->size : mem->n;
11461143
}
11471144

11481145
uint32_t llama_memory_recurrent_context::get_head() const {
1149-
return head;
1146+
return is_full ? 0 : mem->head;
11501147
}
11511148

11521149
int32_t llama_memory_recurrent_context::get_rs_z() const {
1153-
return rs_z;
1150+
return is_full ? 0 : mem->rs_z;
11541151
}
11551152

11561153
uint32_t llama_memory_recurrent_context::get_size() const {
1157-
return size;
1154+
return mem->size;
11581155
}
11591156

11601157
ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
@@ -1166,5 +1163,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
11661163
}
11671164

11681165
int32_t llama_memory_recurrent_context::s_copy(int i) const {
1169-
return mem->cells[i + head].src0;
1166+
return mem->cells[i + mem->head].src0;
11701167
}

src/llama-memory-recurrent.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,8 @@ class llama_memory_recurrent_context : public llama_memory_context_i {
175175

176176
//
177177
// data needed for building the compute graph for the current ubatch:
178+
// TODO: extract all the state like `head` and `n` here
178179
//
179180

180-
const uint32_t n_rs = 0;
181-
const uint32_t head = 0;
182-
const int32_t rs_z = -1;
183-
const uint32_t size = 0;
181+
const bool is_full = false;
184182
};

0 commit comments

Comments
 (0)