@@ -950,33 +950,31 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
950950// Model utils
951951//
952952
953- llama_model * common_load_model_from_params (common_params & params) {
953+ struct common_init_result common_init_from_params (common_params & params) {
954+ common_init_result iparams;
954955 auto mparams = common_model_params_to_llama (params);
955956
956957 llama_model * model = llama_model_load_from_file (params.model .path .c_str (), mparams);
957958 if (model == NULL ) {
958959 LOG_ERR (" %s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n " ,
959960 __func__, params.model .path .c_str ());
960- return nullptr ;
961- }
962-
963- return model;
964- }
965-
966- struct common_init_result common_init_context_from_model (
967- llama_model * model,
968- common_params & params) {
969- common_init_result iparams;
970-
971- if (model == NULL ) {
972- LOG_ERR (" %s: model is NULL\n " , __func__);
973961 return iparams;
974962 }
975963
976964 const llama_vocab * vocab = llama_model_get_vocab (model);
977965
978966 auto cparams = common_context_params_to_llama (params);
979967
968+ // backend sampling initialization
969+ if (params.sampling .backend_sampling ) {
970+ iparams.samplers_seq_config .resize (cparams.n_seq_max );
971+ for (int i = 0 ; i < (int ) cparams.n_seq_max ; ++i) {
972+ iparams.samplers_seq_config [i] = { i, common_sampler_backend_init (model, params.sampling ) };
973+ }
974+ cparams.samplers = iparams.samplers_seq_config .data ();
975+ cparams.n_samplers = cparams.n_seq_max ;
976+ }
977+
980978 llama_context * lctx = llama_init_from_model (model, cparams);
981979 if (lctx == NULL ) {
982980 LOG_ERR (" %s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n " ,
@@ -1142,14 +1140,6 @@ struct common_init_result common_init_context_from_model(
11421140 return iparams;
11431141}
11441142
1145- struct common_init_result common_init_from_params (common_params & params) {
1146- llama_model * model = common_load_model_from_params (params);
1147- if (model == NULL ) {
1148- return common_init_result ();
1149- }
1150- return common_init_context_from_model (model, params);
1151- }
1152-
11531143std::string get_model_endpoint () {
11541144 const char * model_endpoint_env = getenv (" MODEL_ENDPOINT" );
11551145 // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
@@ -1245,9 +1235,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
12451235 cparams.type_k = params.cache_type_k ;
12461236 cparams.type_v = params.cache_type_v ;
12471237
1248- cparams.samplers = params.backend_samplers ;
1249- cparams.n_samplers = params.n_backend_samplers ;
1250-
12511238 return cparams;
12521239}
12531240
0 commit comments