Skip to content

Commit 5bb8a26

Browse files
committed
context : reduce virtuals + remove test function
ggml-ci
1 parent 2664a3d commit 5bb8a26

File tree

7 files changed

+37
-52
lines changed

7 files changed

+37
-52
lines changed

examples/quantize-stats/quantize-stats.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "ggml.h"
22
#include "llama.h"
3-
#include "llama-context.h"
3+
#include "llama-model.h"
44
#include "common.h"
55

66
#include <algorithm>
@@ -328,7 +328,7 @@ int main(int argc, char ** argv) {
328328
}
329329
}
330330

331-
const auto & tensors = llama_internal_get_tensor_map(ctx);
331+
const auto & tensors = llama_internal_get_tensor_map(model);
332332

333333
// check layer tensors
334334
int included_layers = 0;

src/llama-adapter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,4 @@ struct llama_adapter_lora {
7373
llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
7474
};
7575

76-
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
76+
using llama_adapter_loras = std::unordered_map<struct llama_adapter_lora *, float>;

src/llama-context.cpp

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#include "llama-context.h"
22

33
#include "llama-impl.h"
4-
#include "llama-mmap.h"
54
#include "llama-io.h"
5+
#include "llama-mmap.h"
6+
#include "llama-model.h"
7+
#include "llama-kv-cache.h"
68

79
#include <cstring>
810
#include <stdexcept>
@@ -2288,10 +2290,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
22882290
return 0;
22892291
}
22902292

2291-
ggml_cgraph * llama_context_kv_self::graph_init() {
2292-
return llama_context_base::graph_init();
2293-
}
2294-
22952293
llm_graph_result_ptr llama_context_kv_self::graph_build(
22962294
ggml_context * ctx,
22972295
ggml_cgraph * gf,
@@ -2735,10 +2733,6 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
27352733
return 0;
27362734
}
27372735

