@@ -112,11 +112,9 @@ llama_context::llama_context(
112112 }
113113 }
114114
115- const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max ;
116-
117115 LLAMA_LOG_INFO (" %s: n_seq_max = %u\n " , __func__, cparams.n_seq_max );
118116 LLAMA_LOG_INFO (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
119- LLAMA_LOG_INFO (" %s: n_ctx_per_seq = %u\n " , __func__, n_ctx_per_seq);
117+ LLAMA_LOG_INFO (" %s: n_ctx_per_seq = %u\n " , __func__, n_ctx_per_seq () );
120118 LLAMA_LOG_INFO (" %s: n_batch = %u\n " , __func__, cparams.n_batch );
121119 LLAMA_LOG_INFO (" %s: n_ubatch = %u\n " , __func__, cparams.n_ubatch );
122120 LLAMA_LOG_INFO (" %s: causal_attn = %d\n " , __func__, cparams.causal_attn );
@@ -125,14 +123,14 @@ llama_context::llama_context(
125123 LLAMA_LOG_INFO (" %s: freq_base = %.1f\n " , __func__, cparams.rope_freq_base );
126124 LLAMA_LOG_INFO (" %s: freq_scale = %g\n " , __func__, cparams.rope_freq_scale );
127125
128- if (n_ctx_per_seq < hparams.n_ctx_train ) {
126+ if (n_ctx_per_seq () < hparams.n_ctx_train ) {
129127 LLAMA_LOG_WARN (" %s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n " ,
130- __func__, n_ctx_per_seq, hparams.n_ctx_train );
128+ __func__, n_ctx_per_seq () , hparams.n_ctx_train );
131129 }
132130
133- if (n_ctx_per_seq > hparams.n_ctx_train ) {
131+ if (n_ctx_per_seq () > hparams.n_ctx_train ) {
134132 LLAMA_LOG_WARN (" %s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n " ,
135- __func__, n_ctx_per_seq, hparams.n_ctx_train );
133+ __func__, n_ctx_per_seq () , hparams.n_ctx_train );
136134 }
137135
138136 if (!hparams.vocab_only ) {
@@ -454,7 +452,7 @@ uint32_t llama_context::n_ctx() const {
454452}
455453
456454uint32_t llama_context::n_ctx_per_seq () const {
457- return cparams.n_ctx / cparams.n_seq_max ;
455+ return cparams.kv_unified ? cparams. n_ctx : cparams. n_ctx / cparams.n_seq_max ;
458456}
459457
460458uint32_t llama_context::n_batch () const {
0 commit comments