Skip to content

Commit 9e50456

Browse files
committed
context : minor simplify
ggml-ci
1 parent befe14f commit 9e50456

File tree

4 files changed

+22
-26
lines changed

4 files changed

+22
-26
lines changed

src/llama-context.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ void llama_context::init() {
256256
{
257257
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
258258
auto ctx = graph_init();
259-
auto res_pp = graph_build(ctx, ubatch_pp, true);
259+
auto res_pp = graph_build(ctx.get(), ubatch_pp, true);
260260
auto & gf_pp = res_pp.gf;
261261
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
262262
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
@@ -271,7 +271,7 @@ void llama_context::init() {
271271
{
272272
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
273273
auto ctx = graph_init();
274-
auto res_tg = graph_build(ctx, ubatch_tg, true);
274+
auto res_tg = graph_build(ctx.get(), ubatch_tg, true);
275275
auto & gf_tg = res_tg.gf;
276276
if (!ggml_backend_sched_reserve(sched.get(), gf_tg)) {
277277
LLAMA_LOG_ERROR("%s: failed to allocate compute tg buffers\n", __func__);
@@ -285,7 +285,7 @@ void llama_context::init() {
285285
{
286286
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
287287
auto ctx = graph_init();
288-
auto res_pp = graph_build(ctx, ubatch_pp, true);
288+
auto res_pp = graph_build(ctx.get(), ubatch_pp, true);
289289
auto & gf_pp = res_pp.gf;
290290
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
291291
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
@@ -573,7 +573,7 @@ ggml_context_ptr llama_context::graph_init() {
573573
}
574574

575575
llama_graph_result llama_context::graph_build(
576-
ggml_context_ptr & ctx,
576+
ggml_context * ctx,
577577
const llama_ubatch & ubatch,
578578
bool worst_case) {
579579
return model.build_graph(ctx, *this, cparams, ubatch, worst_case);
@@ -1720,7 +1720,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
17201720
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
17211721

17221722
auto ctx = graph_init();
1723-
auto res = graph_build(ctx, ubatch, false);
1723+
auto res = graph_build(ctx.get(), ubatch, false);
17241724

17251725
auto * gf = res.gf;
17261726

@@ -2000,7 +2000,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
20002000
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
20012001

20022002
auto ctx = graph_init();
2003-
auto res = graph_build(ctx, ubatch, true);
2003+
auto res = graph_build(ctx.get(), ubatch, true);
20042004

20052005
// initialize scheduler with the worst-case graph
20062006
ggml_backend_sched_reset(sched.get());
@@ -2015,7 +2015,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
20152015
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
20162016

20172017
auto ctx = graph_init();
2018-
auto res = graph_build(ctx, ubatch, false);
2018+
auto res = graph_build(ctx.get(), ubatch, false);
20192019

20202020
auto * gf = res.gf;
20212021

@@ -2483,11 +2483,10 @@ void llama_context_kv_self::kv_self_update() {
24832483
ggml_backend_sched_reset(sched.get());
24842484

24852485
auto ctx = graph_init();
2486-
auto * ctx0 = ctx.get();
24872486

2488-
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
2487+
ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), model.max_nodes(), false);
24892488

2490-
build_kv_self_shift(ctx0, gf);
2489+
build_kv_self_shift(ctx.get(), gf);
24912490

24922491
ggml_backend_sched_alloc_graph(sched.get(), gf);
24932492

@@ -2512,11 +2511,10 @@ void llama_context_kv_self::kv_self_update() {
25122511
ggml_backend_sched_reset(sched.get());
25132512

25142513
auto ctx = graph_init();
2515-
auto * ctx0 = ctx.get();
25162514

2517-
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
2515+
ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), model.max_nodes(), false);
25182516

2519-
build_kv_self_defrag(ctx0, gf);
2517+
build_kv_self_defrag(ctx.get(), gf);
25202518

25212519
ggml_backend_sched_alloc_graph(sched.get(), gf);
25222520

src/llama-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ struct llama_context : public llama_graph_i {
9797

9898
// TODO: add encode/decode graphs
9999
virtual llama_graph_result graph_build(
100-
ggml_context_ptr & ctx,
100+
ggml_context * ctx,
101101
const llama_ubatch & ubatch,
102102
bool worst_case);
103103

src/llama-model.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3841,19 +3841,18 @@ struct llm_build_context {
38413841
const enum llama_pooling_type pooling_type;
38423842
const enum llama_rope_type rope_type;
38433843

3844-
ggml_context_ptr & ctx;
3845-
ggml_context * ctx0 = nullptr;
3844+
ggml_context * ctx0 = nullptr;
38463845

38473846
llama_graph_result res;
38483847

38493848
// TODO: consider making the entire interface noexcept
38503849
llm_build_context(
3851-
ggml_context_ptr & ctx,
3852-
llama_graph_i & lgf,
3853-
const llama_model & model,
3854-
const llama_cparams & cparams,
3855-
const llama_ubatch & ubatch,
3856-
bool worst_case) :
3850+
ggml_context * ctx,
3851+
llama_graph_i & lgf,
3852+
const llama_model & model,
3853+
const llama_cparams & cparams,
3854+
const llama_ubatch & ubatch,
3855+
bool worst_case) :
38573856
lgf (lgf),
38583857
model (model),
38593858
hparams (model.hparams),
@@ -3885,8 +3884,7 @@ struct llm_build_context {
38853884
flash_attn (cparams.flash_attn),
38863885
pooling_type (cparams.pooling_type),
38873886
rope_type (hparams.rope_type),
3888-
ctx (ctx),
3889-
ctx0 (this->ctx.get()) {
3887+
ctx0 (ctx) {
38903888
}
38913889

38923890
// TODO: tmp
@@ -10937,7 +10935,7 @@ struct llm_build_context {
1093710935
};
1093810936

1093910937
llama_graph_result llama_model::build_graph(
10940-
ggml_context_ptr & ctx,
10938+
ggml_context * ctx,
1094110939
llama_graph_i & lgf,
1094210940
const llama_cparams & cparams,
1094310941
const llama_ubatch & ubatch,

src/llama-model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ struct llama_model {
370370

371371
// TODO: add encode/decode graphs
372372
llama_graph_result build_graph(
373-
ggml_context_ptr & ctx,
373+
ggml_context * ctx,
374374
llama_graph_i & lgf,
375375
const llama_cparams & cparams,
376376
const llama_ubatch & ubatch,

0 commit comments

Comments
 (0)