Skip to content

Commit b5fcf49

Browse files
committed
Improve progress bar
Set default width to whatever the terminal is. Also fixed a small bug around default n_gpu_layers value. Signed-off-by: Eric Curtin <[email protected]>
1 parent 56eea07 commit b5fcf49

File tree

2 files changed

+154
-82
lines changed

2 files changed

+154
-82
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ To learn more about model quantization, [read this documentation](examples/quant
409409

410410
</details>
411411

412-
[^1]: [examples/perplexity/README.md](examples/perplexity/README.md)
412+
[^1]: [examples/perplexity/README.md](https://github.com/ggerganov/llama.cpp/blob/master/examples/perplexity/README.md)
413413
[^2]: [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity)
414414

415415
## [`llama-bench`](example/bench)
@@ -446,7 +446,7 @@ To learn more about model quantization, [read this documentation](examples/quant
446446
447447
</details>
448448
449-
[^3]: [https://github.com/containers/ramalama](RamaLama)
449+
[^3]: [RamaLama](https://github.com/containers/ramalama)
450450
451451
## [`llama-simple`](examples/simple)
452452

examples/run/run.cpp

Lines changed: 152 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#if defined(_WIN32)
22
# include <windows.h>
33
#else
4+
# include <sys/ioctl.h>
45
# include <unistd.h>
56
#endif
67

@@ -29,7 +30,6 @@
2930
class Opt {
3031
public:
3132
int init(int argc, const char ** argv) {
32-
construct_help_str_();
3333
// Parse arguments
3434
if (parse(argc, argv)) {
3535
printe("Error: Failed to parse arguments.\n");
@@ -48,14 +48,53 @@ class Opt {
4848

4949
std::string model_;
5050
std::string user_;
51-
int context_size_ = 2048, ngl_ = -1;
51+
int context_size_ = -1, ngl_ = -1;
52+
bool verbose_ = false;
5253

5354
private:
54-
std::string help_str_;
5555
bool help_ = false;
5656

57-
void construct_help_str_() {
58-
help_str_ =
57+
int parse(int argc, const char ** argv) {
58+
int positional_args_i = 0;
59+
for (int i = 1; i < argc; ++i) {
60+
if (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0) {
61+
if (i + 1 >= argc) {
62+
return 1;
63+
}
64+
65+
context_size_ = std::atoi(argv[++i]);
66+
} else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0) {
67+
if (i + 1 >= argc) {
68+
return 1;
69+
}
70+
71+
ngl_ = std::atoi(argv[++i]);
72+
} else if (strcmp(argv[i], "-v") == 0 || strcmp(argv[i], "--verbose") == 0 ||
73+
strcmp(argv[i], "--log-verbose") == 0) {
74+
verbose_ = true;
75+
} else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) {
76+
help_ = true;
77+
return 0;
78+
} else if (!positional_args_i) {
79+
if (!argv[i][0] || argv[i][0] == '-') {
80+
return 1;
81+
}
82+
83+
++positional_args_i;
84+
model_ = argv[i];
85+
} else if (positional_args_i == 1) {
86+
++positional_args_i;
87+
user_ = argv[i];
88+
} else {
89+
user_ += " " + std::string(argv[i]);
90+
}
91+
}
92+
93+
return model_.empty(); // model_ is the only required value
94+
}
95+
96+
void help() const {
97+
printf(
5998
"Description:\n"
6099
" Runs a llm\n"
61100
"\n"
@@ -64,15 +103,11 @@ class Opt {
64103
"\n"
65104
"Options:\n"
66105
" -c, --context-size <value>\n"
67-
" Context size (default: " +
68-
std::to_string(context_size_);
69-
help_str_ +=
70-
")\n"
106+
" Context size (default: %d)\n"
71107
" -n, --ngl <value>\n"
72-
" Number of GPU layers (default: " +
73-
std::to_string(ngl_);
74-
help_str_ +=
75-
")\n"
108+
" Number of GPU layers (default: %d)\n"
109+
" -v, --verbose, --log-verbose\n"
110+
" Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n"
76111
" -h, --help\n"
77112
" Show help message\n"
78113
"\n"
@@ -96,43 +131,10 @@ class Opt {
96131
" llama-run https://example.com/some-file1.gguf\n"
97132
" llama-run some-file2.gguf\n"
98133
" llama-run file://some-file3.gguf\n"
99-
" llama-run --ngl 99 some-file4.gguf\n"
100-
" llama-run --ngl 99 some-file5.gguf Hello World\n";
101-
}
102-
103-
int parse(int argc, const char ** argv) {
104-
int positional_args_i = 0;
105-
for (int i = 1; i < argc; ++i) {
106-
if (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0) {
107-
if (i + 1 >= argc) {
108-
return 1;
109-
}
110-
111-
context_size_ = std::atoi(argv[++i]);
112-
} else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0) {
113-
if (i + 1 >= argc) {
114-
return 1;
115-
}
116-
117-
ngl_ = std::atoi(argv[++i]);
118-
} else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) {
119-
help_ = true;
120-
return 0;
121-
} else if (!positional_args_i) {
122-
++positional_args_i;
123-
model_ = argv[i];
124-
} else if (positional_args_i == 1) {
125-
++positional_args_i;
126-
user_ = argv[i];
127-
} else {
128-
user_ += " " + std::string(argv[i]);
129-
}
130-
}
131-
132-
return model_.empty(); // model_ is the only required value
134+
" llama-run --ngl 999 some-file4.gguf\n"
135+
" llama-run --ngl 999 some-file5.gguf Hello World\n",
136+
llama_context_default_params().n_batch, llama_model_default_params().n_gpu_layers);
133137
}
134-
135-
void help() const { printf("%s", help_str_.c_str()); }
136138
};
137139

138140
struct progress_data {
@@ -151,8 +153,20 @@ struct FileDeleter {
151153

152154
typedef std::unique_ptr<FILE, FileDeleter> FILE_ptr;
153155

156+
static int get_terminal_width() {
157+
#if defined(_WIN32)
158+
CONSOLE_SCREEN_BUFFER_INFO csbi;
159+
GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi);
160+
return csbi.srWindow.Right - csbi.srWindow.Left + 1;
161+
#else
162+
struct winsize w;
163+
ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
164+
return w.ws_col;
165+
#endif
166+
}
167+
154168
#ifdef LLAMA_USE_CURL
155-
class CurlWrapper {
169+
class HttpClient {
156170
public:
157171
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
158172
const bool progress, std::string * response_str = nullptr) {
@@ -181,7 +195,7 @@ class CurlWrapper {
181195
return 0;
182196
}
183197

184-
~CurlWrapper() {
198+
~HttpClient() {
185199
if (chunk) {
186200
curl_slist_free_all(chunk);
187201
}
@@ -219,7 +233,7 @@ class CurlWrapper {
219233
if (progress) {
220234
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
221235
curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &data);
222-
curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, progress_callback);
236+
curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, update_progress);
223237
}
224238
}
225239

@@ -270,9 +284,9 @@ class CurlWrapper {
270284

271285
static std::string human_readable_size(curl_off_t size) {
272286
static const char * suffix[] = { "B", "KB", "MB", "GB", "TB" };
273-
char length = sizeof(suffix) / sizeof(suffix[0]);
274-
int i = 0;
275-
double dbl_size = size;
287+
char length = sizeof(suffix) / sizeof(suffix[0]);
288+
int i = 0;
289+
double dbl_size = size;
276290
if (size > 1024) {
277291
for (i = 0; (size / 1024) > 0 && i < length - 1; i++, size /= 1024) {
278292
dbl_size = size / 1024.0;
@@ -284,36 +298,89 @@ class CurlWrapper {
284298
return out.str();
285299
}
286300

287-
static int progress_callback(void * ptr, curl_off_t total_to_download, curl_off_t now_downloaded, curl_off_t,
288-
curl_off_t) {
301+
static int update_progress(void * ptr, curl_off_t total_to_download, curl_off_t now_downloaded, curl_off_t,
302+
curl_off_t) {
289303
progress_data * data = static_cast<progress_data *>(ptr);
290304
if (total_to_download <= 0) {
291305
return 0;
292306
}
293307

294308
total_to_download += data->file_size;
295309
const curl_off_t now_downloaded_plus_file_size = now_downloaded + data->file_size;
296-
const curl_off_t percentage = (now_downloaded_plus_file_size * 100) / total_to_download;
297-
const curl_off_t pos = (percentage / 5);
310+
const curl_off_t percentage = calculate_percentage(now_downloaded_plus_file_size, total_to_download);
311+
std::string progress_prefix = generate_progress_prefix(percentage);
312+
313+
const double speed = calculate_speed(now_downloaded, data->start_time);
314+
const double tim = (total_to_download - now_downloaded) / speed;
315+
std::string progress_suffix =
316+
generate_progress_suffix(now_downloaded_plus_file_size, total_to_download, speed, tim);
317+
318+
int progress_bar_width = calculate_progress_bar_width(progress_prefix, progress_suffix);
298319
std::string progress_bar;
299-
for (int i = 0; i < 20; ++i) {
300-
progress_bar.append((i < pos) ? "" : " ");
301-
}
320+
generate_progress_bar(progress_bar_width, percentage, progress_bar);
302321

303-
// Calculate download speed and estimated time to completion
304-
const auto now = std::chrono::steady_clock::now();
305-
const std::chrono::duration<double> elapsed_seconds = now - data->start_time;
306-
const double speed = now_downloaded / elapsed_seconds.count();
307-
const double estimated_time = (total_to_download - now_downloaded) / speed;
308-
printe("\r%ld%% |%s| %s/%s %.2f MB/s %s ", percentage, progress_bar.c_str(),
309-
human_readable_size(now_downloaded).c_str(), human_readable_size(total_to_download).c_str(),
310-
speed / (1024 * 1024), human_readable_time(estimated_time).c_str());
311-
fflush(stderr);
322+
print_progress(progress_prefix, progress_bar, progress_suffix);
312323
data->printed = true;
313324

314325
return 0;
315326
}
316327

328+
static curl_off_t calculate_percentage(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download) {
329+
return (now_downloaded_plus_file_size * 100) / total_to_download;
330+
}
331+
332+
static std::string generate_progress_prefix(curl_off_t percentage) {
333+
std::ostringstream progress_output;
334+
progress_output << std::setw(3) << percentage << "% |";
335+
return progress_output.str();
336+
}
337+
338+
static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
339+
const auto now = std::chrono::steady_clock::now();
340+
const std::chrono::duration<double> elapsed_seconds = now - start_time;
341+
return now_downloaded / elapsed_seconds.count();
342+
}
343+
344+
static std::string generate_progress_suffix(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download,
345+
double speed, double estimated_time) {
346+
const int width = 10;
347+
std::ostringstream progress_output;
348+
progress_output << std::setw(width) << human_readable_size(now_downloaded_plus_file_size) << "/"
349+
<< std::setw(width) << human_readable_size(total_to_download) << std::setw(width)
350+
<< human_readable_size(speed) << "/s" << std::setw(width)
351+
<< human_readable_time(estimated_time);
352+
return progress_output.str();
353+
}
354+
355+
static int calculate_progress_bar_width(const std::string & progress_prefix, const std::string & progress_suffix) {
356+
int progress_bar_width = get_terminal_width() - progress_prefix.size() - progress_suffix.size() - 5;
357+
if (progress_bar_width < 10) {
358+
progress_bar_width = 10;
359+
}
360+
361+
return progress_bar_width;
362+
}
363+
364+
static std::string generate_progress_bar(int progress_bar_width, curl_off_t percentage,
365+
std::string & progress_bar) {
366+
const curl_off_t pos = (percentage * progress_bar_width) / 100;
367+
for (int i = 0; i < progress_bar_width; ++i) {
368+
progress_bar.append((i < pos) ? "" : " ");
369+
}
370+
371+
return progress_bar;
372+
}
373+
374+
static void print_progress(const std::string & progress_prefix, const std::string & progress_bar,
375+
const std::string & progress_suffix) {
376+
std::ostringstream progress_output;
377+
progress_output << progress_prefix << progress_bar << "| " << progress_suffix;
378+
printe(
379+
"\r%*s"
380+
"\r%s",
381+
get_terminal_width(), " ", progress_output.str().c_str());
382+
}
383+
317384
// Function to write data to a file
318385
static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
319386
FILE * out = static_cast<FILE *>(stream);
@@ -357,8 +424,8 @@ class LlamaData {
357424
#ifdef LLAMA_USE_CURL
358425
int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
359426
const bool progress, std::string * response_str = nullptr) {
360-
CurlWrapper curl;
361-
if (curl.init(url, headers, output_file, progress, response_str)) {
427+
HttpClient http;
428+
if (http.init(url, headers, output_file, progress, response_str)) {
362429
return 1;
363430
}
364431

@@ -467,6 +534,10 @@ class LlamaData {
467534
llama_model_params model_params = llama_model_default_params();
468535
model_params.n_gpu_layers = opt.ngl_ >= 0 ? opt.ngl_ : model_params.n_gpu_layers;
469536
resolve_model(opt.model_);
537+
printe(
538+
"\r%*s"
539+
"\rLoading model",
540+
get_terminal_width(), " ");
470541
llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), model_params));
471542
if (!model) {
472543
printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
@@ -478,8 +549,7 @@ class LlamaData {
478549
// Initializes the context with the specified parameters
479550
llama_context_ptr initialize_context(const llama_model_ptr & model, const int n_ctx) {
480551
llama_context_params ctx_params = llama_context_default_params();
481-
ctx_params.n_ctx = n_ctx;
482-
ctx_params.n_batch = n_ctx;
552+
ctx_params.n_ctx = ctx_params.n_batch = n_ctx >= 0 ? n_ctx : ctx_params.n_batch;
483553
llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params));
484554
if (!context) {
485555
printe("%s: error: failed to create the llama_context\n", __func__);
@@ -642,8 +712,9 @@ static int handle_user_input(std::string & user_input, const std::string & user_
642712
}
643713

644714
printf(
645-
"\r "
646-
"\r\033[32m> \033[0m");
715+
"\r%*s"
716+
"\r\033[32m> \033[0m",
717+
get_terminal_width(), " ");
647718
return read_user_input(user_input); // Returns true if input ends the loop
648719
}
649720

@@ -682,8 +753,9 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
682753
return 0;
683754
}
684755

685-
static void log_callback(const enum ggml_log_level level, const char * text, void *) {
686-
if (level == GGML_LOG_LEVEL_ERROR) {
756+
static void log_callback(const enum ggml_log_level level, const char * text, void * p) {
757+
const Opt * opt = static_cast<Opt *>(p);
758+
if (opt->verbose_ || level == GGML_LOG_LEVEL_ERROR) {
687759
printe("%s", text);
688760
}
689761
}
@@ -721,7 +793,7 @@ int main(int argc, const char ** argv) {
721793
opt.user_ += read_pipe_data();
722794
}
723795

724-
llama_log_set(log_callback, nullptr);
796+
llama_log_set(log_callback, &opt);
725797
LlamaData llama_data;
726798
if (llama_data.init(opt)) {
727799
return 1;

0 commit comments

Comments
 (0)