Skip to content

Commit d3f5a58

Browse files
committed
Add --override-tensors option to llama-bench
1 parent bc091a4 commit d3f5a58

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

examples/llama-bench/llama-bench.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ struct cmd_params {
181181
int reps;
182182
ggml_sched_priority prio;
183183
int delay;
184+
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
184185
bool verbose;
185186
bool progress;
186187
output_formats output_format;
@@ -213,6 +214,7 @@ static const cmd_params cmd_params_defaults = {
213214
/* reps */ 5,
214215
/* prio */ GGML_SCHED_PRIO_NORMAL,
215216
/* delay */ 0,
217+
/* tensor_buft_overrides*/ {},
216218
/* verbose */ false,
217219
/* progress */ false,
218220
/* output_format */ MARKDOWN,
@@ -268,6 +270,7 @@ static void print_usage(int /* argc */, char ** argv) {
268270
printf(" -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
269271
printf(" --prio <0|1|2|3> (default: %d)\n", cmd_params_defaults.prio);
270272
printf(" --delay <0...N> (seconds) (default: %d)\n", cmd_params_defaults.delay);
273+
printf(" -ot --override-tensors <tensor name pattern>=<buffer type>,... (default:disabled)\n");
271274
printf(" -o, --output <csv|json|jsonl|md|sql> (default: %s)\n",
272275
output_format_str(cmd_params_defaults.output_format));
273276
printf(" -oe, --output-err <csv|json|jsonl|md|sql> (default: %s)\n",
@@ -575,6 +578,56 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
575578
break;
576579
}
577580
params.delay = std::stoi(argv[i]);
581+
} else if (arg == "-ot" || arg == "--override-tensors") {
582+
if (++i >= argc) {
583+
invalid_param = true;
584+
break;
585+
}
586+
auto value = argv[i];
587+
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
588+
if (buft_list.empty()) {
589+
// enumerate all the devices and add their buffer types to the list
590+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
591+
auto * dev = ggml_backend_dev_get(i);
592+
auto * buft = ggml_backend_dev_buffer_type(dev);
593+
if (buft) {
594+
buft_list[ggml_backend_buft_name(buft)] = buft;
595+
}
596+
}
597+
}
598+
auto override_span_len = std::strcspn(value, ",");
599+
while (override_span_len > 0) {
600+
// Stamps null terminators into the argv
601+
// value for this option to avoid the
602+
// memory leak present in the implementation
603+
// over in arg.cpp. Maybe allowable because we
604+
// only parse these args once in this program.
605+
auto override = value;
606+
if (value[override_span_len] != '\0') {
607+
value[override_span_len] = '\0';
608+
value = &value[override_span_len + 1];
609+
} else {
610+
value = &value[override_span_len];
611+
}
612+
auto tensor_name_span_len = std::strcspn(override, "=");
613+
if (tensor_name_span_len >= override_span_len) {
614+
invalid_param = true;
615+
break;
616+
}
617+
override[tensor_name_span_len] = '\0';
618+
auto tensor_name = override;
619+
auto buffer_type = &override[tensor_name_span_len + 1];
620+
if (buft_list.find(buffer_type) == buft_list.end()) {
621+
printf("Available buffer types:\n");
622+
for (const auto & it : buft_list) {
623+
printf(" %s\n", ggml_backend_buft_name(it.second));
624+
}
625+
invalid_param = true;
626+
break;
627+
}
628+
params.tensor_buft_overrides.push_back({tensor_name, buft_list.at(buffer_type)});
629+
override_span_len = std::strcspn(value, ",");
630+
}
578631
} else if (arg == "-o" || arg == "--output") {
579632
if (++i >= argc) {
580633
invalid_param = true;
@@ -667,6 +720,11 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
667720
params.poll = cmd_params_defaults.poll;
668721
}
669722

723+
// Attach terminators to options that requre them
724+
if (!params.tensor_buft_overrides.empty()) {
725+
params.tensor_buft_overrides.push_back({nullptr, nullptr});
726+
}
727+
670728
return params;
671729
}
672730

@@ -689,6 +747,7 @@ struct cmd_params_instance {
689747
bool no_kv_offload;
690748
bool flash_attn;
691749
std::vector<float> tensor_split;
750+
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
692751
bool use_mmap;
693752
bool embeddings;
694753

@@ -733,6 +792,13 @@ struct cmd_params_instance {
733792
mparams.tensor_split = tensor_split.data();
734793
mparams.use_mmap = use_mmap;
735794

795+
if (tensor_buft_overrides.empty()) {
796+
mparams.tensor_buft_overrides = nullptr;
797+
} else {
798+
GGML_ASSERT(tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
799+
mparams.tensor_buft_overrides = tensor_buft_overrides.data();
800+
}
801+
736802
return mparams;
737803
}
738804

@@ -804,6 +870,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
804870
/* .no_kv_offload= */ nkvo,
805871
/* .flash_attn = */ fa,
806872
/* .tensor_split = */ ts,
873+
/* .tensor_buft_overrides = */ params.tensor_buft_overrides,
807874
/* .use_mmap = */ mmp,
808875
/* .embeddings = */ embd,
809876
};
@@ -833,6 +900,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
833900
/* .no_kv_offload= */ nkvo,
834901
/* .flash_attn = */ fa,
835902
/* .tensor_split = */ ts,
903+
/* .tensor_buft_overrides = */ params.tensor_buft_overrides,
836904
/* .use_mmap = */ mmp,
837905
/* .embeddings = */ embd,
838906
};
@@ -862,6 +930,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
862930
/* .no_kv_offload= */ nkvo,
863931
/* .flash_attn = */ fa,
864932
/* .tensor_split = */ ts,
933+
/* .tensor_buft_overrides = */ params.tensor_buft_overrides,
865934
/* .use_mmap = */ mmp,
866935
/* .embeddings = */ embd,
867936
};

0 commit comments

Comments
 (0)