Skip to content

Commit ff6a907

Browse files
committed
Opt class for positional argument handling
Added support for positional arguments `MODEL` and `PROMPT`. Signed-off-by: Eric Curtin <[email protected]>
1 parent 0eb4e12 commit ff6a907

File tree

1 file changed

+89
-93
lines changed

1 file changed

+89
-93
lines changed

examples/run/run.cpp

Lines changed: 89 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -17,96 +17,102 @@
1717

1818
typedef std::unique_ptr<char[]> char_array_ptr;
1919

20-
struct Argument {
21-
std::string flag;
22-
std::string help_text;
23-
};
24-
25-
struct Options {
26-
std::string model_path, prompt_non_interactive;
27-
int ngl = 99;
28-
int n_ctx = 2048;
29-
};
20+
class Opt {
21+
public:
22+
int init_opt(int argc, const char ** argv) {
23+
construct_help_str_();
24+
// Parse arguments
25+
if (parse(argc, argv)) {
26+
fprintf(stderr, "Error: Failed to parse arguments.\n");
27+
help();
28+
return 1;
29+
}
3030

31-
class ArgumentParser {
32-
public:
33-
ArgumentParser(const char * program_name) : program_name(program_name) {}
31+
// If help is requested, show help and exit
32+
if (help_) {
33+
help();
34+
return 2;
35+
}
3436

35-
void add_argument(const std::string & flag, std::string & var, const std::string & help_text = "") {
36-
string_args[flag] = &var;
37-
arguments.push_back({flag, help_text});
37+
return 0; // Success
3838
}
3939

40-
void add_argument(const std::string & flag, int & var, const std::string & help_text = "") {
41-
int_args[flag] = &var;
42-
arguments.push_back({flag, help_text});
40+
const char * model_ = nullptr;
41+
std::string prompt_;
42+
int context_size_ = 2048, ngl_ = 0;
43+
44+
private:
45+
std::string help_str_;
46+
bool help_ = false;
47+
48+
void construct_help_str_() {
49+
help_str_ =
50+
"Description:\n"
51+
" Runs a llm\n"
52+
"\n"
53+
"Usage:\n"
54+
" llama-run [options] MODEL [PROMPT]\n"
55+
"\n"
56+
"Options:\n"
57+
" -c, --context-size <value>\n"
58+
" Context size (default: " +
59+
std::to_string(context_size_);
60+
help_str_ +=
61+
")\n"
62+
" -n, --ngl <value>\n"
63+
" Number of GPU layers (default: " +
64+
std::to_string(ngl_);
65+
help_str_ +=
66+
")\n"
67+
" -h, --help\n"
68+
" Show help message\n"
69+
"\n"
70+
"Examples:\n"
71+
" llama-run ~/.local/share/ramalama/models/ollama/smollm\\:135m\n"
72+
" llama-run --ngl 99 ~/.local/share/ramalama/models/ollama/smollm\\:135m\n"
73+
" llama-run --ngl 99 ~/.local/share/ramalama/models/ollama/smollm\\:135m Hello World\n";
4374
}
4475

4576
int parse(int argc, const char ** argv) {
77+
if (parse_arguments(argc, argv) || !model_) {
78+
return 1;
79+
}
80+
81+
return 0;
82+
}
83+
84+
int parse_arguments(int argc, const char ** argv) {
85+
int positional_args_i = 0;
4686
for (int i = 1; i < argc; ++i) {
47-
std::string arg = argv[i];
48-
if (string_args.count(arg)) {
49-
if (i + 1 < argc) {
50-
*string_args[arg] = argv[++i];
51-
} else {
52-
fprintf(stderr, "error: missing value for %s\n", arg.c_str());
53-
print_usage();
87+
if (std::strcmp(argv[i], "-c") == 0 || std::strcmp(argv[i], "--context-size") == 0) {
88+
if (i + 1 >= argc) {
5489
return 1;
5590
}
56-
} else if (int_args.count(arg)) {
57-
if (i + 1 < argc) {
58-
if (parse_int_arg(argv[++i], *int_args[arg]) != 0) {
59-
fprintf(stderr, "error: invalid value for %s: %s\n", arg.c_str(), argv[i]);
60-
print_usage();
61-
return 1;
62-
}
63-
} else {
64-
fprintf(stderr, "error: missing value for %s\n", arg.c_str());
65-
print_usage();
91+
92+
context_size_ = std::atoi(argv[++i]);
93+
} else if (std::strcmp(argv[i], "-n") == 0 || std::strcmp(argv[i], "--ngl") == 0) {
94+
if (i + 1 >= argc) {
6695
return 1;
6796
}
97+
98+
ngl_ = std::atoi(argv[++i]);
99+
} else if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) {
100+
help_ = true;
101+
} else if (!positional_args_i) {
102+
++positional_args_i;
103+
model_ = argv[i];
104+
} else if (positional_args_i == 1) {
105+
++positional_args_i;
106+
prompt_ = argv[i];
68107
} else {
69-
fprintf(stderr, "error: unrecognized argument %s\n", arg.c_str());
70-
print_usage();
71-
return 1;
108+
prompt_ += " " + std::string(argv[i]);
72109
}
73110
}
74111

75-
if (string_args["-m"]->empty()) {
76-
fprintf(stderr, "error: -m is required\n");
77-
print_usage();
78-
return 1;
79-
}
80-
81112
return 0;
82113
}
83114

84-
private:
85-
const char * program_name;
86-
std::unordered_map<std::string, std::string *> string_args;
87-
std::unordered_map<std::string, int *> int_args;
88-
std::vector<Argument> arguments;
89-
90-
int parse_int_arg(const char * arg, int & value) {
91-
char * end;
92-
const long val = std::strtol(arg, &end, 10);
93-
if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) {
94-
value = static_cast<int>(val);
95-
return 0;
96-
}
97-
return 1;
98-
}
99-
100-
void print_usage() const {
101-
printf("\nUsage:\n");
102-
printf(" %s [OPTIONS]\n\n", program_name);
103-
printf("Options:\n");
104-
for (const auto & arg : arguments) {
105-
printf(" %-10s %s\n", arg.flag.c_str(), arg.help_text.c_str());
106-
}
107-
108-
printf("\n");
109-
}
115+
void help() const { printf("%s", help_str_.c_str()); }
110116
};
111117

112118
class LlamaData {
@@ -116,13 +122,13 @@ class LlamaData {
116122
llama_context_ptr context;
117123
std::vector<llama_chat_message> messages;
118124

119-
int init(const Options & opt) {
120-
model = initialize_model(opt.model_path, opt.ngl);
125+
int init(const Opt & opt) {
126+
model = initialize_model(opt.model_, opt.ngl_);
121127
if (!model) {
122128
return 1;
123129
}
124130

125-
context = initialize_context(model, opt.n_ctx);
131+
context = initialize_context(model, opt.context_size_);
126132
if (!context) {
127133
return 1;
128134
}
@@ -273,19 +279,6 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
273279
return 0;
274280
}
275281

276-
static int parse_arguments(const int argc, const char ** argv, Options & opt) {
277-
ArgumentParser parser(argv[0]);
278-
parser.add_argument("-m", opt.model_path, "model");
279-
parser.add_argument("-p", opt.prompt_non_interactive, "prompt");
280-
parser.add_argument("-c", opt.n_ctx, "context_size");
281-
parser.add_argument("-ngl", opt.ngl, "n_gpu_layers");
282-
if (parser.parse(argc, argv)) {
283-
return 1;
284-
}
285-
286-
return 0;
287-
}
288-
289282
static int read_user_input(std::string & user) {
290283
std::getline(std::cin, user);
291284
return user.empty(); // Indicate an error or empty input
@@ -382,17 +375,20 @@ static std::string read_pipe_data() {
382375
}
383376

384377
int main(int argc, const char ** argv) {
385-
Options opt;
386-
if (parse_arguments(argc, argv, opt)) {
378+
Opt opt;
379+
const int opt_ret = opt.init_opt(argc, argv);
380+
if (opt_ret == 2) {
381+
return 0;
382+
} else if (opt_ret) {
387383
return 1;
388384
}
389385

390386
if (!is_stdin_a_terminal()) {
391-
if (!opt.prompt_non_interactive.empty()) {
392-
opt.prompt_non_interactive += "\n\n";
387+
if (!opt.prompt_.empty()) {
388+
opt.prompt_ += "\n\n";
393389
}
394390

395-
opt.prompt_non_interactive += read_pipe_data();
391+
opt.prompt_ += read_pipe_data();
396392
}
397393

398394
llama_log_set(log_callback, nullptr);
@@ -401,7 +397,7 @@ int main(int argc, const char ** argv) {
401397
return 1;
402398
}
403399

404-
if (chat_loop(llama_data, opt.prompt_non_interactive)) {
400+
if (chat_loop(llama_data, opt.prompt_)) {
405401
return 1;
406402
}
407403

0 commit comments

Comments
 (0)