Skip to content

Commit b26c706

Browse files
committed
common : initialize backend samplers
1 parent e2d4f08 commit b26c706

File tree

3 files changed

+16
-52
lines changed

3 files changed

+16
-52
lines changed

common/common.cpp

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -950,33 +950,31 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
950950
// Model utils
951951
//
952952

953-
llama_model * common_load_model_from_params(common_params & params) {
953+
struct common_init_result common_init_from_params(common_params & params) {
954+
common_init_result iparams;
954955
auto mparams = common_model_params_to_llama(params);
955956

956957
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
957958
if (model == NULL) {
958959
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
959960
__func__, params.model.path.c_str());
960-
return nullptr;
961-
}
962-
963-
return model;
964-
}
965-
966-
struct common_init_result common_init_context_from_model(
967-
llama_model * model,
968-
common_params & params) {
969-
common_init_result iparams;
970-
971-
if (model == NULL) {
972-
LOG_ERR("%s: model is NULL\n", __func__);
973961
return iparams;
974962
}
975963

976964
const llama_vocab * vocab = llama_model_get_vocab(model);
977965

978966
auto cparams = common_context_params_to_llama(params);
979967

968+
// backend sampling initialization
969+
if (params.sampling.backend_sampling) {
970+
iparams.samplers_seq_config.resize(cparams.n_seq_max);
971+
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
972+
iparams.samplers_seq_config[i] = { i, common_sampler_backend_init(model, params.sampling) };
973+
}
974+
cparams.samplers = iparams.samplers_seq_config.data();
975+
cparams.n_samplers = cparams.n_seq_max;
976+
}
977+
980978
llama_context * lctx = llama_init_from_model(model, cparams);
981979
if (lctx == NULL) {
982980
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
@@ -1142,14 +1140,6 @@ struct common_init_result common_init_context_from_model(
11421140
return iparams;
11431141
}
11441142

1145-
struct common_init_result common_init_from_params(common_params & params) {
1146-
llama_model * model = common_load_model_from_params(params);
1147-
if (model == NULL) {
1148-
return common_init_result();
1149-
}
1150-
return common_init_context_from_model(model, params);
1151-
}
1152-
11531143
std::string get_model_endpoint() {
11541144
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
11551145
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
@@ -1245,9 +1235,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
12451235
cparams.type_k = params.cache_type_k;
12461236
cparams.type_v = params.cache_type_v;
12471237

1248-
cparams.samplers = params.backend_samplers;
1249-
cparams.n_samplers = params.n_backend_samplers;
1250-
12511238
return cparams;
12521239
}
12531240

common/common.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -523,9 +523,6 @@ struct common_params {
523523
bool has_speculative() const {
524524
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
525525
}
526-
527-
llama_sampler_seq_config * backend_samplers = NULL;
528-
size_t n_backend_samplers = 0;
529526
};
530527

531528
// call once at the start of a program if it uses libcommon
@@ -643,18 +640,13 @@ struct common_init_result {
643640
llama_context_ptr context;
644641

645642
std::vector<llama_adapter_lora_ptr> lora;
643+
644+
std::vector<llama_sampler_ptr> samplers;
645+
std::vector<llama_sampler_seq_config> samplers_seq_config;
646646
};
647647

648648
struct common_init_result common_init_from_params(common_params & params);
649649

650-
// Load model only (allows creating backend samplers before context initialization)
651-
llama_model * common_load_model_from_params(common_params & params);
652-
653-
// Initialize context from an already-loaded model (allows pre-configuring backend samplers)
654-
struct common_init_result common_init_context_from_model(
655-
llama_model * model,
656-
common_params & params);
657-
658650
struct llama_model_params common_model_params_to_llama ( common_params & params);
659651
struct llama_context_params common_context_params_to_llama(const common_params & params);
660652
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);

tools/main/main.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -138,22 +138,7 @@ int main(int argc, char ** argv) {
138138
// load the model and apply lora adapter, if any
139139
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
140140

141-
model = common_load_model_from_params(params);
142-
if (model == NULL) {
143-
LOG_ERR("%s: error: unable to load model\n", __func__);
144-
return 1;
145-
}
146-
147-
// Configure backend sampler if configured
148-
llama_sampler * backend_sampler = common_sampler_backend_init(model, sparams);
149-
llama_sampler_seq_config sampler_config = { 0, backend_sampler };
150-
151-
if (backend_sampler) {
152-
params.backend_samplers = &sampler_config;
153-
params.n_backend_samplers = 1;
154-
}
155-
156-
common_init_result llama_init = common_init_context_from_model(model, params);
141+
common_init_result llama_init = common_init_from_params(params);
157142
ctx = llama_init.context.get();
158143
model = llama_init.model.get(); // Update pointer (now managed by llama_init)
159144

0 commit comments

Comments
 (0)