Skip to content

Commit f2cca02

Browse files
committed
llama : add note about context size queries
1 parent 23323cd commit f2cca02

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

include/llama.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,8 @@ extern "C" {
461461
LLAMA_API bool llama_supports_gpu_offload(void);
462462
LLAMA_API bool llama_supports_rpc (void);
463463

464+
// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
465+
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
464466
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
465467
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
466468
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
@@ -586,7 +588,7 @@ extern "C" {
586588
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
587589

588590
// Manually free a LoRA adapter
589-
// Note: loaded adapters will be free when the associated model is deleted
591+
// NOTE: loaded adapters will be free when the associated model is deleted
590592
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
591593

592594
// Get the invocation tokens if the current lora is an alora

src/llama-context.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,24 @@ llama_context::llama_context(
112112
}
113113
}
114114

115-
cparams.n_ctx_seq = cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max;
115+
if (cparams.kv_unified) {
116+
cparams.n_ctx_seq = cparams.n_ctx;
117+
} else {
118+
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
119+
}
116120

117121
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
118122
LLAMA_LOG_WARN("%s: capping n_ctx_seq (%u) to n_ctx_train (%u)\n", __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
119123

120124
cparams.n_ctx_seq = hparams.n_ctx_train;
121125
}
122126

127+
if (cparams.kv_unified) {
128+
cparams.n_ctx = cparams.n_ctx_seq;
129+
} else {
130+
cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
131+
}
132+
123133
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
124134
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
125135
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);

0 commit comments

Comments
 (0)