Skip to content

Commit cf74744

Browse files
committed
context : move interface implementation to source file + factory
ggml-ci
1 parent 5bb8a26 commit cf74744

File tree

3 files changed

+146
-143
lines changed

3 files changed

+146
-143
lines changed

src/llama-context.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,39 @@
1010
#include <stdexcept>
1111
#include <cinttypes>
1212

13+
llama_context * llama_context::create(const llama_model & model, llama_context_params params) {
14+
llama_context * ctx = nullptr;
15+
16+
try {
17+
switch (model.arch) {
18+
case LLM_ARCH_BERT:
19+
case LLM_ARCH_JINA_BERT_V2:
20+
case LLM_ARCH_NOMIC_BERT:
21+
ctx = new llama_context_enc(model, params, LLM_GRAPH_TYPE_DEFAULT);
22+
break;
23+
case LLM_ARCH_T5:
24+
ctx = new llama_context_enc_dec(model, params);
25+
break;
26+
case LLM_ARCH_RWKV6:
27+
case LLM_ARCH_RWKV6QWEN2:
28+
case LLM_ARCH_MAMBA:
29+
GGML_ASSERT(llama_model_is_recurrent(&model));
30+
ctx = new llama_context_recurrent(model, params, LLM_GRAPH_TYPE_DEFAULT);
31+
break;
32+
default:
33+
GGML_ASSERT(!llama_model_is_recurrent(&model));
34+
ctx = new llama_context_kv_self(model, params, LLM_GRAPH_TYPE_DEFAULT);
35+
}
36+
} catch (const std::exception & e) {
37+
LLAMA_LOG_ERROR("%s: failed to initialize context: %s\n", __func__, e.what());
38+
return nullptr;
39+
}
40+
41+
ctx->init();
42+
43+
return ctx;
44+
}
45+
1346
//
1447
// llama_context_base
1548
//
@@ -3212,6 +3245,84 @@ size_t llama_context_enc_dec::state_seq_save_file(
32123245
// interface implementation
32133246
//
32143247

3248+
llama_context_params llama_context_default_params() {
3249+
llama_context_params result = {
3250+
/*.n_ctx =*/ 512,
3251+
/*.n_batch =*/ 2048,
3252+
/*.n_ubatch =*/ 512,
3253+
/*.n_seq_max =*/ 1,
3254+
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
3255+
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
3256+
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
3257+
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
3258+
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
3259+
/*.rope_freq_base =*/ 0.0f,
3260+
/*.rope_freq_scale =*/ 0.0f,
3261+
/*.yarn_ext_factor =*/ -1.0f,
3262+
/*.yarn_attn_factor =*/ 1.0f,
3263+
/*.yarn_beta_fast =*/ 32.0f,
3264+
/*.yarn_beta_slow =*/ 1.0f,
3265+
/*.yarn_orig_ctx =*/ 0,
3266+
/*.defrag_thold =*/ -1.0f,
3267+
/*.cb_eval =*/ nullptr,
3268+
/*.cb_eval_user_data =*/ nullptr,
3269+
/*.type_k =*/ GGML_TYPE_F16,
3270+
/*.type_v =*/ GGML_TYPE_F16,
3271+
/*.logits_all =*/ false,
3272+
/*.embeddings =*/ false,
3273+
/*.offload_kqv =*/ true,
3274+
/*.flash_attn =*/ false,
3275+
/*.no_perf =*/ true,
3276+
/*.abort_callback =*/ nullptr,
3277+
/*.abort_callback_data =*/ nullptr,
3278+
};
3279+
3280+
return result;
3281+
}
3282+
3283+
llama_context * llama_init_from_model(
3284+
llama_model * model,
3285+
llama_context_params params) {
3286+
if (!model) {
3287+
LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
3288+
return nullptr;
3289+
}
3290+
3291+
if (params.n_batch == 0 && params.n_ubatch == 0) {
3292+
LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
3293+
return nullptr;
3294+
}
3295+
3296+
if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
3297+
LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
3298+
return nullptr;
3299+
}
3300+
3301+
if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
3302+
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
3303+
params.flash_attn = false;
3304+
}
3305+
3306+
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
3307+
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
3308+
params.flash_attn = false;
3309+
}
3310+
3311+
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
3312+
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
3313+
return nullptr;
3314+
}
3315+
3316+
return llama_context::create(*model, params);
3317+
}
3318+
3319+
// deprecated
3320+
struct llama_context * llama_new_context_with_model(
3321+
struct llama_model * model,
3322+
struct llama_context_params params) {
3323+
return llama_init_from_model(model, params);
3324+
}
3325+
32153326
void llama_free(struct llama_context * ctx) {
32163327
delete ctx;
32173328
}
@@ -3653,3 +3764,36 @@ int32_t llama_decode(
36533764

36543765
return ret;
36553766
}
3767+
3768+
//
3769+
// perf
3770+
//
3771+
3772+
llama_perf_context_data llama_perf_context(const llama_context * ctx) {
3773+
llama_perf_context_data data = {};
3774+
3775+
if (ctx == nullptr) {
3776+
return data;
3777+
}
3778+
3779+
data = ctx->perf_get_data();
3780+
3781+
return data;
3782+
}
3783+
3784+
void llama_perf_context_print(const llama_context * ctx) {
3785+
const auto data = llama_perf_context(ctx);
3786+
3787+
const double t_end_ms = 1e-3 * ggml_time_us();
3788+
3789+
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
3790+
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
3791+
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
3792+
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
3793+
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
3794+
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
3795+
}
3796+
3797+
void llama_perf_context_reset(llama_context * ctx) {
3798+
ctx->perf_reset();
3799+
}

