Skip to content

Commit 47aaa78

Browse files
committed
Improve progress bar
Set default width to whatever the terminal is. Also fixed a small bug around default n_gpu_layers value. Signed-off-by: Eric Curtin <[email protected]>
1 parent 56eea07 commit 47aaa78

File tree

1 file changed

+89
-22
lines changed

1 file changed

+89
-22
lines changed

examples/run/run.cpp

Lines changed: 89 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#if defined(_WIN32)
22
# include <windows.h>
33
#else
4+
# include <sys/ioctl.h>
45
# include <unistd.h>
56
#endif
67

@@ -70,7 +71,7 @@ class Opt {
7071
")\n"
7172
" -n, --ngl <value>\n"
7273
" Number of GPU layers (default: " +
73-
std::to_string(ngl_);
74+
std::to_string(llama_model_default_params().n_gpu_layers);
7475
help_str_ +=
7576
")\n"
7677
" -h, --help\n"
@@ -96,8 +97,8 @@ class Opt {
9697
" llama-run https://example.com/some-file1.gguf\n"
9798
" llama-run some-file2.gguf\n"
9899
" llama-run file://some-file3.gguf\n"
99-
" llama-run --ngl 99 some-file4.gguf\n"
100-
" llama-run --ngl 99 some-file5.gguf Hello World\n";
100+
" llama-run --ngl 999 some-file4.gguf\n"
101+
" llama-run --ngl 999 some-file5.gguf Hello World\n";
101102
}
102103

103104
int parse(int argc, const char ** argv) {
@@ -119,6 +120,10 @@ class Opt {
119120
help_ = true;
120121
return 0;
121122
} else if (!positional_args_i) {
123+
if (!argv[i][0] || argv[i][1] == '-') {
124+
return 1;
125+
}
126+
122127
++positional_args_i;
123128
model_ = argv[i];
124129
} else if (positional_args_i == 1) {
@@ -151,6 +156,18 @@ struct FileDeleter {
151156

152157
typedef std::unique_ptr<FILE, FileDeleter> FILE_ptr;
153158

159+
static int get_terminal_width() {
160+
#if defined(_WIN32)
161+
CONSOLE_SCREEN_BUFFER_INFO csbi;
162+
GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi);
163+
return csbi.srWindow.Right - csbi.srWindow.Left + 1;
164+
#else
165+
struct winsize w;
166+
ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
167+
return w.ws_col;
168+
#endif
169+
}
170+
154171
#ifdef LLAMA_USE_CURL
155172
class CurlWrapper {
156173
public:
@@ -270,9 +287,9 @@ class CurlWrapper {
270287

271288
static std::string human_readable_size(curl_off_t size) {
272289
static const char * suffix[] = { "B", "KB", "MB", "GB", "TB" };
273-
char length = sizeof(suffix) / sizeof(suffix[0]);
274-
int i = 0;
275-
double dbl_size = size;
290+
char length = sizeof(suffix) / sizeof(suffix[0]);
291+
int i = 0;
292+
double dbl_size = size;
276293
if (size > 1024) {
277294
for (i = 0; (size / 1024) > 0 && i < length - 1; i++, size /= 1024) {
278295
dbl_size = size / 1024.0;
@@ -293,27 +310,75 @@ class CurlWrapper {
293310

294311
total_to_download += data->file_size;
295312
const curl_off_t now_downloaded_plus_file_size = now_downloaded + data->file_size;
296-
const curl_off_t percentage = (now_downloaded_plus_file_size * 100) / total_to_download;
297-
const curl_off_t pos = (percentage / 5);
313+
const curl_off_t percentage = calculate_percentage(now_downloaded_plus_file_size, total_to_download);
314+
std::string progress_prefix = generate_progress_prefix(percentage);
315+
316+
const double speed = calculate_speed(now_downloaded, data->start_time);
317+
const double time = (total_to_download - now_downloaded) / speed;
318+
std::string progress_suffix =
319+
generate_progress_suffix(now_downloaded_plus_file_size, total_to_download, speed, time);
320+
321+
int progress_bar_width = calculate_progress_bar_width(progress_prefix, progress_suffix);
298322
std::string progress_bar;
299-
for (int i = 0; i < 20; ++i) {
300-
progress_bar.append((i < pos) ? "" : " ");
301-
}
323+
generate_progress_bar(progress_bar_width, percentage, progress_bar);
302324

303-
// Calculate download speed and estimated time to completion
304-
const auto now = std::chrono::steady_clock::now();
305-
const std::chrono::duration<double> elapsed_seconds = now - data->start_time;
306-
const double speed = now_downloaded / elapsed_seconds.count();
307-
const double estimated_time = (total_to_download - now_downloaded) / speed;
308-
printe("\r%ld%% |%s| %s/%s %.2f MB/s %s ", percentage, progress_bar.c_str(),
309-
human_readable_size(now_downloaded).c_str(), human_readable_size(total_to_download).c_str(),
310-
speed / (1024 * 1024), human_readable_time(estimated_time).c_str());
311-
fflush(stderr);
325+
print_progress(progress_prefix, progress_bar, progress_suffix);
312326
data->printed = true;
313327

314328
return 0;
315329
}
316330

331+
static curl_off_t calculate_percentage(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download) {
332+
return (now_downloaded_plus_file_size * 100) / total_to_download;
333+
}
334+
335+
static std::string generate_progress_prefix(curl_off_t percentage) {
336+
std::ostringstream progress_output;
337+
progress_output << percentage << "% |";
338+
return progress_output.str();
339+
}
340+
341+
static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
342+
const auto now = std::chrono::steady_clock::now();
343+
const std::chrono::duration<double> elapsed_seconds = now - start_time;
344+
return now_downloaded / elapsed_seconds.count();
345+
}
346+
347+
static std::string generate_progress_suffix(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download,
348+
double speed, double estimated_time) {
349+
std::ostringstream progress_output;
350+
progress_output << human_readable_size(now_downloaded_plus_file_size).c_str() << "/"
351+
<< human_readable_size(total_to_download).c_str() << " " << std::fixed << std::setprecision(2)
352+
<< speed / (1024 * 1024) << " MB/s " << human_readable_time(estimated_time).c_str();
353+
return progress_output.str();
354+
}
355+
356+
static int calculate_progress_bar_width(const std::string & progress_prefix, const std::string & progress_suffix) {
357+
int progress_bar_width = get_terminal_width() - progress_prefix.size() - progress_suffix.size() - 5;
358+
if (progress_bar_width < 10) {
359+
progress_bar_width = 10;
360+
}
361+
return progress_bar_width;
362+
}
363+
364+
static std::string generate_progress_bar(int progress_bar_width, curl_off_t percentage,
365+
std::string & progress_bar) {
366+
const curl_off_t pos = (percentage * progress_bar_width) / 100;
367+
for (int i = 0; i < progress_bar_width; ++i) {
368+
progress_bar.append((i < pos) ? "" : " ");
369+
}
370+
371+
return progress_bar;
372+
}
373+
374+
static void print_progress(const std::string & progress_prefix, const std::string & progress_bar,
375+
const std::string & progress_suffix) {
376+
std::ostringstream progress_output;
377+
progress_output << progress_prefix << progress_bar << "| " << progress_suffix;
378+
printe("\r%*s\r%s", get_terminal_width(), " ", progress_output.str().c_str());
379+
fflush(stderr);
380+
}
381+
317382
// Function to write data to a file
318383
static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
319384
FILE * out = static_cast<FILE *>(stream);
@@ -467,6 +532,7 @@ class LlamaData {
467532
llama_model_params model_params = llama_model_default_params();
468533
model_params.n_gpu_layers = opt.ngl_ >= 0 ? opt.ngl_ : model_params.n_gpu_layers;
469534
resolve_model(opt.model_);
535+
printe("Loading model");
470536
llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), model_params));
471537
if (!model) {
472538
printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
@@ -642,8 +708,9 @@ static int handle_user_input(std::string & user_input, const std::string & user_
642708
}
643709

644710
printf(
645-
"\r "
646-
"\r\033[32m> \033[0m");
711+
"\r%*s"
712+
"\r\033[32m> \033[0m",
713+
get_terminal_width(), " ");
647714
return read_user_input(user_input); // Returns true if input ends the loop
648715
}
649716

0 commit comments

Comments
 (0)