Skip to content

Commit 88df3fd

Browse files
committed
Update llama-run to include temperature option
This commit updates the `examples/run/README.md` file to include a new option for setting the temperature and updates the `run.cpp` file to parse this option. Signed-off-by: Eric Curtin <[email protected]>
1 parent a3c33b1 commit 88df3fd

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

examples/run/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ Options:
1919
Context size (default: 2048)
2020
-n, --ngl <value>
2121
Number of GPU layers (default: 0)
22+
--temp <value>
23+
Temperature (default: 0.8)
2224
-v, --verbose, --log-verbose
2325
Set verbosity level to infinity (i.e. log all messages, useful for debugging)
2426
-h, --help

examples/run/run.cpp

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ static int printe(const char * fmt, ...) {
5555
class Opt {
5656
public:
5757
int init(int argc, const char ** argv) {
58+
if (argc < 2) {
59+
printe("Error: No arguments provided.\n");
60+
help();
61+
return 1;
62+
}
63+
5864
// Parse arguments
5965
if (parse(argc, argv)) {
6066
printe("Error: Failed to parse arguments.\n");
@@ -74,6 +80,8 @@ class Opt {
7480
std::string model_;
7581
std::string user_;
7682
int context_size_ = -1, ngl_ = -1;
83+
static constexpr const float temperature_default = 0.8;
84+
float temperature_ = -1;
7785
bool verbose_ = false;
7886

7987
private:
@@ -89,6 +97,17 @@ class Opt {
8997
}
9098

9199
option_value = std::atoi(argv[++i]);
100+
101+
return 0;
102+
}
103+
104+
int handle_option_with_value(int argc, const char ** argv, int & i, float & option_value) {
105+
if (i + 1 >= argc) {
106+
return 1;
107+
}
108+
109+
option_value = std::atof(argv[++i]);
110+
92111
return 0;
93112
}
94113

@@ -103,6 +122,10 @@ class Opt {
103122
if (handle_option_with_value(argc, argv, i, ngl_) == 1) {
104123
return 1;
105124
}
125+
} else if (options_parsing && strcmp(argv[i], "--temperature") == 0) {
126+
if (handle_option_with_value(argc, argv, i, temperature_) == 1) {
127+
return 1;
128+
}
106129
} else if (options_parsing &&
107130
(parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
108131
verbose_ = true;
@@ -142,6 +165,8 @@ class Opt {
142165
" Context size (default: %d)\n"
143166
" -n, --ngl <value>\n"
144167
" Number of GPU layers (default: %d)\n"
168+
" --temp <value>\n"
169+
" Temperature (default: %.1f)\n"
145170
" -v, --verbose, --log-verbose\n"
146171
" Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n"
147172
" -h, --help\n"
@@ -170,7 +195,8 @@ class Opt {
170195
" llama-run file://some-file3.gguf\n"
171196
" llama-run --ngl 999 some-file4.gguf\n"
172197
" llama-run --ngl 999 some-file5.gguf Hello World\n",
173-
llama_context_default_params().n_batch, llama_model_default_params().n_gpu_layers);
198+
llama_context_default_params().n_batch, llama_model_default_params().n_gpu_layers,
199+
Opt::temperature_default);
174200
}
175201
};
176202

@@ -495,12 +521,12 @@ class LlamaData {
495521
return 1;
496522
}
497523

498-
context = initialize_context(model, opt.context_size_);
524+
context = initialize_context(model, opt);
499525
if (!context) {
500526
return 1;
501527
}
502528

503-
sampler = initialize_sampler();
529+
sampler = initialize_sampler(opt);
504530
return 0;
505531
}
506532

@@ -636,9 +662,9 @@ class LlamaData {
636662
}
637663

638664
// Initializes the context with the specified parameters
639-
llama_context_ptr initialize_context(const llama_model_ptr & model, const int n_ctx) {
665+
llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) {
640666
llama_context_params ctx_params = llama_context_default_params();
641-
ctx_params.n_ctx = ctx_params.n_batch = n_ctx >= 0 ? n_ctx : ctx_params.n_batch;
667+
ctx_params.n_ctx = ctx_params.n_batch = opt.context_size_ >= 0 ? opt.context_size_ : ctx_params.n_batch;
642668
llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params));
643669
if (!context) {
644670
printe("%s: error: failed to create the llama_context\n", __func__);
@@ -648,10 +674,12 @@ class LlamaData {
648674
}
649675

650676
// Initializes and configures the sampler
651-
llama_sampler_ptr initialize_sampler() {
677+
llama_sampler_ptr initialize_sampler(const Opt & opt) {
652678
llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params()));
653679
llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1));
654-
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f));
680+
llama_sampler_chain_add(
681+
sampler.get(),
682+
llama_sampler_init_temp(opt.temperature_ >= 0 ? opt.temperature_ : Opt::temperature_default));
655683
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
656684

657685
return sampler;

0 commit comments

Comments
 (0)