11#if  defined(_WIN32)
22#    include  < windows.h> 
33#else 
4+ #    include  < sys/ioctl.h> 
45#    include  < unistd.h> 
56#endif 
67
2930class  Opt  {
3031  public: 
3132    int  init (int  argc, const  char  ** argv) {
32-         construct_help_str_ ();
3333        //  Parse arguments
3434        if  (parse (argc, argv)) {
3535            printe (" Error: Failed to parse arguments.\n "  );
@@ -48,14 +48,53 @@ class Opt {
4848
4949    std::string model_;
5050    std::string user_;
51-     int          context_size_ = 2048 , ngl_ = -1 ;
51+     int          context_size_ = -1 , ngl_ = -1 ;
52+     bool         verbose_ = false ;
5253
5354  private: 
54-     std::string help_str_;
5555    bool         help_ = false ;
5656
57-     void  construct_help_str_ () {
58-         help_str_ =
57+     int  parse (int  argc, const  char  ** argv) {
58+         int  positional_args_i = 0 ;
59+         for  (int  i = 1 ; i < argc; ++i) {
60+             if  (strcmp (argv[i], " -c"  ) == 0  || strcmp (argv[i], " --context-size"  ) == 0 ) {
61+                 if  (i + 1  >= argc) {
62+                     return  1 ;
63+                 }
64+ 
65+                 context_size_ = std::atoi (argv[++i]);
66+             } else  if  (strcmp (argv[i], " -n"  ) == 0  || strcmp (argv[i], " --ngl"  ) == 0 ) {
67+                 if  (i + 1  >= argc) {
68+                     return  1 ;
69+                 }
70+ 
71+                 ngl_ = std::atoi (argv[++i]);
72+             } else  if  (strcmp (argv[i], " -v"  ) == 0  || strcmp (argv[i], " --verbose"  ) == 0  ||
73+                        strcmp (argv[i], " --log-verbose"  ) == 0 ) {
74+                 verbose_ = true ;
75+             } else  if  (strcmp (argv[i], " -h"  ) == 0  || strcmp (argv[i], " --help"  ) == 0 ) {
76+                 help_ = true ;
77+                 return  0 ;
78+             } else  if  (!positional_args_i) {
79+                 if  (!argv[i][0 ] || argv[i][0 ] == ' -'  ) {
80+                     return  1 ;
81+                 }
82+ 
83+                 ++positional_args_i;
84+                 model_ = argv[i];
85+             } else  if  (positional_args_i == 1 ) {
86+                 ++positional_args_i;
87+                 user_ = argv[i];
88+             } else  {
89+                 user_ += "  "   + std::string (argv[i]);
90+             }
91+         }
92+ 
93+         return  model_.empty ();  //  model_ is the only required value
94+     }
95+ 
96+     void  help () const  {
97+         printf (
5998            " Description:\n " 
6099            "   Runs a llm\n " 
61100            " \n " 
@@ -64,15 +103,11 @@ class Opt {
64103            " \n " 
65104            " Options:\n " 
66105            "   -c, --context-size <value>\n " 
67-             "       Context size (default: "   +
68-             std::to_string (context_size_);
69-         help_str_ +=
70-             " )\n " 
106+             "       Context size (default: %d)\n " 
71107            "   -n, --ngl <value>\n " 
72-             "       Number of GPU layers (default: "   +
73-             std::to_string (ngl_);
74-         help_str_ +=
75-             " )\n " 
108+             "       Number of GPU layers (default: %d)\n " 
109+             "   -v, --verbose, --log-verbose\n " 
110+             "       Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n " 
76111            "   -h, --help\n " 
77112            "       Show help message\n " 
78113            " \n " 
@@ -96,43 +131,10 @@ class Opt {
96131            "   llama-run https://example.com/some-file1.gguf\n " 
97132            "   llama-run some-file2.gguf\n " 
98133            "   llama-run file://some-file3.gguf\n " 
99-             "   llama-run --ngl 99 some-file4.gguf\n " 
100-             "   llama-run --ngl 99 some-file5.gguf Hello World\n "  ;
101-     }
102- 
103-     int  parse (int  argc, const  char  ** argv) {
104-         int  positional_args_i = 0 ;
105-         for  (int  i = 1 ; i < argc; ++i) {
106-             if  (strcmp (argv[i], " -c"  ) == 0  || strcmp (argv[i], " --context-size"  ) == 0 ) {
107-                 if  (i + 1  >= argc) {
108-                     return  1 ;
109-                 }
110- 
111-                 context_size_ = std::atoi (argv[++i]);
112-             } else  if  (strcmp (argv[i], " -n"  ) == 0  || strcmp (argv[i], " --ngl"  ) == 0 ) {
113-                 if  (i + 1  >= argc) {
114-                     return  1 ;
115-                 }
116- 
117-                 ngl_ = std::atoi (argv[++i]);
118-             } else  if  (strcmp (argv[i], " -h"  ) == 0  || strcmp (argv[i], " --help"  ) == 0 ) {
119-                 help_ = true ;
120-                 return  0 ;
121-             } else  if  (!positional_args_i) {
122-                 ++positional_args_i;
123-                 model_ = argv[i];
124-             } else  if  (positional_args_i == 1 ) {
125-                 ++positional_args_i;
126-                 user_ = argv[i];
127-             } else  {
128-                 user_ += "  "   + std::string (argv[i]);
129-             }
130-         }
131- 
132-         return  model_.empty ();  //  model_ is the only required value
134+             "   llama-run --ngl 999 some-file4.gguf\n " 
135+             "   llama-run --ngl 999 some-file5.gguf Hello World\n "  ,
136+             llama_context_default_params ().n_batch , llama_model_default_params ().n_gpu_layers );
133137    }
134- 
135-     void  help () const  { printf (" %s"  , help_str_.c_str ()); }
136138};
137139
138140struct  progress_data  {
@@ -151,8 +153,20 @@ struct FileDeleter {
151153
152154typedef  std::unique_ptr<FILE, FileDeleter> FILE_ptr;
153155
156+ static  int  get_terminal_width () {
157+ #if  defined(_WIN32)
158+     CONSOLE_SCREEN_BUFFER_INFO csbi;
159+     GetConsoleScreenBufferInfo (GetStdHandle (STD_OUTPUT_HANDLE), &csbi);
160+     return  csbi.srWindow .Right  - csbi.srWindow .Left  + 1 ;
161+ #else 
162+     struct  winsize  w;
163+     ioctl (STDOUT_FILENO, TIOCGWINSZ, &w);
164+     return  w.ws_col ;
165+ #endif 
166+ }
167+ 
154168#ifdef  LLAMA_USE_CURL
155- class  CurlWrapper  {
169+ class  HttpClient  {
156170  public: 
157171    int  init (const  std::string & url, const  std::vector<std::string> & headers, const  std::string & output_file,
158172             const  bool  progress, std::string * response_str = nullptr ) {
@@ -181,7 +195,7 @@ class CurlWrapper {
181195        return  0 ;
182196    }
183197
184-     ~CurlWrapper  () {
198+     ~HttpClient  () {
185199        if  (chunk) {
186200            curl_slist_free_all (chunk);
187201        }
@@ -219,7 +233,7 @@ class CurlWrapper {
219233        if  (progress) {
220234            curl_easy_setopt (curl, CURLOPT_NOPROGRESS, 0L );
221235            curl_easy_setopt (curl, CURLOPT_XFERINFODATA, &data);
222-             curl_easy_setopt (curl, CURLOPT_XFERINFOFUNCTION, progress_callback );
236+             curl_easy_setopt (curl, CURLOPT_XFERINFOFUNCTION, update_progress );
223237        }
224238    }
225239
@@ -270,9 +284,9 @@ class CurlWrapper {
270284
271285    static  std::string human_readable_size (curl_off_t  size) {
272286        static  const  char  * suffix[] = { " B"  , " KB"  , " MB"  , " GB"  , " TB"   };
273-         char          length   = sizeof (suffix) / sizeof (suffix[0 ]);
274-         int           i        = 0 ;
275-         double        dbl_size = size;
287+         char                  length   = sizeof (suffix) / sizeof (suffix[0 ]);
288+         int                   i        = 0 ;
289+         double                dbl_size = size;
276290        if  (size > 1024 ) {
277291            for  (i = 0 ; (size / 1024 ) > 0  && i < length - 1 ; i++, size /= 1024 ) {
278292                dbl_size = size / 1024.0 ;
@@ -284,36 +298,89 @@ class CurlWrapper {
284298        return  out.str ();
285299    }
286300
287-     static  int  progress_callback (void  * ptr, curl_off_t  total_to_download, curl_off_t  now_downloaded, curl_off_t ,
288-                                   curl_off_t ) {
301+     static  int  update_progress (void  * ptr, curl_off_t  total_to_download, curl_off_t  now_downloaded, curl_off_t ,
302+                                curl_off_t ) {
289303        progress_data * data = static_cast <progress_data *>(ptr);
290304        if  (total_to_download <= 0 ) {
291305            return  0 ;
292306        }
293307
294308        total_to_download += data->file_size ;
295309        const  curl_off_t  now_downloaded_plus_file_size = now_downloaded + data->file_size ;
296-         const  curl_off_t  percentage                    = (now_downloaded_plus_file_size * 100 ) / total_to_download;
297-         const  curl_off_t  pos                           = (percentage / 5 );
310+         const  curl_off_t  percentage      = calculate_percentage (now_downloaded_plus_file_size, total_to_download);
311+         std::string      progress_prefix = generate_progress_prefix (percentage);
312+ 
313+         const  double  speed = calculate_speed (now_downloaded, data->start_time );
314+         const  double  tim  = (total_to_download - now_downloaded) / speed;
315+         std::string  progress_suffix =
316+             generate_progress_suffix (now_downloaded_plus_file_size, total_to_download, speed, tim);
317+ 
318+         int          progress_bar_width = calculate_progress_bar_width (progress_prefix, progress_suffix);
298319        std::string progress_bar;
299-         for  (int  i = 0 ; i < 20 ; ++i) {
300-             progress_bar.append ((i < pos) ? " █"   : "  "  );
301-         }
320+         generate_progress_bar (progress_bar_width, percentage, progress_bar);
302321
303-         //  Calculate download speed and estimated time to completion
304-         const  auto                           now             = std::chrono::steady_clock::now ();
305-         const  std::chrono::duration<double > elapsed_seconds = now - data->start_time ;
306-         const  double                         speed           = now_downloaded / elapsed_seconds.count ();
307-         const  double                         estimated_time  = (total_to_download - now_downloaded) / speed;
308-         printe (" \r %ld%% |%s| %s/%s  %.2f MB/s  %s      "  , percentage, progress_bar.c_str (),
309-                human_readable_size (now_downloaded).c_str (), human_readable_size (total_to_download).c_str (),
310-                speed / (1024  * 1024 ), human_readable_time (estimated_time).c_str ());
311-         fflush (stderr);
322+         print_progress (progress_prefix, progress_bar, progress_suffix);
312323        data->printed  = true ;
313324
314325        return  0 ;
315326    }
316327
328+     static  curl_off_t  calculate_percentage (curl_off_t  now_downloaded_plus_file_size, curl_off_t  total_to_download) {
329+         return  (now_downloaded_plus_file_size * 100 ) / total_to_download;
330+     }
331+ 
332+     static  std::string generate_progress_prefix (curl_off_t  percentage) {
333+         std::ostringstream progress_output;
334+         progress_output << std::setw (3 ) << percentage << " % |"  ;
335+         return  progress_output.str ();
336+     }
337+ 
338+     static  double  calculate_speed (curl_off_t  now_downloaded, const  std::chrono::steady_clock::time_point & start_time) {
339+         const  auto                           now             = std::chrono::steady_clock::now ();
340+         const  std::chrono::duration<double > elapsed_seconds = now - start_time;
341+         return  now_downloaded / elapsed_seconds.count ();
342+     }
343+ 
344+     static  std::string generate_progress_suffix (curl_off_t  now_downloaded_plus_file_size, curl_off_t  total_to_download,
345+                                                 double  speed, double  estimated_time) {
346+         const  int           width = 10 ;
347+         std::ostringstream progress_output;
348+         progress_output << std::setw (width) << human_readable_size (now_downloaded_plus_file_size) << " /" 
349+                         << std::setw (width) << human_readable_size (total_to_download) << std::setw (width)
350+                         << human_readable_size (speed) << " /s"   << std::setw (width)
351+                         << human_readable_time (estimated_time);
352+         return  progress_output.str ();
353+     }
354+ 
355+     static  int  calculate_progress_bar_width (const  std::string & progress_prefix, const  std::string & progress_suffix) {
356+         int  progress_bar_width = get_terminal_width () - progress_prefix.size () - progress_suffix.size () - 5 ;
357+         if  (progress_bar_width < 1 ) {
358+             progress_bar_width = 1 ;
359+         }
360+ 
361+         return  progress_bar_width;
362+     }
363+ 
364+     static  std::string generate_progress_bar (int  progress_bar_width, curl_off_t  percentage,
365+                                              std::string & progress_bar) {
366+         const  curl_off_t  pos = (percentage * progress_bar_width) / 100 ;
367+         for  (int  i = 0 ; i < progress_bar_width; ++i) {
368+             progress_bar.append ((i < pos) ? " █"   : "  "  );
369+         }
370+ 
371+         return  progress_bar;
372+     }
373+ 
374+     static  void  print_progress (const  std::string & progress_prefix, const  std::string & progress_bar,
375+                                const  std::string & progress_suffix) {
376+         std::ostringstream progress_output;
377+         progress_output << progress_prefix << progress_bar << " | "   << progress_suffix;
378+         printe (
379+             " \r %*s" 
380+             " \r %s"  ,
381+             get_terminal_width (), "  "  , progress_output.str ().c_str ());
382+     }
383+ 
317384    //  Function to write data to a file
318385    static  size_t  write_data (void  * ptr, size_t  size, size_t  nmemb, void  * stream) {
319386        FILE * out = static_cast <FILE *>(stream);
@@ -357,8 +424,8 @@ class LlamaData {
357424#ifdef  LLAMA_USE_CURL
358425    int  download (const  std::string & url, const  std::vector<std::string> & headers, const  std::string & output_file,
359426                 const  bool  progress, std::string * response_str = nullptr ) {
360-         CurlWrapper curl ;
361-         if  (curl .init (url, headers, output_file, progress, response_str)) {
427+         HttpClient http ;
428+         if  (http .init (url, headers, output_file, progress, response_str)) {
362429            return  1 ;
363430        }
364431
@@ -467,6 +534,10 @@ class LlamaData {
467534        llama_model_params model_params = llama_model_default_params ();
468535        model_params.n_gpu_layers        = opt.ngl_  >= 0  ? opt.ngl_  : model_params.n_gpu_layers ;
469536        resolve_model (opt.model_ );
537+         printe (
538+             " \r %*s" 
539+             " \r Loading model"  ,
540+             get_terminal_width (), "  "  );
470541        llama_model_ptr model (llama_load_model_from_file (opt.model_ .c_str (), model_params));
471542        if  (!model) {
472543            printe (" %s: error: unable to load model from file: %s\n "  , __func__, opt.model_ .c_str ());
@@ -478,8 +549,7 @@ class LlamaData {
478549    //  Initializes the context with the specified parameters
479550    llama_context_ptr initialize_context (const  llama_model_ptr & model, const  int  n_ctx) {
480551        llama_context_params ctx_params = llama_context_default_params ();
481-         ctx_params.n_ctx                 = n_ctx;
482-         ctx_params.n_batch               = n_ctx;
552+         ctx_params.n_ctx  = ctx_params.n_batch  = n_ctx >= 0  ? n_ctx : ctx_params.n_batch ;
483553        llama_context_ptr context (llama_new_context_with_model (model.get (), ctx_params));
484554        if  (!context) {
485555            printe (" %s: error: failed to create the llama_context\n "  , __func__);
@@ -642,8 +712,9 @@ static int handle_user_input(std::string & user_input, const std::string & user_
642712    }
643713
644714    printf (
645-         " \r                                                                        " 
646-         " \r\033 [32m> \033 [0m"  );
715+         " \r %*s" 
716+         " \r\033 [32m> \033 [0m"  ,
717+         get_terminal_width (), "  "  );
647718    return  read_user_input (user_input);  //  Returns true if input ends the loop
648719}
649720
@@ -682,8 +753,9 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
682753    return  0 ;
683754}
684755
685- static  void  log_callback (const  enum  ggml_log_level level, const  char  * text, void  *) {
686-     if  (level == GGML_LOG_LEVEL_ERROR) {
756+ static  void  log_callback (const  enum  ggml_log_level level, const  char  * text, void  * p) {
757+     const  Opt * opt = static_cast <Opt *>(p);
758+     if  (opt->verbose_  || level == GGML_LOG_LEVEL_ERROR) {
687759        printe (" %s"  , text);
688760    }
689761}
@@ -721,7 +793,7 @@ int main(int argc, const char ** argv) {
721793        opt.user_  += read_pipe_data ();
722794    }
723795
724-     llama_log_set (log_callback, nullptr );
796+     llama_log_set (log_callback, &opt );
725797    LlamaData llama_data;
726798    if  (llama_data.init (opt)) {
727799        return  1 ;
0 commit comments