Skip to content

Commit 03231da

Browse files
committed
add model member function to build mtp graph, to be called from speculative.cpp
1 parent 1f477b3 commit 03231da

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

src/llama-model.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18673,6 +18673,22 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1867318673
return llm->res->get_gf();
1867418674
}
1867518675

18676+
ggml_cgraph* llama_model::build_mtp_graph(const llm_graph_params& params,
18677+
ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const {
18678+
std::unique_ptr<llm_graph_context> llm;
18679+
18680+
switch (arch) {
18681+
case LLM_ARCH_GLM4_MOE:
18682+
{
18683+
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params, hidden_state_inp, last_token_id, n_past);
18684+
} break;
18685+
default:
18686+
GGML_ABORT("fatal error");
18687+
}
18688+
18689+
return llm->res->get_gf();
18690+
}
18691+
1867618692
//
1867718693
// interface implementation
1867818694
//

src/llama-model.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,8 @@ struct llama_model {
475475

476476
// TODO: move this to new llm_arch_model_i interface
477477
ggml_cgraph * build_graph(const llm_graph_params & params) const;
478+
ggml_cgraph * build_mtp_graph(const llm_graph_params & params,
479+
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const;
478480

479481
private:
480482
struct impl;

0 commit comments

Comments
 (0)