@@ -640,14 +640,20 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
640640 }
641641 if (arg == " --lora" ) {
642642 CHECK_ARG
643- params.lora_adapter .emplace_back (argv[i], 1 .0f );
643+ params.lora_adapters .push_back ({
644+ std::string (argv[i]),
645+ 1.0 ,
646+ });
644647 return true ;
645648 }
646649 if (arg == " --lora-scaled" ) {
647650 CHECK_ARG
648651 const char * lora_adapter = argv[i];
649652 CHECK_ARG
650- params.lora_adapter .emplace_back (lora_adapter, std::stof (argv[i]));
653+ params.lora_adapters .push_back ({
654+ lora_adapter,
655+ std::stof (argv[i]),
656+ });
651657 return true ;
652658 }
653659 if (arg == " --control-vector" ) {
@@ -1725,6 +1731,17 @@ std::string string_get_sortable_timestamp() {
17251731 return std::string (timestamp_no_ns) + " ." + std::string (timestamp_ns);
17261732}
17271733
1734+ void string_replace_all (std::string & s, const std::string & search, const std::string & replace) {
1735+ if (search.empty ()) {
1736+ return ; // Avoid infinite loop if 'search' is an empty string
1737+ }
1738+ size_t pos = 0 ;
1739+ while ((pos = s.find (search, pos)) != std::string::npos) {
1740+ s.replace (pos, search.length (), replace);
1741+ pos += replace.length ();
1742+ }
1743+ }
1744+
17281745void string_process_escapes (std::string & input) {
17291746 std::size_t input_len = input.length ();
17301747 std::size_t output_idx = 0 ;
@@ -1998,8 +2015,8 @@ std::string fs_get_cache_file(const std::string & filename) {
19982015//
19992016// Model utils
20002017//
2001-
2002- std::tuple< struct llama_model *, struct llama_context *> llama_init_from_gpt_params (gpt_params & params) {
2018+ struct llama_init_result llama_init_from_gpt_params (gpt_params & params) {
2019+ llama_init_result iparams;
20032020 auto mparams = llama_model_params_from_gpt_params (params);
20042021
20052022 llama_model * model = nullptr ;
@@ -2014,7 +2031,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
20142031
20152032 if (model == NULL ) {
20162033 fprintf (stderr, " %s: error: failed to load model '%s'\n " , __func__, params.model .c_str ());
2017- return std::make_tuple ( nullptr , nullptr ) ;
2034+ return iparams ;
20182035 }
20192036
20202037 auto cparams = llama_context_params_from_gpt_params (params);
@@ -2023,7 +2040,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
20232040 if (lctx == NULL ) {
20242041 fprintf (stderr, " %s: error: failed to create context with model '%s'\n " , __func__, params.model .c_str ());
20252042 llama_free_model (model);
2026- return std::make_tuple ( nullptr , nullptr ) ;
2043+ return iparams ;
20272044 }
20282045
20292046 if (!params.control_vectors .empty ()) {
@@ -2034,7 +2051,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
20342051 if (cvec.n_embd == -1 ) {
20352052 llama_free (lctx);
20362053 llama_free_model (model);
2037- return std::make_tuple ( nullptr , nullptr ) ;
2054+ return iparams ;
20382055 }
20392056
20402057 int err = llama_control_vector_apply (lctx,
@@ -2046,21 +2063,26 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
20462063 if (err) {
20472064 llama_free (lctx);
20482065 llama_free_model (model);
2049- return std::make_tuple ( nullptr , nullptr ) ;
2066+ return iparams ;
20502067 }
20512068 }
20522069
2053- for (unsigned int i = 0 ; i < params.lora_adapter .size (); ++i) {
2054- const std::string & lora_adapter = std::get<0 >(params.lora_adapter [i]);
2055- float lora_scale = std::get<1 >(params.lora_adapter [i]);
2056- auto adapter = llama_lora_adapter_init (model, lora_adapter.c_str ());
2057- if (adapter == nullptr ) {
2058- fprintf (stderr, " %s: error: failed to apply lora adapter\n " , __func__);
2070+ // load and optionally apply lora adapters
2071+ for (auto & la : params.lora_adapters ) {
2072+ llama_lora_adapter_container loaded_la;
2073+ loaded_la.path = la.path ;
2074+ loaded_la.scale = la.scale ;
2075+ loaded_la.adapter = llama_lora_adapter_init (model, la.path .c_str ());
2076+ if (loaded_la.adapter == nullptr ) {
2077+ fprintf (stderr, " %s: error: failed to apply lora adapter '%s'\n " , __func__, la.path .c_str ());
20592078 llama_free (lctx);
20602079 llama_free_model (model);
2061- return std::make_tuple ( nullptr , nullptr ) ;
2080+ return iparams ;
20622081 }
2063- llama_lora_adapter_set (lctx, adapter, lora_scale);
2082+ iparams.lora_adapters .push_back (loaded_la); // copy to list of loaded adapters
2083+ }
2084+ if (!params.lora_init_without_apply ) {
2085+ llama_lora_adapters_apply (lctx, iparams.lora_adapters );
20642086 }
20652087
20662088 if (params.ignore_eos ) {
@@ -2088,13 +2110,26 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
20882110 tmp.clear ();
20892111 tmp.push_back (decoder_start_token_id);
20902112 }
2091- llama_decode (lctx, llama_batch_get_one (tmp.data (), std::min (tmp.size (), (size_t ) params.n_batch ), 0 , 0 ));
2113+ if (llama_model_has_decoder (model)) {
2114+ llama_decode (lctx, llama_batch_get_one (tmp.data (), std::min (tmp.size (), (size_t ) params.n_batch ), 0 , 0 ));
2115+ }
20922116 llama_kv_cache_clear (lctx);
20932117 llama_synchronize (lctx);
20942118 llama_reset_timings (lctx);
20952119 }
20962120
2097- return std::make_tuple (model, lctx);
2121+ iparams.model = model;
2122+ iparams.context = lctx;
2123+ return iparams;
2124+ }
2125+
2126+ void llama_lora_adapters_apply (struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters) {
2127+ llama_lora_adapter_clear (ctx);
2128+ for (auto & la : lora_adapters) {
2129+ if (la.scale != 0 .0f ) {
2130+ llama_lora_adapter_set (ctx, la.adapter , la.scale );
2131+ }
2132+ }
20982133}
20992134
21002135struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params) {
@@ -3126,18 +3161,16 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
31263161 }
31273162
31283163 fprintf (stream, " lora:\n " );
3129- for (std::tuple<std::string, float > la : params.lora_adapter ) {
3130- if (std::get< 1 >(la) ! = 1 .0f ) {
3131- continue ;
3164+ for (auto & la : params.lora_adapters ) {
3165+ if (la. scale = = 1 .0f ) {
3166+ fprintf (stream, " - %s \n " , la. path . c_str ()) ;
31323167 }
3133- fprintf (stream, " - %s\n " , std::get<0 >(la).c_str ());
31343168 }
31353169 fprintf (stream, " lora_scaled:\n " );
3136- for (std::tuple<std::string, float > la : params.lora_adapter ) {
3137- if (std::get< 1 >(la) = = 1 .0f ) {
3138- continue ;
3170+ for (auto & la : params.lora_adapters ) {
3171+ if (la. scale ! = 1 .0f ) {
3172+ fprintf (stream, " - %s: %f \n " , la. path . c_str (), la. scale ) ;
31393173 }
3140- fprintf (stream, " - %s: %f\n " , std::get<0 >(la).c_str (), std::get<1 >(la));
31413174 }
31423175 fprintf (stream, " main_gpu: %d # default: 0\n " , params.main_gpu );
31433176 fprintf (stream, " min_keep: %d # default: 0 (disabled)\n " , sparams.min_keep );
0 commit comments