@@ -36,6 +36,46 @@ static uint64_t get_time_ns() {
3636 return std::chrono::nanoseconds (clock::now ().time_since_epoch ()).count ();
3737}
3838
39+ static bool tensor_buft_override_equal (const llama_model_tensor_buft_override& a, const llama_model_tensor_buft_override& b) {
40+ if (a.pattern != b.pattern ) {
41+ // cString comparison that may be null
42+ if (a.pattern == nullptr || b.pattern == nullptr ) {
43+ return false ;
44+ }
45+ if (strcmp (a.pattern , b.pattern ) != 0 ) {
46+ return false ;
47+ }
48+ }
49+ if (a.buft != b.buft ) {
50+ return false ;
51+ }
52+ return true ;
53+ }
54+
55+ static bool vec_tensor_buft_override_equal (const std::vector<llama_model_tensor_buft_override>& a, const std::vector<llama_model_tensor_buft_override>& b) {
56+ if (a.size () != b.size ()) {
57+ return false ;
58+ }
59+ for (size_t i = 0 ; i < a.size (); i++) {
60+ if (!tensor_buft_override_equal (a[i], b[i])) {
61+ return false ;
62+ }
63+ }
64+ return true ;
65+ }
66+
67+ static bool vec_vec_tensor_buft_override_equal (const std::vector<std::vector<llama_model_tensor_buft_override>>& a, const std::vector<std::vector<llama_model_tensor_buft_override>>& b) {
68+ if (a.size () != b.size ()) {
69+ return false ;
70+ }
71+ for (size_t i = 0 ; i < a.size (); i++) {
72+ if (!vec_tensor_buft_override_equal (a[i], b[i])) {
73+ return false ;
74+ }
75+ }
76+ return true ;
77+ }
78+
3979template <class T > static std::string join (const std::vector<T> & values, const std::string & delim) {
4080 std::ostringstream str;
4181 for (size_t i = 0 ; i < values.size (); i++) {
@@ -175,6 +215,7 @@ struct cmd_params {
175215 std::vector<bool > no_kv_offload;
176216 std::vector<bool > flash_attn;
177217 std::vector<std::vector<float >> tensor_split;
218+ std::vector<std::vector<llama_model_tensor_buft_override>> tensor_buft_overrides;
178219 std::vector<bool > use_mmap;
179220 std::vector<bool > embeddings;
180221 ggml_numa_strategy numa;
@@ -207,6 +248,7 @@ static const cmd_params cmd_params_defaults = {
207248 /* no_kv_offload */ { false },
208249 /* flash_attn */ { false },
209250 /* tensor_split */ { std::vector<float >(llama_max_devices (), 0 .0f ) },
251+ /* tensor_buft_overrides*/ { std::vector<llama_model_tensor_buft_override>{{nullptr ,nullptr }} },
210252 /* use_mmap */ { true },
211253 /* embeddings */ { false },
212254 /* numa */ GGML_NUMA_STRATEGY_DISABLED,
@@ -265,6 +307,7 @@ static void print_usage(int /* argc */, char ** argv) {
265307 printf (" -embd, --embeddings <0|1> (default: %s)\n " ,
266308 join (cmd_params_defaults.embeddings , " ," ).c_str ());
267309 printf (" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n " );
310+ printf (" -ot --override-tensors <tensor name pattern>=<buffer type>;... (default: disabled)\n " );
268311 printf (" -r, --repetitions <n> (default: %d)\n " , cmd_params_defaults.reps );
269312 printf (" --prio <0|1|2|3> (default: %d)\n " , cmd_params_defaults.prio );
270313 printf (" --delay <0...N> (seconds) (default: %d)\n " , cmd_params_defaults.delay );
@@ -557,6 +600,87 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
557600 }
558601 params.tensor_split .push_back (tensor_split);
559602 }
603+ } else if (arg == " -ot" || arg == " --override-tensor" ) {
604+ if (++i >= argc) {
605+ invalid_param = true ;
606+ break ;
607+ }
608+ auto value = argv[i];
609+ /* static */ std::map<std::string, ggml_backend_buffer_type_t > buft_list;
610+ if (buft_list.empty ()) {
611+ // enumerate all the devices and add their buffer types to the list
612+ for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
613+ auto * dev = ggml_backend_dev_get (i);
614+ auto * buft = ggml_backend_dev_buffer_type (dev);
615+ if (buft) {
616+ buft_list[ggml_backend_buft_name (buft)] = buft;
617+ }
618+ }
619+ }
620+ auto override_group_span_len = std::strcspn (value, " ," );
621+ bool last_group = false ;
622+ do {
623+ if (override_group_span_len == 0 ) {
624+ // Adds an empty override-tensors for an empty span
625+ params.tensor_buft_overrides .push_back ({{}});
626+ if (value[override_group_span_len] == ' \0 ' ) {
627+ value = &value[override_group_span_len];
628+ last_group = true ;
629+ } else {
630+ value = &value[override_group_span_len + 1 ];
631+ override_group_span_len = std::strcspn (value, " ," );
632+ }
633+ continue ;
634+ }
635+ // Stamps null terminators into the argv
636+ // value for this option to avoid the
637+ // memory leak present in the implementation
638+ // over in arg.cpp. Acceptable because we
639+ // only parse these args once in this program.
640+ auto override_group = value;
641+ if (value[override_group_span_len] == ' \0 ' ) {
642+ value = &value[override_group_span_len];
643+ last_group = true ;
644+ } else {
645+ value[override_group_span_len] = ' \0 ' ;
646+ value = &value[override_group_span_len + 1 ];
647+ }
648+ std::vector<llama_model_tensor_buft_override> group_tensor_buft_overrides{};
649+ auto override_span_len = std::strcspn (override_group, " ;" );
650+ while (override_span_len > 0 ) {
651+ auto override = override_group;
652+ if (override_group[override_span_len] != ' \0 ' ) {
653+ override_group[override_span_len] = ' \0 ' ;
654+ override_group = &override_group[override_span_len + 1 ];
655+ } else {
656+ override_group = &override_group[override_span_len];
657+ }
658+ auto tensor_name_span_len = std::strcspn (override , " =" );
659+ if (tensor_name_span_len >= override_span_len) {
660+ invalid_param = true ;
661+ break ;
662+ }
663+ override [tensor_name_span_len] = ' \0 ' ;
664+ auto tensor_name = override ;
665+ auto buffer_type = &override [tensor_name_span_len + 1 ];
666+ if (buft_list.find (buffer_type) == buft_list.end ()) {
667+ printf (" Available buffer types:\n " );
668+ for (const auto & it : buft_list) {
669+ printf (" %s\n " , ggml_backend_buft_name (it.second ));
670+ }
671+ invalid_param = true ;
672+ break ;
673+ }
674+ group_tensor_buft_overrides.push_back ({tensor_name, buft_list.at (buffer_type)});
675+ override_span_len = std::strcspn (override_group, " ;" );
676+ }
677+ if (invalid_param) {
678+ break ;
679+ }
680+ group_tensor_buft_overrides.push_back ({nullptr ,nullptr });
681+ params.tensor_buft_overrides .push_back (group_tensor_buft_overrides);
682+ override_group_span_len = std::strcspn (value, " ," );
683+ } while (!last_group);
560684 } else if (arg == " -r" || arg == " --repetitions" ) {
561685 if (++i >= argc) {
562686 invalid_param = true ;
@@ -648,6 +772,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
648772 if (params.tensor_split .empty ()) {
649773 params.tensor_split = cmd_params_defaults.tensor_split ;
650774 }
775+ if (params.tensor_buft_overrides .empty ()) {
776+ params.tensor_buft_overrides = cmd_params_defaults.tensor_buft_overrides ;
777+ }
651778 if (params.use_mmap .empty ()) {
652779 params.use_mmap = cmd_params_defaults.use_mmap ;
653780 }
@@ -689,6 +816,7 @@ struct cmd_params_instance {
689816 bool no_kv_offload;
690817 bool flash_attn;
691818 std::vector<float > tensor_split;
819+ std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
692820 bool use_mmap;
693821 bool embeddings;
694822
@@ -733,13 +861,20 @@ struct cmd_params_instance {
733861 mparams.tensor_split = tensor_split.data ();
734862 mparams.use_mmap = use_mmap;
735863
864+ if (tensor_buft_overrides.empty ()) {
865+ mparams.tensor_buft_overrides = nullptr ;
866+ } else {
867+ GGML_ASSERT (tensor_buft_overrides.back ().pattern == nullptr && " Tensor buffer overrides not terminated with empty pattern" );
868+ mparams.tensor_buft_overrides = tensor_buft_overrides.data ();
869+ }
870+
736871 return mparams;
737872 }
738873
739874 bool equal_mparams (const cmd_params_instance & other) const {
740875 return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str &&
741876 split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap &&
742- tensor_split == other.tensor_split ;
877+ tensor_split == other.tensor_split && vec_tensor_buft_override_equal (tensor_buft_overrides, other. tensor_buft_overrides ) ;
743878 }
744879
745880 llama_context_params to_llama_cparams () const {
@@ -769,6 +904,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
769904 for (const auto & sm : params.split_mode )
770905 for (const auto & mg : params.main_gpu )
771906 for (const auto & ts : params.tensor_split )
907+ for (const auto & ot : params.tensor_buft_overrides )
772908 for (const auto & mmp : params.use_mmap )
773909 for (const auto & embd : params.embeddings )
774910 for (const auto & nb : params.n_batch )
@@ -804,6 +940,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
804940 /* .no_kv_offload= */ nkvo,
805941 /* .flash_attn = */ fa,
806942 /* .tensor_split = */ ts,
943+ /* .tensor_buft_overrides = */ ot,
807944 /* .use_mmap = */ mmp,
808945 /* .embeddings = */ embd,
809946 };
@@ -833,6 +970,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
833970 /* .no_kv_offload= */ nkvo,
834971 /* .flash_attn = */ fa,
835972 /* .tensor_split = */ ts,
973+ /* .tensor_buft_overrides = */ ot,
836974 /* .use_mmap = */ mmp,
837975 /* .embeddings = */ embd,
838976 };
@@ -862,6 +1000,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
8621000 /* .no_kv_offload= */ nkvo,
8631001 /* .flash_attn = */ fa,
8641002 /* .tensor_split = */ ts,
1003+ /* .tensor_buft_overrides = */ ot,
8651004 /* .use_mmap = */ mmp,
8661005 /* .embeddings = */ embd,
8671006 };
@@ -896,6 +1035,7 @@ struct test {
8961035 bool no_kv_offload;
8971036 bool flash_attn;
8981037 std::vector<float > tensor_split;
1038+ std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
8991039 bool use_mmap;
9001040 bool embeddings;
9011041 int n_prompt;
@@ -927,6 +1067,7 @@ struct test {
9271067 no_kv_offload = inst.no_kv_offload ;
9281068 flash_attn = inst.flash_attn ;
9291069 tensor_split = inst.tensor_split ;
1070+ tensor_buft_overrides = inst.tensor_buft_overrides ;
9301071 use_mmap = inst.use_mmap ;
9311072 embeddings = inst.embeddings ;
9321073 n_prompt = inst.n_prompt ;
@@ -972,9 +1113,9 @@ struct test {
9721113 " build_commit" , " build_number" , " cpu_info" , " gpu_info" , " backends" , " model_filename" ,
9731114 " model_type" , " model_size" , " model_n_params" , " n_batch" , " n_ubatch" , " n_threads" ,
9741115 " cpu_mask" , " cpu_strict" , " poll" , " type_k" , " type_v" , " n_gpu_layers" ,
975- " split_mode" , " main_gpu" , " no_kv_offload" , " flash_attn" , " tensor_split" , " use_mmap " ,
976- " embeddings " , " n_prompt " , " n_gen " , " test_time " , " avg_ns " , " stddev_ns " ,
977- " avg_ts" , " stddev_ts" ,
1116+ " split_mode" , " main_gpu" , " no_kv_offload" , " flash_attn" , " tensor_split" , " tensor_buft_overrides " ,
1117+ " use_mmap " , " embeddings " , " n_prompt " , " n_gen " , " test_time " , " avg_ns " ,
1118+ " stddev_ns " , " avg_ts" , " stddev_ts" ,
9781119 };
9791120 return fields;
9801121 }
@@ -1000,6 +1141,7 @@ struct test {
10001141
10011142 std::vector<std::string> get_values () const {
10021143 std::string tensor_split_str;
1144+ std::string tensor_buft_overrides_str;
10031145 int max_nonzero = 0 ;
10041146 for (size_t i = 0 ; i < llama_max_devices (); i++) {
10051147 if (tensor_split[i] > 0 ) {
@@ -1014,6 +1156,26 @@ struct test {
10141156 tensor_split_str += " /" ;
10151157 }
10161158 }
1159+ if (tensor_buft_overrides.size () == 1 ) {
1160+ // Last element of tensor_buft_overrides is always a null pattern
1161+ // so if it is only one element long, it must be a null pattern.
1162+ GGML_ASSERT (tensor_buft_overrides[0 ].pattern == nullptr );
1163+ tensor_buft_overrides_str += " none" ;
1164+ } else {
1165+ for (size_t i = 0 ; i < tensor_buft_overrides.size ()-1 ; i++) {
1166+ // Last element of tensor_buft_overrides is always a null pattern
1167+ if (tensor_buft_overrides[i].pattern == nullptr ) {
1168+ tensor_buft_overrides_str += " none" ;
1169+ } else {
1170+ tensor_buft_overrides_str += tensor_buft_overrides[i].pattern ;
1171+ tensor_buft_overrides_str += " =" ;
1172+ tensor_buft_overrides_str += ggml_backend_buft_name (tensor_buft_overrides[i].buft );
1173+ }
1174+ if (i + 2 < tensor_buft_overrides.size ()) {
1175+ tensor_buft_overrides_str += " ;" ;
1176+ }
1177+ }
1178+ }
10171179 std::vector<std::string> values = { build_commit,
10181180 std::to_string (build_number),
10191181 cpu_info,
@@ -1037,6 +1199,7 @@ struct test {
10371199 std::to_string (no_kv_offload),
10381200 std::to_string (flash_attn),
10391201 tensor_split_str,
1202+ tensor_buft_overrides_str,
10401203 std::to_string (use_mmap),
10411204 std::to_string (embeddings),
10421205 std::to_string (n_prompt),
@@ -1254,6 +1417,9 @@ struct markdown_printer : public printer {
12541417 if (field == " tensor_split" ) {
12551418 return " ts" ;
12561419 }
1420+ if (field == " tensor_buft_overrides" ) {
1421+ return " ot" ;
1422+ }
12571423 return field;
12581424 }
12591425
@@ -1307,6 +1473,9 @@ struct markdown_printer : public printer {
13071473 if (params.tensor_split .size () > 1 || params.tensor_split != cmd_params_defaults.tensor_split ) {
13081474 fields.emplace_back (" tensor_split" );
13091475 }
1476+ if (params.tensor_buft_overrides .size () > 1 || !vec_vec_tensor_buft_override_equal (params.tensor_buft_overrides , cmd_params_defaults.tensor_buft_overrides )) {
1477+ fields.emplace_back (" tensor_buft_overrides" );
1478+ }
13101479 if (params.use_mmap .size () > 1 || params.use_mmap != cmd_params_defaults.use_mmap ) {
13111480 fields.emplace_back (" use_mmap" );
13121481 }
0 commit comments