@@ -55,29 +55,51 @@ static int printe(const char * fmt, ...) {
5555class Opt {
5656 public:
5757 int init (int argc, const char ** argv) {
58+ ctx_params = llama_context_default_params ();
59+ model_params = llama_model_default_params ();
60+ context_size_default = ctx_params.n_batch ;
61+ ngl_default = model_params.n_gpu_layers ;
62+ common_params_sampling sampling;
63+ temperature_default = sampling.temp ;
64+
65+ if (argc < 2 ) {
66+ printe (" Error: No arguments provided.\n " );
67+ print_help ();
68+ return 1 ;
69+ }
70+
5871 // Parse arguments
5972 if (parse (argc, argv)) {
6073 printe (" Error: Failed to parse arguments.\n " );
61- help ();
74+ print_help ();
6275 return 1 ;
6376 }
6477
6578 // If help is requested, show help and exit
66- if (help_ ) {
67- help ();
79+ if (help ) {
80+ print_help ();
6881 return 2 ;
6982 }
7083
84+ ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default;
85+ model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default;
86+ temperature = temperature >= 0 ? temperature : temperature_default;
87+
7188 return 0 ; // Success
7289 }
7390
91+ llama_context_params ctx_params;
92+ llama_model_params model_params;
7493 std::string model_;
75- std::string user_;
76- int context_size_ = -1 , ngl_ = -1 ;
77- bool verbose_ = false ;
94+ std::string user;
95+ int context_size = -1 , ngl = -1 ;
96+ float temperature = -1 ;
97+ bool verbose = false ;
7898
7999 private:
80- bool help_ = false ;
100+ int context_size_default = -1 , ngl_default = -1 ;
101+ float temperature_default = -1 ;
102+ bool help = false ;
81103
82104 bool parse_flag (const char ** argv, int i, const char * short_opt, const char * long_opt) {
83105 return strcmp (argv[i], short_opt) == 0 || strcmp (argv[i], long_opt) == 0 ;
@@ -89,25 +111,40 @@ class Opt {
89111 }
90112
91113 option_value = std::atoi (argv[++i]);
114+
115+ return 0 ;
116+ }
117+
118+ int handle_option_with_value (int argc, const char ** argv, int & i, float & option_value) {
119+ if (i + 1 >= argc) {
120+ return 1 ;
121+ }
122+
123+ option_value = std::atof (argv[++i]);
124+
92125 return 0 ;
93126 }
94127
95128 int parse (int argc, const char ** argv) {
96129 bool options_parsing = true ;
97130 for (int i = 1 , positional_args_i = 0 ; i < argc; ++i) {
98131 if (options_parsing && (strcmp (argv[i], " -c" ) == 0 || strcmp (argv[i], " --context-size" ) == 0 )) {
99- if (handle_option_with_value (argc, argv, i, context_size_ ) == 1 ) {
132+ if (handle_option_with_value (argc, argv, i, context_size ) == 1 ) {
100133 return 1 ;
101134 }
102135 } else if (options_parsing && (strcmp (argv[i], " -n" ) == 0 || strcmp (argv[i], " --ngl" ) == 0 )) {
103- if (handle_option_with_value (argc, argv, i, ngl_) == 1 ) {
136+ if (handle_option_with_value (argc, argv, i, ngl) == 1 ) {
137+ return 1 ;
138+ }
139+ } else if (options_parsing && strcmp (argv[i], " --temp" ) == 0 ) {
140+ if (handle_option_with_value (argc, argv, i, temperature) == 1 ) {
104141 return 1 ;
105142 }
106143 } else if (options_parsing &&
107144 (parse_flag (argv, i, " -v" , " --verbose" ) || parse_flag (argv, i, " -v" , " --log-verbose" ))) {
108- verbose_ = true ;
145+ verbose = true ;
109146 } else if (options_parsing && parse_flag (argv, i, " -h" , " --help" )) {
110- help_ = true ;
147+ help = true ;
111148 return 0 ;
112149 } else if (options_parsing && strcmp (argv[i], " --" ) == 0 ) {
113150 options_parsing = false ;
@@ -120,16 +157,16 @@ class Opt {
120157 model_ = argv[i];
121158 } else if (positional_args_i == 1 ) {
122159 ++positional_args_i;
123- user_ = argv[i];
160+ user = argv[i];
124161 } else {
125- user_ += " " + std::string (argv[i]);
162+ user += " " + std::string (argv[i]);
126163 }
127164 }
128165
129166 return 0 ;
130167 }
131168
132- void help () const {
169+ void print_help () const {
133170 printf (
134171 " Description:\n "
135172 " Runs a llm\n "
@@ -142,6 +179,8 @@ class Opt {
142179 " Context size (default: %d)\n "
143180 " -n, --ngl <value>\n "
144181 " Number of GPU layers (default: %d)\n "
182+ " --temp <value>\n "
183+ " Temperature (default: %.1f)\n "
145184 " -v, --verbose, --log-verbose\n "
146185 " Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n "
147186 " -h, --help\n "
@@ -170,7 +209,7 @@ class Opt {
170209 " llama-run file://some-file3.gguf\n "
171210 " llama-run --ngl 999 some-file4.gguf\n "
172211 " llama-run --ngl 999 some-file5.gguf Hello World\n " ,
173- llama_context_default_params (). n_batch , llama_model_default_params (). n_gpu_layers );
212+ context_size_default, ngl_default, temperature_default );
174213 }
175214};
176215
@@ -495,12 +534,12 @@ class LlamaData {
495534 return 1 ;
496535 }
497536
498- context = initialize_context (model, opt. context_size_ );
537+ context = initialize_context (model, opt);
499538 if (!context) {
500539 return 1 ;
501540 }
502541
503- sampler = initialize_sampler ();
542+ sampler = initialize_sampler (opt );
504543 return 0 ;
505544 }
506545
@@ -619,14 +658,12 @@ class LlamaData {
619658 // Initializes the model and returns a unique pointer to it
620659 llama_model_ptr initialize_model (Opt & opt) {
621660 ggml_backend_load_all ();
622- llama_model_params model_params = llama_model_default_params ();
623- model_params.n_gpu_layers = opt.ngl_ >= 0 ? opt.ngl_ : model_params.n_gpu_layers ;
624661 resolve_model (opt.model_ );
625662 printe (
626663 " \r %*s"
627664 " \r Loading model" ,
628665 get_terminal_width (), " " );
629- llama_model_ptr model (llama_load_model_from_file (opt.model_ .c_str (), model_params));
666+ llama_model_ptr model (llama_load_model_from_file (opt.model_ .c_str (), opt. model_params ));
630667 if (!model) {
631668 printe (" %s: error: unable to load model from file: %s\n " , __func__, opt.model_ .c_str ());
632669 }
@@ -636,10 +673,8 @@ class LlamaData {
636673 }
637674
638675 // Initializes the context with the specified parameters
639- llama_context_ptr initialize_context (const llama_model_ptr & model, const int n_ctx) {
640- llama_context_params ctx_params = llama_context_default_params ();
641- ctx_params.n_ctx = ctx_params.n_batch = n_ctx >= 0 ? n_ctx : ctx_params.n_batch ;
642- llama_context_ptr context (llama_new_context_with_model (model.get (), ctx_params));
676+ llama_context_ptr initialize_context (const llama_model_ptr & model, const Opt & opt) {
677+ llama_context_ptr context (llama_new_context_with_model (model.get (), opt.ctx_params ));
643678 if (!context) {
644679 printe (" %s: error: failed to create the llama_context\n " , __func__);
645680 }
@@ -648,10 +683,10 @@ class LlamaData {
648683 }
649684
650685 // Initializes and configures the sampler
651- llama_sampler_ptr initialize_sampler () {
686+ llama_sampler_ptr initialize_sampler (const Opt & opt ) {
652687 llama_sampler_ptr sampler (llama_sampler_chain_init (llama_sampler_chain_default_params ()));
653688 llama_sampler_chain_add (sampler.get (), llama_sampler_init_min_p (0 .05f , 1 ));
654- llama_sampler_chain_add (sampler.get (), llama_sampler_init_temp (0 . 8f ));
689+ llama_sampler_chain_add (sampler.get (), llama_sampler_init_temp (opt. temperature ));
655690 llama_sampler_chain_add (sampler.get (), llama_sampler_init_dist (LLAMA_DEFAULT_SEED));
656691
657692 return sampler;
@@ -798,9 +833,9 @@ static int apply_chat_template_with_error_handling(LlamaData & llama_data, const
798833}
799834
800835// Helper function to handle user input
801- static int handle_user_input (std::string & user_input, const std::string & user_ ) {
802- if (!user_ .empty ()) {
803- user_input = user_ ;
836+ static int handle_user_input (std::string & user_input, const std::string & user ) {
837+ if (!user .empty ()) {
838+ user_input = user ;
804839 return 0 ; // No need for interactive input
805840 }
806841
@@ -832,17 +867,17 @@ static bool is_stdout_a_terminal() {
832867}
833868
834869// Function to tokenize the prompt
835- static int chat_loop (LlamaData & llama_data, const std::string & user_ ) {
870+ static int chat_loop (LlamaData & llama_data, const std::string & user ) {
836871 int prev_len = 0 ;
837872 llama_data.fmtted .resize (llama_n_ctx (llama_data.context .get ()));
838873 static const bool stdout_a_terminal = is_stdout_a_terminal ();
839874 while (true ) {
840875 // Get user input
841876 std::string user_input;
842- while (handle_user_input (user_input, user_ )) {
877+ while (handle_user_input (user_input, user )) {
843878 }
844879
845- add_message (" user" , user_ .empty () ? user_input : user_ , llama_data);
880+ add_message (" user" , user .empty () ? user_input : user , llama_data);
846881 int new_len;
847882 if (apply_chat_template_with_error_handling (llama_data, true , new_len) < 0 ) {
848883 return 1 ;
@@ -854,7 +889,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
854889 return 1 ;
855890 }
856891
857- if (!user_ .empty ()) {
892+ if (!user .empty ()) {
858893 break ;
859894 }
860895
@@ -869,7 +904,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
869904
870905static void log_callback (const enum ggml_log_level level, const char * text, void * p) {
871906 const Opt * opt = static_cast <Opt *>(p);
872- if (opt->verbose_ || level == GGML_LOG_LEVEL_ERROR) {
907+ if (opt->verbose || level == GGML_LOG_LEVEL_ERROR) {
873908 printe (" %s" , text);
874909 }
875910}
@@ -890,11 +925,11 @@ int main(int argc, const char ** argv) {
890925 }
891926
892927 if (!is_stdin_a_terminal ()) {
893- if (!opt.user_ .empty ()) {
894- opt.user_ += " \n\n " ;
928+ if (!opt.user .empty ()) {
929+ opt.user += " \n\n " ;
895930 }
896931
897- opt.user_ += read_pipe_data ();
932+ opt.user += read_pipe_data ();
898933 }
899934
900935 llama_log_set (log_callback, &opt);
@@ -903,7 +938,7 @@ int main(int argc, const char ** argv) {
903938 return 1 ;
904939 }
905940
906- if (chat_loop (llama_data, opt.user_ )) {
941+ if (chat_loop (llama_data, opt.user )) {
907942 return 1 ;
908943 }
909944
0 commit comments