1717
1818typedef std::unique_ptr<char []> char_array_ptr;
1919
20- struct Argument {
21- std::string flag;
22- std::string help_text;
23- } ;
24-
25- struct Options {
26- std::string model_path, prompt_non_interactive ;
27- int ngl = 99 ;
28- int n_ctx = 2048 ;
29- };
20+ class Opt {
21+ public:
22+ int init_opt ( int argc, const char ** argv) {
23+ construct_help_str_ () ;
24+ // Parse arguments
25+ if ( parse (argc, argv)) {
26+ fprintf (stderr, " Error: Failed to parse arguments. \n " ) ;
27+ help () ;
28+ return 1 ;
29+ }
3030
31- class ArgumentParser {
32- public:
33- ArgumentParser (const char * program_name) : program_name(program_name) {}
31+ // If help is requested, show help and exit
32+ if (help_) {
33+ help ();
34+ return 2 ;
35+ }
3436
35- void add_argument (const std::string & flag, std::string & var, const std::string & help_text = " " ) {
36- string_args[flag] = &var;
37- arguments.push_back ({flag, help_text});
37+ return 0 ; // Success
3838 }
3939
40- void add_argument (const std::string & flag, int & var, const std::string & help_text = " " ) {
41- int_args[flag] = &var;
42- arguments.push_back ({flag, help_text});
40+ const char * model_ = nullptr ;
41+ std::string prompt_;
42+ int context_size_ = 2048 , ngl_ = 0 ;
43+
44+ private:
45+ std::string help_str_;
46+ bool help_ = false ;
47+
48+ void construct_help_str_ () {
49+ help_str_ =
50+ " Description:\n "
51+ " Runs a llm\n "
52+ " \n "
53+ " Usage:\n "
54+ " llama-run [options] MODEL [PROMPT]\n "
55+ " \n "
56+ " Options:\n "
57+ " -c, --context-size <value>\n "
58+ " Context size (default: " +
59+ std::to_string (context_size_);
60+ help_str_ +=
61+ " )\n "
62+ " -n, --ngl <value>\n "
63+ " Number of GPU layers (default: " +
64+ std::to_string (ngl_);
65+ help_str_ +=
66+ " )\n "
67+ " -h, --help\n "
68+ " Show help message\n "
69+ " \n "
70+ " Examples:\n "
71+ " llama-run ~/.local/share/ramalama/models/ollama/smollm\\ :135m\n "
72+ " llama-run --ngl 99 ~/.local/share/ramalama/models/ollama/smollm\\ :135m\n "
73+ " llama-run --ngl 99 ~/.local/share/ramalama/models/ollama/smollm\\ :135m Hello World\n " ;
4374 }
4475
4576 int parse (int argc, const char ** argv) {
77+ if (parse_arguments (argc, argv) || !model_) {
78+ return 1 ;
79+ }
80+
81+ return 0 ;
82+ }
83+
84+ int parse_arguments (int argc, const char ** argv) {
85+ int positional_args_i = 0 ;
4686 for (int i = 1 ; i < argc; ++i) {
47- std::string arg = argv[i];
48- if (string_args.count (arg)) {
49- if (i + 1 < argc) {
50- *string_args[arg] = argv[++i];
51- } else {
52- fprintf (stderr, " error: missing value for %s\n " , arg.c_str ());
53- print_usage ();
87+ if (std::strcmp (argv[i], " -c" ) == 0 || std::strcmp (argv[i], " --context-size" ) == 0 ) {
88+ if (i + 1 >= argc) {
5489 return 1 ;
5590 }
56- } else if (int_args.count (arg)) {
57- if (i + 1 < argc) {
58- if (parse_int_arg (argv[++i], *int_args[arg]) != 0 ) {
59- fprintf (stderr, " error: invalid value for %s: %s\n " , arg.c_str (), argv[i]);
60- print_usage ();
61- return 1 ;
62- }
63- } else {
64- fprintf (stderr, " error: missing value for %s\n " , arg.c_str ());
65- print_usage ();
91+
92+ context_size_ = std::atoi (argv[++i]);
93+ } else if (std::strcmp (argv[i], " -n" ) == 0 || std::strcmp (argv[i], " --ngl" ) == 0 ) {
94+ if (i + 1 >= argc) {
6695 return 1 ;
6796 }
97+
98+ ngl_ = std::atoi (argv[++i]);
99+ } else if (std::strcmp (argv[i], " -h" ) == 0 || std::strcmp (argv[i], " --help" ) == 0 ) {
100+ help_ = true ;
101+ } else if (!positional_args_i) {
102+ ++positional_args_i;
103+ model_ = argv[i];
104+ } else if (positional_args_i == 1 ) {
105+ prompt_ = argv[i];
68106 } else {
69- fprintf (stderr, " error: unrecognized argument %s\n " , arg.c_str ());
70- print_usage ();
71- return 1 ;
107+ prompt_ += " " + std::string (argv[i]);
72108 }
73109 }
74110
75- if (string_args[" -m" ]->empty ()) {
76- fprintf (stderr, " error: -m is required\n " );
77- print_usage ();
78- return 1 ;
79- }
80-
81111 return 0 ;
82112 }
83113
84- private:
85- const char * program_name;
86- std::unordered_map<std::string, std::string *> string_args;
87- std::unordered_map<std::string, int *> int_args;
88- std::vector<Argument> arguments;
89-
90- int parse_int_arg (const char * arg, int & value) {
91- char * end;
92- const long val = std::strtol (arg, &end, 10 );
93- if (*end == ' \0 ' && val >= INT_MIN && val <= INT_MAX) {
94- value = static_cast <int >(val);
95- return 0 ;
96- }
97- return 1 ;
98- }
99-
100- void print_usage () const {
101- printf (" \n Usage:\n " );
102- printf (" %s [OPTIONS]\n\n " , program_name);
103- printf (" Options:\n " );
104- for (const auto & arg : arguments) {
105- printf (" %-10s %s\n " , arg.flag .c_str (), arg.help_text .c_str ());
106- }
107-
108- printf (" \n " );
109- }
114+ void help () const { printf (" %s" , help_str_.c_str ()); }
110115};
111116
112117class LlamaData {
@@ -116,13 +121,13 @@ class LlamaData {
116121 llama_context_ptr context;
117122 std::vector<llama_chat_message> messages;
118123
119- int init (const Options & opt) {
120- model = initialize_model (opt.model_path , opt.ngl );
124+ int init (const Opt & opt) {
125+ model = initialize_model (opt.model_ , opt.ngl_ );
121126 if (!model) {
122127 return 1 ;
123128 }
124129
125- context = initialize_context (model, opt.n_ctx );
130+ context = initialize_context (model, opt.context_size_ );
126131 if (!context) {
127132 return 1 ;
128133 }
@@ -273,19 +278,6 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
273278 return 0 ;
274279}
275280
276- static int parse_arguments (const int argc, const char ** argv, Options & opt) {
277- ArgumentParser parser (argv[0 ]);
278- parser.add_argument (" -m" , opt.model_path , " model" );
279- parser.add_argument (" -p" , opt.prompt_non_interactive , " prompt" );
280- parser.add_argument (" -c" , opt.n_ctx , " context_size" );
281- parser.add_argument (" -ngl" , opt.ngl , " n_gpu_layers" );
282- if (parser.parse (argc, argv)) {
283- return 1 ;
284- }
285-
286- return 0 ;
287- }
288-
289281static int read_user_input (std::string & user) {
290282 std::getline (std::cin, user);
291283 return user.empty (); // Indicate an error or empty input
@@ -382,17 +374,20 @@ static std::string read_pipe_data() {
382374}
383375
384376int main (int argc, const char ** argv) {
385- Options opt;
386- if (parse_arguments (argc, argv, opt)) {
377+ Opt opt;
378+ const int opt_ret = opt.init_opt (argc, argv);
379+ if (opt_ret == 2 ) {
380+ return 0 ;
381+ } else if (opt_ret) {
387382 return 1 ;
388383 }
389384
390385 if (!is_stdin_a_terminal ()) {
391- if (!opt.prompt_non_interactive .empty ()) {
392- opt.prompt_non_interactive += " \n\n " ;
386+ if (!opt.prompt_ .empty ()) {
387+ opt.prompt_ += " \n\n " ;
393388 }
394389
395- opt.prompt_non_interactive += read_pipe_data ();
390+ opt.prompt_ += read_pipe_data ();
396391 }
397392
398393 llama_log_set (log_callback, nullptr );
@@ -401,7 +396,7 @@ int main(int argc, const char ** argv) {
401396 return 1 ;
402397 }
403398
404- if (chat_loop (llama_data, opt.prompt_non_interactive )) {
399+ if (chat_loop (llama_data, opt.prompt_ )) {
405400 return 1 ;
406401 }
407402
0 commit comments