44#include  < unistd.h> 
55#endif 
66
7- #include  < climits> 
87#include  < cstdio> 
98#include  < cstring> 
109#include  < iostream> 
1110#include  < sstream> 
1211#include  < string> 
13- #include  < unordered_map> 
1412#include  < vector> 
1513
1614#include  " llama-cpp.h" 
1715
1816typedef  std::unique_ptr<char []> char_array_ptr;
1917
20- struct   Argument  {
21-     std::string flag; 
22-     std::string help_text; 
23- } ;
24- 
25- struct   Options  {
26-     std::string model_path, prompt_non_interactive ;
27-     int  ngl =  99 ;
28-     int  n_ctx =  2048 ;
29- }; 
18+ class   Opt  {
19+   public: 
20+     int   init_opt ( int  argc,  const   char  ** argv) { 
21+          construct_help_str_ () ;
22+          //  Parse arguments 
23+          if  ( parse (argc, argv))  {
24+              fprintf (stderr,  " Error: Failed to parse arguments. \n " ) ;
25+              help () ;
26+              return   1 ;
27+         } 
3028
31- class  ArgumentParser  {
32-    public: 
33-     ArgumentParser (const  char  * program_name) : program_name(program_name) {}
29+         //  If help is requested, show help and exit
30+         if  (help_) {
31+             help ();
32+             return  2 ;
33+         }
3434
35-     void  add_argument (const  std::string & flag, std::string & var, const  std::string & help_text = " "  ) {
36-         string_args[flag] = &var;
37-         arguments.push_back ({flag, help_text});
35+         return  0 ;  //  Success
3836    }
3937
40-     void  add_argument (const  std::string & flag, int  & var, const  std::string & help_text = " "  ) {
41-         int_args[flag] = &var;
42-         arguments.push_back ({flag, help_text});
38+     const  char  *             model_ = nullptr ;
39+     std::string              prompt_;
40+     int                       context_size_ = 2048 , ngl_ = 0 ;
41+ 
42+   private: 
43+     std::string help_str_;
44+     bool         help_ = false ;
45+ 
46+     void  construct_help_str_ () {
47+         help_str_ =
48+             " Description:\n " 
49+             "   Runs a llm\n " 
50+             " \n " 
51+             " Usage:\n " 
52+             "   llama-run [options] MODEL [PROMPT]\n " 
53+             " \n " 
54+             " Options:\n " 
55+             "   -c, --context-size <value>\n " 
56+             "       Context size (default: "   +
57+             std::to_string (context_size_);
58+         help_str_ +=
59+             " )\n " 
60+             "   -n, --ngl <value>\n " 
61+             "       Number of GPU layers (default: "   +
62+             std::to_string (ngl_);
63+         help_str_ +=
64+             " )\n " 
65+             "   -h, --help\n " 
66+             "       Show help message\n " 
67+             " \n " 
68+             " Examples:\n " 
69+             "   llama-run your_model.gguf\n " 
70+             "   llama-run --ngl 99 your_model.gguf\n " 
71+             "   llama-run --ngl 99 your_model.gguf\n "  ;
4372    }
4473
4574    int  parse (int  argc, const  char  ** argv) {
75+         if  (parse_arguments (argc, argv) || !model_) {
76+             return  1 ;
77+         }
78+ 
79+         return  0 ;
80+     }
81+ 
82+     int  parse_arguments (int  argc, const  char  ** argv) {
83+         int  positional_args_i = 0 ;
4684        for  (int  i = 1 ; i < argc; ++i) {
47-             std::string arg = argv[i];
48-             if  (string_args.count (arg)) {
49-                 if  (i + 1  < argc) {
50-                     *string_args[arg] = argv[++i];
51-                 } else  {
52-                     fprintf (stderr, " error: missing value for %s\n "  , arg.c_str ());
53-                     print_usage ();
85+             if  (std::strcmp (argv[i], " -c"  ) == 0  || std::strcmp (argv[i], " --context-size"  ) == 0 ) {
86+                 if  (i + 1  >= argc) {
5487                    return  1 ;
5588                }
56-             } else  if  (int_args.count (arg)) {
57-                 if  (i + 1  < argc) {
58-                     if  (parse_int_arg (argv[++i], *int_args[arg]) != 0 ) {
59-                         fprintf (stderr, " error: invalid value for %s: %s\n "  , arg.c_str (), argv[i]);
60-                         print_usage ();
61-                         return  1 ;
62-                     }
63-                 } else  {
64-                     fprintf (stderr, " error: missing value for %s\n "  , arg.c_str ());
65-                     print_usage ();
89+ 
90+                 context_size_ = std::atoi (argv[++i]);
91+             } else  if  (std::strcmp (argv[i], " -n"  ) == 0  || std::strcmp (argv[i], " --ngl"  ) == 0 ) {
92+                 if  (i + 1  >= argc) {
6693                    return  1 ;
6794                }
95+ 
96+                 ngl_ = std::atoi (argv[++i]);
97+             } else  if  (std::strcmp (argv[i], " -h"  ) == 0  || std::strcmp (argv[i], " --help"  ) == 0 ) {
98+                 help_ = true ;
99+                 model_ = argv[i];
100+                 return  0 ;
101+             } else  if  (!positional_args_i) {
102+                 ++positional_args_i;
103+                 model_ = argv[i];
104+             } else  if  (positional_args_i == 1 ) {
105+                 ++positional_args_i;
106+                 prompt_ = argv[i];
68107            } else  {
69-                 fprintf (stderr, " error: unrecognized argument %s\n "  , arg.c_str ());
70-                 print_usage ();
71-                 return  1 ;
108+                 prompt_ += "  "   + std::string (argv[i]);
72109            }
73110        }
74111
75-         if  (string_args[" -m"  ]->empty ()) {
76-             fprintf (stderr, " error: -m is required\n "  );
77-             print_usage ();
78-             return  1 ;
79-         }
80- 
81112        return  0 ;
82113    }
83114
84-    private: 
85-     const  char  * program_name;
86-     std::unordered_map<std::string, std::string *> string_args;
87-     std::unordered_map<std::string, int  *> int_args;
88-     std::vector<Argument> arguments;
89- 
90-     int  parse_int_arg (const  char  * arg, int  & value) {
91-         char  * end;
92-         const  long  val = std::strtol (arg, &end, 10 );
93-         if  (*end == ' \0 '   && val >= INT_MIN && val <= INT_MAX) {
94-             value = static_cast <int >(val);
95-             return  0 ;
96-         }
97-         return  1 ;
98-     }
99- 
100-     void  print_usage () const  {
101-         printf (" \n Usage:\n "  );
102-         printf ("   %s [OPTIONS]\n\n "  , program_name);
103-         printf (" Options:\n "  );
104-         for  (const  auto  & arg : arguments) {
105-             printf ("   %-10s %s\n "  , arg.flag .c_str (), arg.help_text .c_str ());
106-         }
107- 
108-         printf (" \n "  );
109-     }
115+     void  help () const  { printf (" %s"  , help_str_.c_str ()); }
110116};
111117
112118class  LlamaData  {
@@ -116,13 +122,13 @@ class LlamaData {
116122    llama_context_ptr context;
117123    std::vector<llama_chat_message> messages;
118124
119-     int  init (const  Options  & opt) {
120-         model = initialize_model (opt.model_path , opt.ngl );
125+     int  init (const  Opt  & opt) {
126+         model = initialize_model (opt.model_ , opt.ngl_ );
121127        if  (!model) {
122128            return  1 ;
123129        }
124130
125-         context = initialize_context (model, opt.n_ctx );
131+         context = initialize_context (model, opt.context_size_ );
126132        if  (!context) {
127133            return  1 ;
128134        }
@@ -134,6 +140,7 @@ class LlamaData {
134140   private: 
135141    //  Initializes the model and returns a unique pointer to it
136142    llama_model_ptr initialize_model (const  std::string & model_path, const  int  ngl) {
143+         ggml_backend_load_all ();
137144        llama_model_params model_params = llama_model_default_params ();
138145        model_params.n_gpu_layers  = ngl;
139146
@@ -273,19 +280,6 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
273280    return  0 ;
274281}
275282
276- static  int  parse_arguments (const  int  argc, const  char  ** argv, Options & opt) {
277-     ArgumentParser parser (argv[0 ]);
278-     parser.add_argument (" -m"  , opt.model_path , " model"  );
279-     parser.add_argument (" -p"  , opt.prompt_non_interactive , " prompt"  );
280-     parser.add_argument (" -c"  , opt.n_ctx , " context_size"  );
281-     parser.add_argument (" -ngl"  , opt.ngl , " n_gpu_layers"  );
282-     if  (parser.parse (argc, argv)) {
283-         return  1 ;
284-     }
285- 
286-     return  0 ;
287- }
288- 
289283static  int  read_user_input (std::string & user) {
290284    std::getline (std::cin, user);
291285    return  user.empty ();  //  Indicate an error or empty input
@@ -382,17 +376,20 @@ static std::string read_pipe_data() {
382376}
383377
384378int  main (int  argc, const  char  ** argv) {
385-     Options opt;
386-     if  (parse_arguments (argc, argv, opt)) {
379+     Opt       opt;
380+     const  int  opt_ret = opt.init_opt (argc, argv);
381+     if  (opt_ret == 2 ) {
382+         return  0 ;
383+     } else  if  (opt_ret) {
387384        return  1 ;
388385    }
389386
390387    if  (!is_stdin_a_terminal ()) {
391-         if  (!opt.prompt_non_interactive .empty ()) {
392-             opt.prompt_non_interactive  += " \n\n "  ;
388+         if  (!opt.prompt_ .empty ()) {
389+             opt.prompt_  += " \n\n "  ;
393390        }
394391
395-         opt.prompt_non_interactive  += read_pipe_data ();
392+         opt.prompt_  += read_pipe_data ();
396393    }
397394
398395    llama_log_set (log_callback, nullptr );
@@ -401,7 +398,7 @@ int main(int argc, const char ** argv) {
401398        return  1 ;
402399    }
403400
404-     if  (chat_loop (llama_data, opt.prompt_non_interactive )) {
401+     if  (chat_loop (llama_data, opt.prompt_ )) {
405402        return  1 ;
406403    }
407404
0 commit comments