Skip to content

Commit 2ed6850

Browse files
committed
added override tensor
1 parent 17360a3 commit 2ed6850

File tree

3 files changed

+59
-3
lines changed

3 files changed

+59
-3
lines changed

expose.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ struct load_model_inputs
6262
const int moe_experts = -1;
6363
const bool no_bos_token = false;
6464
const char * override_kv = nullptr;
65+
const char * override_tensors = nullptr;
6566
const bool flash_attention = false;
6667
const float tensor_split[tensor_split_max] = {};
6768
const int quant_k = 0;

gpttype_adapter.cpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2172,6 +2172,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
21722172
}
21732173

21742174
std::vector<llama_model_kv_override> kvos; //ensure it keeps in scope until model is created
2175+
std::vector<llama_model_tensor_buft_override> tenos; //ensure it keeps in scope until model is created
2176+
std::vector<std::string> temp_tensor_names; //store temp tensor names to have mem references.
21752177
if(inputs.moe_experts>0)
21762178
{
21772179
printf("\nOverriding number of experts to %d\n",inputs.moe_experts);
@@ -2195,13 +2197,58 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
21952197
{
21962198
printf("\nAttempting to apply KV override: %s...\n",override_kv.c_str());
21972199
bool kvo_ok = string_parse_kv_override(override_kv.c_str(),kvos);
2198-
LLAMA_LOG_INFO("\nKV override result: %s\n",(kvo_ok?"success":"failed"));
2200+
LLAMA_LOG_INFO("\nKV override parse: %s\n",(kvo_ok?"success":"failed"));
21992201
fflush(stdout);
22002202
}
22012203
if(kvos.size()>0)
22022204
{
2205+
kvos.emplace_back();
2206+
kvos.back().key[0] = 0;
22032207
model_params.kv_overrides = kvos.data();
22042208
}
2209+
//handle override tensor
2210+
std::string tensoroverrides = inputs.override_tensors;
2211+
if(tensoroverrides!="" && ggml_backend_dev_count()>1)
2212+
{
2213+
printf("Handling Override Tensors for backends: ");
2214+
std::map<std::string, ggml_backend_buffer_type_t> buft_list;
2215+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
2216+
auto * dev = ggml_backend_dev_get(i);
2217+
auto * buft = ggml_backend_dev_buffer_type(dev);
2218+
if (buft) {
2219+
std::string name = ggml_backend_buft_name(buft);
2220+
printf("%s ", name.c_str());
2221+
buft_list[name] = buft;
2222+
}
2223+
}
2224+
printf("\n\n");
2225+
for (const auto & override : string_split<std::string>(tensoroverrides, ',')) {
2226+
std::string::size_type pos = override.find('=');
2227+
if (pos == std::string::npos) {
2228+
printf("\nInvalid Override Tensor: %s\n",override.c_str());
2229+
continue;
2230+
}
2231+
std::string tensor_name = override.substr(0, pos);
2232+
std::string buffer_type = override.substr(pos + 1);
2233+
2234+
if (buft_list.find(buffer_type) == buft_list.end()) {
2235+
printf("\nUnknown Buffer Type: %s\n",buffer_type.c_str());
2236+
continue;
2237+
}
2238+
llama_model_tensor_buft_override nto;
2239+
temp_tensor_names.push_back(tensor_name);
2240+
nto.pattern = temp_tensor_names[temp_tensor_names.size()-1].c_str();
2241+
nto.buft = buft_list.at(buffer_type);
2242+
tenos.push_back(nto);
2243+
printf("Override Tensor: %s to %s\n",tensor_name.c_str(),buffer_type.c_str());
2244+
}
2245+
}
2246+
if(tenos.size()>0)
2247+
{
2248+
tenos.push_back({nullptr, nullptr});
2249+
model_params.tensor_buft_overrides = tenos.data();
2250+
}
2251+
22052252
llama_model * llamamodel = llama_model_load_from_file(kcpp_data->model_filename.c_str(), model_params);
22062253

22072254
if(overwriteRope)

koboldcpp.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ class load_model_inputs(ctypes.Structure):
181181
("moe_experts", ctypes.c_int),
182182
("no_bos_token", ctypes.c_bool),
183183
("override_kv", ctypes.c_char_p),
184+
("override_tensors", ctypes.c_char_p),
184185
("flash_attention", ctypes.c_bool),
185186
("tensor_split", ctypes.c_float * tensor_split_max),
186187
("quant_k", ctypes.c_int),
@@ -1214,6 +1215,7 @@ def load_model(model_filename):
12141215
inputs.moe_experts = args.moeexperts
12151216
inputs.no_bos_token = args.nobostoken
12161217
inputs.override_kv = args.overridekv.encode("UTF-8") if args.overridekv else "".encode("UTF-8")
1218+
inputs.override_tensors = args.overridetensors.encode("UTF-8") if args.overridetensors else "".encode("UTF-8")
12171219
inputs = set_backend_props(inputs)
12181220
ret = handle.load_model(inputs)
12191221
return ret
@@ -3868,6 +3870,7 @@ def hide_tooltip(event):
38683870
defaultgenamt_var = ctk.StringVar(value=str(512))
38693871
nobostoken_var = ctk.IntVar(value=0)
38703872
override_kv_var = ctk.StringVar(value="")
3873+
override_tensors_var = ctk.StringVar(value="")
38713874

