@@ -181,6 +181,7 @@ struct cmd_params {
181181 int reps;
182182 ggml_sched_priority prio;
183183 int delay;
184+ std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
184185 bool verbose;
185186 bool progress;
186187 output_formats output_format;
@@ -213,6 +214,7 @@ static const cmd_params cmd_params_defaults = {
213214 /* reps */ 5 ,
214215 /* prio */ GGML_SCHED_PRIO_NORMAL,
215216 /* delay */ 0 ,
217+ /* tensor_buft_overrides*/ {},
216218 /* verbose */ false ,
217219 /* progress */ false ,
218220 /* output_format */ MARKDOWN,
@@ -268,6 +270,7 @@ static void print_usage(int /* argc */, char ** argv) {
268270 printf (" -r, --repetitions <n> (default: %d)\n " , cmd_params_defaults.reps );
269271 printf (" --prio <0|1|2|3> (default: %d)\n " , cmd_params_defaults.prio );
270272 printf (" --delay <0...N> (seconds) (default: %d)\n " , cmd_params_defaults.delay );
273+ printf (" -ot --override-tensors <tensor name pattern>=<buffer type>,... (default:disabled)\n " );
271274 printf (" -o, --output <csv|json|jsonl|md|sql> (default: %s)\n " ,
272275 output_format_str (cmd_params_defaults.output_format ));
273276 printf (" -oe, --output-err <csv|json|jsonl|md|sql> (default: %s)\n " ,
@@ -575,6 +578,56 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
575578 break ;
576579 }
577580 params.delay = std::stoi (argv[i]);
581+ } else if (arg == " -ot" || arg == " --override-tensors" ) {
582+ if (++i >= argc) {
583+ invalid_param = true ;
584+ break ;
585+ }
586+ auto value = argv[i];
587+ /* static */ std::map<std::string, ggml_backend_buffer_type_t > buft_list;
588+ if (buft_list.empty ()) {
589+ // enumerate all the devices and add their buffer types to the list
590+ for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
591+ auto * dev = ggml_backend_dev_get (i);
592+ auto * buft = ggml_backend_dev_buffer_type (dev);
593+ if (buft) {
594+ buft_list[ggml_backend_buft_name (buft)] = buft;
595+ }
596+ }
597+ }
598+ auto override_span_len = std::strcspn (value, " ," );
599+ while (override_span_len > 0 ) {
600+ // Stamps null terminators into the argv
601+ // value for this option to avoid the
602+ // memory leak present in the implementation
603+ // over in arg.cpp. Maybe allowable because we
604+ // only parse these args once in this program.
605+ auto override = value;
606+ if (value[override_span_len] != ' \0 ' ) {
607+ value[override_span_len] = ' \0 ' ;
608+ value = &value[override_span_len + 1 ];
609+ } else {
610+ value = &value[override_span_len];
611+ }
612+ auto tensor_name_span_len = std::strcspn (override , " =" );
613+ if (tensor_name_span_len >= override_span_len) {
614+ invalid_param = true ;
615+ break ;
616+ }
617+ override [tensor_name_span_len] = ' \0 ' ;
618+ auto tensor_name = override ;
619+ auto buffer_type = &override [tensor_name_span_len + 1 ];
620+ if (buft_list.find (buffer_type) == buft_list.end ()) {
621+ printf (" Available buffer types:\n " );
622+ for (const auto & it : buft_list) {
623+ printf (" %s\n " , ggml_backend_buft_name (it.second ));
624+ }
625+ invalid_param = true ;
626+ break ;
627+ }
628+ params.tensor_buft_overrides .push_back ({tensor_name, buft_list.at (buffer_type)});
629+ override_span_len = std::strcspn (value, " ," );
630+ }
578631 } else if (arg == " -o" || arg == " --output" ) {
579632 if (++i >= argc) {
580633 invalid_param = true ;
@@ -667,6 +720,11 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
667720 params.poll = cmd_params_defaults.poll ;
668721 }
669722
723+ // Attach terminators to options that requre them
724+ if (!params.tensor_buft_overrides .empty ()) {
725+ params.tensor_buft_overrides .push_back ({nullptr , nullptr });
726+ }
727+
670728 return params;
671729}
672730
@@ -689,6 +747,7 @@ struct cmd_params_instance {
689747 bool no_kv_offload;
690748 bool flash_attn;
691749 std::vector<float > tensor_split;
750+ std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
692751 bool use_mmap;
693752 bool embeddings;
694753
@@ -733,6 +792,13 @@ struct cmd_params_instance {
733792 mparams.tensor_split = tensor_split.data ();
734793 mparams.use_mmap = use_mmap;
735794
795+ if (tensor_buft_overrides.empty ()) {
796+ mparams.tensor_buft_overrides = nullptr ;
797+ } else {
798+ GGML_ASSERT (tensor_buft_overrides.back ().pattern == nullptr && " Tensor buffer overrides not terminated with empty pattern" );
799+ mparams.tensor_buft_overrides = tensor_buft_overrides.data ();
800+ }
801+
736802 return mparams;
737803 }
738804
@@ -804,6 +870,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
804870 /* .no_kv_offload= */ nkvo,
805871 /* .flash_attn = */ fa,
806872 /* .tensor_split = */ ts,
873+ /* .tensor_buft_overrides = */ params.tensor_buft_overrides ,
807874 /* .use_mmap = */ mmp,
808875 /* .embeddings = */ embd,
809876 };
@@ -833,6 +900,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
833900 /* .no_kv_offload= */ nkvo,
834901 /* .flash_attn = */ fa,
835902 /* .tensor_split = */ ts,
903+ /* .tensor_buft_overrides = */ params.tensor_buft_overrides ,
836904 /* .use_mmap = */ mmp,
837905 /* .embeddings = */ embd,
838906 };
@@ -862,6 +930,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
862930 /* .no_kv_offload= */ nkvo,
863931 /* .flash_attn = */ fa,
864932 /* .tensor_split = */ ts,
933+ /* .tensor_buft_overrides = */ params.tensor_buft_overrides ,
865934 /* .use_mmap = */ mmp,
866935 /* .embeddings = */ embd,
867936 };
0 commit comments