2738-
ggml_cgraph * llama_context_recurrent::graph_init() {
2739-
return llama_context_base::graph_init();
2740-
}
2741-
27422736
llm_graph_result_ptr llama_context_recurrent::graph_build(
27432737
ggml_context * ctx,
27442738
ggml_cgraph * gf,
@@ -2956,10 +2950,6 @@ void llama_context_dec::reserve() {
29562950
llama_context_kv_self::reserve();
29572951
}
29582952

2959-
ggml_cgraph * llama_context_dec::graph_init() {
2960-
return llama_context_kv_self::graph_init();
2961-
}
2962-
29632953
llm_graph_result_ptr llama_context_dec::graph_build(
29642954
ggml_context * ctx,
29652955
ggml_cgraph * gf,
@@ -3663,10 +3653,3 @@ int32_t llama_decode(
36633653

36643654
return ret;
36653655
}
3666-
3667-
3668-
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
3669-
struct llama_context * ctx
3670-
) {
3671-
return ctx->get_model().tensors_by_name;
3672-
}

src/llama-context.h

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
#include "llama-batch.h"
55
#include "llama-cparams.h"
66
#include "llama-graph.h"
7-
#include "llama-model.h"
8-
#include "llama-kv-cache.h"
97
#include "llama-adapter.h"
108

119
#include "ggml-cpp.h"
1210

1311
#include <map>
1412
#include <vector>
1513

14+
struct llama_model;
15+
struct llama_kv_cache;
16+
1617
class llama_io_read_i;
1718
class llama_io_write_i;
1819

@@ -244,28 +245,29 @@ class llama_context_base : public llama_context {
244245

245246
// Make sure enough space is available for outputs.
246247
// Returns max number of outputs for which space was reserved.
247-
virtual int32_t output_reserve(int32_t n_outputs);
248+
int32_t output_reserve(int32_t n_outputs);
248249

249250
// make the outputs have the same order they had in the user-provided batch
250251
// TODO: maybe remove this
251-
virtual void output_reorder();
252+
void output_reorder();
252253

253254
//
254255
// graph
255256
//
256257

257-
virtual int32_t graph_max_nodes() const;
258+
int32_t graph_max_nodes() const;
258259

259260
// zero-out inputs and create the ctx_compute for the compute graph
260-
virtual ggml_cgraph * graph_init();
261+
ggml_cgraph * graph_init();
261262

263+
// override this method in order to pass custom set of parameters to the llm_graph_context
262264
virtual llm_graph_result_ptr graph_build(
263265
ggml_context * ctx,
264266
ggml_cgraph * gf,
265267
const llama_ubatch & ubatch);
266268

267269
// returns the result of ggml_backend_sched_graph_compute_async execution
268-
virtual enum ggml_status graph_compute(
270+
enum ggml_status graph_compute(
269271
ggml_cgraph * gf,
270272
bool batched);
271273

@@ -330,6 +332,8 @@ class llama_context_base : public llama_context {
330332
size_t n_token_count) override;
331333

332334
protected:
335+
// override these to store all relevant state for the specific context
336+
// TODO: read/write adapters
333337
virtual size_t state_write_data(llama_io_write_i & io);
334338
virtual size_t state_read_data (llama_io_read_i & io);
335339

@@ -345,10 +349,10 @@ class llama_context_base : public llama_context {
345349

346350
const llm_graph_type gtype;
347351

348-
llama_cparams cparams;
349-
llama_adapter_cvec cvec;
350-
llama_loras loras;
351-
llama_sbatch sbatch;
352+
llama_cparams cparams;
353+
llama_adapter_cvec cvec;
354+
llama_adapter_loras loras;
355+
llama_sbatch sbatch;
352356

353357
ggml_backend_sched_ptr sched;
354358

@@ -431,8 +435,6 @@ class llama_context_kv_self : public llama_context_base {
431435
// graph
432436
//
433437

434-
ggml_cgraph * graph_init() override;
435-
436438
llm_graph_result_ptr graph_build(
437439
ggml_context * ctx,
438440
ggml_cgraph * gf,
@@ -482,8 +484,6 @@ class llama_context_recurrent : public llama_context_base {
482484
// graph
483485
//
484486

485-
ggml_cgraph * graph_init() override;
486-
487487
llm_graph_result_ptr graph_build(
488488
ggml_context * ctx,
489489
ggml_cgraph * gf,
@@ -532,8 +532,6 @@ class llama_context_dec : public llama_context_kv_self {
532532
// graph
533533
//
534534

535-
ggml_cgraph * graph_init() override;
536-
537535
llm_graph_result_ptr graph_build(
538536
ggml_context * ctx,
539537
ggml_cgraph * gf,
@@ -677,7 +675,3 @@ class llama_context_enc_dec : public llama_context {
677675

678676
llama_cross cross;
679677
};
680-
681-
// For internal test use
682-
// TODO: remove
683-
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);

src/llama-graph.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,10 @@ struct llm_graph_params {
366366
ggml_backend * backend_cpu;
367367
const std::vector<ggml_backend_ptr> & backends;
368368

369-
const llama_adapter_cvec * cvec;
370-
const llama_loras * loras;
371-
const llama_memory_i * memory;
372-
const llama_cross * cross;
369+
const llama_adapter_cvec * cvec;
370+
const llama_adapter_loras * loras;
371+
const llama_memory_i * memory;
372+
const llama_cross * cross;
373373

374374
int32_t n_outputs;
375375
};
@@ -420,10 +420,10 @@ struct llm_graph_context {
420420
ggml_backend * backend_cpu;
421421
const std::vector<ggml_backend_ptr> & backends;
422422

423-
const llama_adapter_cvec * cvec;
424-
const llama_loras * loras;
425-
const llama_memory_i * memory;
426-
const llama_cross * cross;
423+
const llama_adapter_cvec * cvec;
424+
const llama_adapter_loras * loras;
425+
const llama_memory_i * memory;
426+
const llama_cross * cross;
427427

428428
std::unique_ptr<llm_graph_result> res;
429429

src/llama-model.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10845,3 +10845,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
1084510845
default: return false;
1084610846
}
1084710847
}
10848+
10849+
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
10850+
return model->tensors_by_name;
10851+
}

src/llama-model.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,3 +385,7 @@ struct llama_model {
385385
};
386386

387387
const char * llm_type_name(llm_type type);
388+
389+
// For internal test use
390+
// TODO: remove
391+
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model);

0 commit comments

Comments
 (0)