Skip to content

Commit 8723531

Browse files
committed
memory : move the recurrent state into the memory context
1 parent a126bc4 commit 8723531

File tree

4 files changed

+26
-18
lines changed

4 files changed

+26
-18
lines changed

src/llama-graph.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,12 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
235235
}
236236
}
237237

238+
llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) :
239+
mctx(mctx),
240+
head(mctx->get_head()),
241+
rs_z(mctx->get_rs_z()) {
242+
}
243+
238244
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
239245
GGML_UNUSED(ubatch);
240246

@@ -263,8 +269,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
263269
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
264270
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
265271

266-
res &= head == mctx->get_head();
267-
res &= rs_z == mctx->get_rs_z();
272+
res &= this->head == mctx->get_head();
273+
res &= this->rs_z == mctx->get_rs_z();
268274

269275
return res;
270276
}
@@ -1899,9 +1905,6 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
18991905
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
19001906
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
19011907

1902-
inp->head = mctx_cur->get_head();
1903-
inp->rs_z = mctx_cur->get_rs_z();
1904-
19051908
return inp;
19061909
}
19071910

src/llama-graph.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ class llm_graph_input_cls : public llm_graph_input_i {
219219

220220
class llm_graph_input_rs : public llm_graph_input_i {
221221
public:
222-
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
222+
llm_graph_input_rs(const llama_memory_recurrent_context * mctx);
223223
virtual ~llm_graph_input_rs() = default;
224224

225225
void set_input(const llama_ubatch * ubatch) override;
@@ -235,9 +235,9 @@ class llm_graph_input_rs : public llm_graph_input_i {
235235

236236
const llama_memory_recurrent_context * mctx;
237237

238-
// used in view offsets, need to match for valid graph reuse
239-
uint32_t head;
240-
int32_t rs_z;
238+
// need to match for valid graph reuse
239+
const uint32_t head;
240+
const int32_t rs_z;
241241
};
242242

243243
class llm_graph_input_cross_embd : public llm_graph_input_i {

src/llama-memory-recurrent.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,12 +1092,15 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
10921092
llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
10931093

10941094
llama_memory_recurrent_context::llama_memory_recurrent_context(
1095-
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
1095+
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem),
1096+
n_rs(mem->size), head(0), rs_z(0), size(mem->size) {
10961097
}
10971098

10981099
llama_memory_recurrent_context::llama_memory_recurrent_context(
10991100
llama_memory_recurrent * mem,
1100-
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
1101+
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)),
1102+
n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) {
1103+
}
11011104

11021105
llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
11031106

@@ -1138,19 +1141,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
11381141
}
11391142

11401143
uint32_t llama_memory_recurrent_context::get_n_rs() const {
1141-
return is_full ? mem->size : mem->n;
1144+
return n_rs;
11421145
}
11431146

11441147
uint32_t llama_memory_recurrent_context::get_head() const {
1145-
return is_full ? 0 : mem->head;
1148+
return head;
11461149
}
11471150

11481151
int32_t llama_memory_recurrent_context::get_rs_z() const {
1149-
return is_full ? 0 : mem->rs_z;
1152+
return rs_z;
11501153
}
11511154

11521155
uint32_t llama_memory_recurrent_context::get_size() const {
1153-
return mem->size;
1156+
return size;
11541157
}
11551158

11561159
ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
@@ -1162,5 +1165,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
11621165
}
11631166

11641167
int32_t llama_memory_recurrent_context::s_copy(int i) const {
1165-
return mem->cells[i + mem->head].src0;
1168+
return mem->cells[i + head].src0;
11661169
}

src/llama-memory-recurrent.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,10 @@ class llama_memory_recurrent_context : public llama_memory_context_i {
175175

176176
//
177177
// data needed for building the compute graph for the current ubatch:
178-
// TODO: extract all the state like `head` and `n` here
179178
//
180179

181-
const bool is_full = false;
180+
const uint32_t n_rs = 0;
181+
const uint32_t head = 0;
182+
const int32_t rs_z = -1;
183+
const uint32_t size = 0;
182184
};

0 commit comments

Comments
 (0)