@@ -1092,7 +1092,8 @@ struct llm_build_context {
10921092 llama_context & lctx,
10931093 const llama_ubatch & ubatch,
10941094 const llm_build_cb & cb,
1095- bool worst_case) :
1095+ bool worst_case,
1096+ bool warmup) :
10961097 model (lctx.model),
10971098 lctx (lctx),
10981099 hparams (model.hparams),
@@ -1110,7 +1111,7 @@ struct llm_build_context {
11101111 n_embd_head_v (hparams.n_embd_head_v),
11111112 n_embd_v_gqa (hparams.n_embd_v_gqa()),
11121113 n_expert (hparams.n_expert),
1113- n_expert_used (hparams.n_expert_used),
1114+ n_expert_used (warmup ? hparams.n_expert : hparams.n_expert_used),
11141115 freq_base (cparams.rope_freq_base),
11151116 freq_scale (cparams.rope_freq_scale),
11161117 ext_factor (cparams.yarn_ext_factor),
@@ -8198,7 +8199,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
81988199
81998200 llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
82008201
8201- struct llm_build_context llm (lctx, dummy, cb, false );
8202+ struct llm_build_context llm (lctx, dummy, cb, false , false );
82028203
82038204 llm.init();
82048205
@@ -8215,7 +8216,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
82158216
82168217 llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
82178218
8218- struct llm_build_context llm (lctx, dummy, cb, false );
8219+ struct llm_build_context llm (lctx, dummy, cb, false , false );
82198220
82208221 llm.init();
82218222
@@ -8266,7 +8267,11 @@ static struct ggml_cgraph * llama_build_graph(
82668267
82678268 struct ggml_cgraph * result = NULL ;
82688269
8269- struct llm_build_context llm (lctx, ubatch, cb, worst_case);
8270+ const llama_vocab * vocab = llama_model_get_vocab (&model);
8271+ llama_token bos = llama_vocab_bos (vocab);
8272+ llama_token eos = llama_vocab_eos (vocab);
8273+ bool is_warming_up = (ubatch.n_tokens == 2 && ubatch.token [0 ] == bos && ubatch.token [1 ] == eos);
8274+ struct llm_build_context llm (lctx, ubatch, cb, worst_case, is_warming_up);
82708275
82718276 llm.init();
82728277
0 commit comments