Skip to content

Commit 7e4cae5

Browse files
committed
Merge remote-tracking branch 'fairydreaming/experts-warmup' into tmp
2 parents e162e47 + c8bc6e4 commit 7e4cae5

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3602,7 +3602,7 @@ size_t llama_model::size() const {
36023602
}
36033603

36043604
size_t llama_model::max_nodes() const {
3605-
return std::max<size_t>(8192, tensors_by_name.size()*5);
3605+
return std::max<size_t>(65536, tensors_by_name.size()*5);
36063606
}
36073607

36083608
size_t llama_model::n_devices() const {

src/llama.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,8 @@ struct llm_build_context {
10921092
llama_context & lctx,
10931093
const llama_ubatch & ubatch,
10941094
const llm_build_cb & cb,
1095-
bool worst_case) :
1095+
bool worst_case,
1096+
bool warmup) :
10961097
model (lctx.model),
10971098
lctx (lctx),
10981099
hparams (model.hparams),
@@ -1110,7 +1111,7 @@ struct llm_build_context {
11101111
n_embd_head_v (hparams.n_embd_head_v),
11111112
n_embd_v_gqa (hparams.n_embd_v_gqa()),
11121113
n_expert (hparams.n_expert),
1113-
n_expert_used (hparams.n_expert_used),
1114+
n_expert_used (warmup ? hparams.n_expert : hparams.n_expert_used),
11141115
freq_base (cparams.rope_freq_base),
11151116
freq_scale (cparams.rope_freq_scale),
11161117
ext_factor (cparams.yarn_ext_factor),
@@ -8198,7 +8199,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
81988199

81998200
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
82008201

8201-
struct llm_build_context llm(lctx, dummy, cb, false);
8202+
struct llm_build_context llm(lctx, dummy, cb, false, false);
82028203

82038204
llm.init();
82048205

@@ -8215,7 +8216,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
82158216

82168217
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
82178218

8218-
struct llm_build_context llm(lctx, dummy, cb, false);
8219+
struct llm_build_context llm(lctx, dummy, cb, false, false);
82198220

82208221
llm.init();
82218222

@@ -8266,7 +8267,11 @@ static struct ggml_cgraph * llama_build_graph(
82668267

82678268
struct ggml_cgraph * result = NULL;
82688269

8269-
struct llm_build_context llm(lctx, ubatch, cb, worst_case);
8270+
const llama_vocab * vocab = llama_model_get_vocab(&model);
8271+
llama_token bos = llama_vocab_bos(vocab);
8272+
llama_token eos = llama_vocab_eos(vocab);
8273+
bool is_warming_up = (ubatch.n_tokens == 2 && ubatch.token[0] == bos && ubatch.token[1] == eos);
8274+
struct llm_build_context llm(lctx, ubatch, cb, worst_case, is_warming_up);
82708275

82718276
llm.init();
82728277

0 commit comments

Comments
 (0)