src/llama-context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ class llama_context_i {
157157
// C alias
158158
struct llama_context : public llama_context_i {
159159
using llama_context_i::llama_context_i;
160+
161+
static llama_context * create(const llama_model & model, llama_context_params params);
160162
};
161163

162164
// basic transformer without KV cache

src/llama.cpp

Lines changed: 0 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#include "llama-chat.h"
44
#include "llama-mmap.h"
5-
#include "llama-context.h"
65
#include "llama-vocab.h"
76
#include "llama-model-loader.h"
87
#include "llama-model.h"
@@ -25,41 +24,6 @@
2524
// interface implementation
2625
//
2726

28-
struct llama_context_params llama_context_default_params() {
29-
struct llama_context_params result = {
30-
/*.n_ctx =*/ 512,
31-
/*.n_batch =*/ 2048,
32-
/*.n_ubatch =*/ 512,
33-
/*.n_seq_max =*/ 1,
34-
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
35-
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
36-
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
37-
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
38-
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
39-
/*.rope_freq_base =*/ 0.0f,
40-
/*.rope_freq_scale =*/ 0.0f,
41-
/*.yarn_ext_factor =*/ -1.0f,
42-
/*.yarn_attn_factor =*/ 1.0f,
43-
/*.yarn_beta_fast =*/ 32.0f,
44-
/*.yarn_beta_slow =*/ 1.0f,
45-
/*.yarn_orig_ctx =*/ 0,
46-
/*.defrag_thold =*/ -1.0f,
47-
/*.cb_eval =*/ nullptr,
48-
/*.cb_eval_user_data =*/ nullptr,
49-
/*.type_k =*/ GGML_TYPE_F16,
50-
/*.type_v =*/ GGML_TYPE_F16,
51-
/*.logits_all =*/ false,
52-
/*.embeddings =*/ false,
53-
/*.offload_kqv =*/ true,
54-
/*.flash_attn =*/ false,
55-
/*.no_perf =*/ true,
56-
/*.abort_callback =*/ nullptr,
57-
/*.abort_callback_data =*/ nullptr,
58-
};
59-
60-
return result;
61-
}
62-
6327
struct llama_sampler_chain_params llama_sampler_chain_default_params() {
6428
struct llama_sampler_chain_params result = {
6529
/*.no_perf =*/ true,
@@ -289,80 +253,6 @@ struct llama_model * llama_model_load_from_splits(
289253
return llama_model_load_from_file_impl(splits.front(), splits, params);
290254
}
291255

292-
struct llama_context * llama_init_from_model(
293-
struct llama_model * model,
294-
struct llama_context_params params) {
295-
296-
if (!model) {
297-
LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
298-
return nullptr;
299-
}
300-
301-
if (params.n_batch == 0 && params.n_ubatch == 0) {
302-
LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
303-
return nullptr;
304-
}
305-
306-
if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
307-
LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
308-
return nullptr;
309-
}
310-
311-
if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
312-
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
313-
params.flash_attn = false;
314-
}
315-
316-
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
317-
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
318-
params.flash_attn = false;
319-
}
320-
321-
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
322-
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
323-
return nullptr;
324-
}
325-
326-
llama_context * ctx = nullptr;
327-
328-
try {
329-
// TODO: make static method of llama_context
330-
switch (model->arch) {
331-
case LLM_ARCH_BERT:
332-
case LLM_ARCH_JINA_BERT_V2:
333-
case LLM_ARCH_NOMIC_BERT:
334-
ctx = new llama_context_enc(*model, params, LLM_GRAPH_TYPE_DEFAULT);
335-
break;
336-
case LLM_ARCH_T5:
337-
ctx = new llama_context_enc_dec(*model, params);
338-
break;
339-
case LLM_ARCH_RWKV6:
340-
case LLM_ARCH_RWKV6QWEN2:
341-
case LLM_ARCH_MAMBA:
342-
GGML_ASSERT(llama_model_is_recurrent(model));
343-
ctx = new llama_context_recurrent(*model, params, LLM_GRAPH_TYPE_DEFAULT);
344-
break;
345-
default:
346-
GGML_ASSERT(!llama_model_is_recurrent(model));
347-
ctx = new llama_context_kv_self(*model, params, LLM_GRAPH_TYPE_DEFAULT);
348-
};
349-
350-
ctx->init();
351-
} catch (const std::exception & e) {
352-
LLAMA_LOG_ERROR("%s: failed to initialize context: %s\n", __func__, e.what());
353-
return nullptr;
354-
}
355-
356-
return ctx;
357-
}
358-
359-
// deprecated
360-
struct llama_context * llama_new_context_with_model(
361-
struct llama_model * model,
362-
struct llama_context_params params) {
363-
return llama_init_from_model(model, params);
364-
}
365-
366256
//
367257
// chat templates
368258
//
@@ -448,36 +338,3 @@ const char * llama_print_system_info(void) {
448338

449339
return s.c_str();
450340
}
451-
452-
//
453-
// perf
454-
//
455-
456-
struct llama_perf_context_data llama_perf_context(const struct llama_context * ctx) {
457-
struct llama_perf_context_data data = {};
458-
459-
if (ctx == nullptr) {
460-
return data;
461-
}
462-
463-
data = ctx->perf_get_data();
464-
465-
return data;
466-
}
467-
468-
void llama_perf_context_print(const struct llama_context * ctx) {
469-
const auto data = llama_perf_context(ctx);
470-
471-
const double t_end_ms = 1e-3 * ggml_time_us();
472-
473-
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
474-
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
475-
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
476-
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
477-
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
478-
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
479-
}
480-
481-
void llama_perf_context_reset(struct llama_context * ctx) {
482-
ctx->perf_reset();
483-
}

0 commit comments

Comments
 (0)