@@ -112,9 +112,17 @@ llama_context::llama_context(
112112        }
113113    }
114114
115+     cparams.n_ctx_seq  = cparams.kv_unified  ? cparams.n_ctx  : cparams.n_ctx  / cparams.n_seq_max ;
116+ 
117+     if  (cparams.n_ctx_seq  > hparams.n_ctx_train ) {
118+         LLAMA_LOG_WARN (" %s: capping n_ctx_seq (%u) to n_ctx_train (%u)\n "  , __func__, cparams.n_ctx_seq , hparams.n_ctx_train );
119+ 
120+         cparams.n_ctx_seq  = hparams.n_ctx_train ;
121+     }
122+ 
115123    LLAMA_LOG_INFO (" %s: n_seq_max     = %u\n "  ,   __func__, cparams.n_seq_max );
116124    LLAMA_LOG_INFO (" %s: n_ctx         = %u\n "  ,   __func__, cparams.n_ctx );
117-     LLAMA_LOG_INFO (" %s: n_ctx_per_seq  = %u\n "  ,   __func__, n_ctx_per_seq () );
125+     LLAMA_LOG_INFO (" %s: n_ctx_seq      = %u\n "  ,   __func__, cparams. n_ctx_seq );
118126    LLAMA_LOG_INFO (" %s: n_batch       = %u\n "  ,   __func__, cparams.n_batch );
119127    LLAMA_LOG_INFO (" %s: n_ubatch      = %u\n "  ,   __func__, cparams.n_ubatch );
120128    LLAMA_LOG_INFO (" %s: causal_attn   = %d\n "  ,   __func__, cparams.causal_attn );
@@ -123,14 +131,14 @@ llama_context::llama_context(
123131    LLAMA_LOG_INFO (" %s: freq_base     = %.1f\n "  , __func__, cparams.rope_freq_base );
124132    LLAMA_LOG_INFO (" %s: freq_scale    = %g\n "  ,   __func__, cparams.rope_freq_scale );
125133
126-     if  (n_ctx_per_seq ()  < hparams.n_ctx_train ) {
127-         LLAMA_LOG_WARN (" %s: n_ctx_per_seq  (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n "  ,
128-                 __func__, n_ctx_per_seq () , hparams.n_ctx_train );
134+     if  (cparams. n_ctx_seq  < hparams.n_ctx_train ) {
135+         LLAMA_LOG_WARN (" %s: n_ctx_seq  (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n "  ,
136+                 __func__, cparams. n_ctx_seq , hparams.n_ctx_train );
129137    }
130138
131-     if  (n_ctx_per_seq ()  > hparams.n_ctx_train ) {
132-         LLAMA_LOG_WARN (" %s: n_ctx_per_seq  (%u) > n_ctx_train (%u) -- possible training context overflow\n "  ,
133-                 __func__, n_ctx_per_seq () , hparams.n_ctx_train );
139+     if  (cparams. n_ctx_seq  > hparams.n_ctx_train ) {
140+         LLAMA_LOG_WARN (" %s: n_ctx_seq  (%u) > n_ctx_train (%u) -- possible training context overflow\n "  ,
141+                 __func__, cparams. n_ctx_seq , hparams.n_ctx_train );
134142    }
135143
136144    if  (!hparams.vocab_only ) {
@@ -451,8 +459,8 @@ uint32_t llama_context::n_ctx() const {
451459    return  cparams.n_ctx ;
452460}
453461
454- uint32_t  llama_context::n_ctx_per_seq  () const  {
455-     return  cparams.kv_unified  ? cparams. n_ctx  : cparams. n_ctx  / cparams. n_seq_max ;
462+ uint32_t  llama_context::n_ctx_seq  () const  {
463+     return  cparams.n_ctx_seq ;
456464}
457465
458466uint32_t  llama_context::n_batch () const  {
@@ -2381,6 +2389,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
23812389    return  ctx->n_ctx ();
23822390}
23832391
2392+ uint32_t  llama_n_ctx_seq (const  llama_context * ctx) {
2393+     return  ctx->n_ctx_seq ();
2394+ }
2395+ 
23842396uint32_t  llama_n_batch (const  llama_context * ctx) {
23852397    return  ctx->n_batch ();
23862398}
0 commit comments