@@ -178,6 +178,7 @@ struct cmd_params {
178178 std::vector<std::vector<float >> tensor_split;
179179 std::vector<bool > use_mmap;
180180 std::vector<bool > embeddings;
181+ ggml_numa_strategy numa;
181182 int reps;
182183 bool verbose;
183184 output_formats output_format;
@@ -200,6 +201,7 @@ static const cmd_params cmd_params_defaults = {
200201 /* tensor_split */ {std::vector<float >(llama_max_devices (), 0 .0f )},
201202 /* use_mmap */ {true },
202203 /* embeddings */ {false },
204+ /* numa */ GGML_NUMA_STRATEGY_DISABLED,
203205 /* reps */ 5 ,
204206 /* verbose */ false ,
205207 /* output_format */ MARKDOWN
@@ -224,6 +226,7 @@ static void print_usage(int /* argc */, char ** argv) {
224226 printf (" -nkvo, --no-kv-offload <0|1> (default: %s)\n " , join (cmd_params_defaults.no_kv_offload , " ," ).c_str ());
225227 printf (" -fa, --flash-attn <0|1> (default: %s)\n " , join (cmd_params_defaults.flash_attn , " ," ).c_str ());
226228 printf (" -mmp, --mmap <0|1> (default: %s)\n " , join (cmd_params_defaults.use_mmap , " ," ).c_str ());
229+ printf (" --numa <distribute|isolate|numactl> (default: disabled)\n " );
227230 printf (" -embd, --embeddings <0|1> (default: %s)\n " , join (cmd_params_defaults.embeddings , " ," ).c_str ());
228231 printf (" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n " );
229232 printf (" -r, --repetitions <n> (default: %d)\n " , cmd_params_defaults.reps );
@@ -396,6 +399,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
396399 }
397400 auto p = split<bool >(argv[i], split_delim);
398401 params.no_kv_offload .insert (params.no_kv_offload .end (), p.begin (), p.end ());
402+ } else if (arg == " --numa" ) {
403+ if (++i >= argc) {
404+ invalid_param = true ;
405+ break ;
406+ } else {
407+ std::string value (argv[i]);
408+ /* */ if (value == " distribute" || value == " " ) { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; }
409+ else if (value == " isolate" ) { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
410+ else if (value == " numactl" ) { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
411+ else { invalid_param = true ; break ; }
412+ }
399413 } else if (arg == " -fa" || arg == " --flash-attn" ) {
400414 if (++i >= argc) {
401415 invalid_param = true ;
@@ -1215,6 +1229,7 @@ int main(int argc, char ** argv) {
12151229 llama_log_set (llama_null_log_callback, NULL );
12161230 }
12171231 llama_backend_init ();
1232+ llama_numa_init (params.numa );
12181233
12191234 // initialize printer
12201235 std::unique_ptr<printer> p;
0 commit comments