38723875
model_var = ctk.StringVar()
38733876
lora_var = ctk.StringVar()
@@ -4393,8 +4396,9 @@ def togglerope(a,b,c):
43934396
qkvslider,qkvlabel,qkvtitle = makeslider(tokens_tab, "Quantize KV Cache:", quantkv_text, quantkv_var, 0, 2, 30, set=0,tooltip="Enable quantization of KV cache.\nRequires FlashAttention for full effect, otherwise only K cache is quantized.")
43944397
quantkv_var.trace("w", toggleflashattn)
43954398
makecheckbox(tokens_tab, "No BOS Token", nobostoken_var, 43, tooltiptxt="Prevents BOS token from being added at the start of any prompt. Usually NOT recommended for most models.")
4396-
makelabelentry(tokens_tab, "MoE Experts:", moeexperts_var, row=45, padx=100, singleline=True, tooltip="Override number of MoE experts.")
4397-
makelabelentry(tokens_tab, "Override KV:", override_kv_var, row=47, padx=100, singleline=True, width=150, tooltip="Advanced option to override model metadata by key, same as in llama.cpp. Mainly for debugging, not intended for general use. Types: int, float, bool, str")
4399+
makelabelentry(tokens_tab, "MoE Experts:", moeexperts_var, row=45, padx=120, singleline=True, tooltip="Override number of MoE experts.")
4400+
makelabelentry(tokens_tab, "Override KV:", override_kv_var, row=47, padx=120, singleline=True, width=150, tooltip="Advanced option to override model metadata by key, same as in llama.cpp. Mainly for debugging, not intended for general use. Types: int, float, bool, str")
4401+
makelabelentry(tokens_tab, "Override Tensors:", override_tensors_var, row=49, padx=120, singleline=True, width=150, tooltip="Advanced option to override tensor backend selection, same as in llama.cpp.")
43984402

43994403
# Model Tab
44004404
model_tab = tabcontent["Loaded Files"]
@@ -4667,6 +4671,7 @@ def export_vars():
46674671
args.defaultgenamt = int(defaultgenamt_var.get()) if defaultgenamt_var.get()!="" else 512
46684672
args.nobostoken = (nobostoken_var.get()==1)
46694673
args.overridekv = None if override_kv_var.get() == "" else override_kv_var.get()
4674+
args.overridetensors = None if override_tensors_var.get() == "" else override_tensors_var.get()
46704675
args.chatcompletionsadapter = None if chatcompletionsadapter_var.get() == "" else chatcompletionsadapter_var.get()
46714676
try:
46724677
if kcpp_exporting_template and isinstance(args.chatcompletionsadapter, str) and args.chatcompletionsadapter!="" and os.path.exists(args.chatcompletionsadapter):
@@ -4861,6 +4866,8 @@ def import_vars(dict):
48614866
nobostoken_var.set(dict["nobostoken"] if ("nobostoken" in dict) else 0)
48624867
if "overridekv" in dict and dict["overridekv"]:
48634868
override_kv_var.set(dict["overridekv"])
4869+
if "overridetensors" in dict and dict["overridetensors"]:
4870+
override_tensors_var.set(dict["overridetensors"])
48644871

48654872
if "blasbatchsize" in dict and dict["blasbatchsize"]:
48664873
blas_size_var.set(blasbatchsize_values.index(str(dict["blasbatchsize"])))
@@ -6588,6 +6595,7 @@ def range_checker(arg: str):
65886595
advparser.add_argument("--nobostoken", help="Prevents BOS token from being added at the start of any prompt. Usually NOT recommended for most models.", action='store_true')
65896596
advparser.add_argument("--maxrequestsize", metavar=('[size in MB]'), help="Specify a max request payload size. Any requests to the server larger than this size will be dropped. Do not change if unsure.", type=int, default=32)
65906597
advparser.add_argument("--overridekv", metavar=('[name=type:value]'), help="Advanced option to override a metadata by key, same as in llama.cpp. Mainly for debugging, not intended for general use. Types: int, float, bool, str", default="")
6598+
advparser.add_argument("--overridetensors", metavar=('[tensor name pattern=buffer type]'), help="Advanced option to override tensor backend selection, same as in llama.cpp.", default="")
65916599
compatgroup2 = parser.add_mutually_exclusive_group()
65926600
compatgroup2.add_argument("--showgui", help="Always show the GUI instead of launching the model right away when loading settings from a .kcpps file.", action='store_true')
65936601
compatgroup2.add_argument("--skiplauncher", help="Doesn't display or use the GUI launcher.", action='store_true')

0 commit comments

Comments
 (0)