Skip to content

Commit 803ee23

Browse files
authored
Merge branch 'ggml-org:master' into mradermacher
2 parents d43c493 + 5fce5f9 commit 803ee23

23 files changed

+682
-52
lines changed

common/chat-parser.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::
4949

5050
// LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
5151
result_.tool_calls.emplace_back(tool_call);
52+
5253
return true;
5354
}
5455
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
@@ -378,3 +379,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
378379
/* .is_partial = */ found_healing_marker,
379380
};
380381
}
382+
383+
void common_chat_msg_parser::clear_tools() {
384+
result_.tool_calls.clear();
385+
}

common/chat-parser.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,6 @@ class common_chat_msg_parser {
115115
const std::vector<std::vector<std::string>> & args_paths = {},
116116
const std::vector<std::vector<std::string>> & content_paths = {}
117117
);
118+
119+
void clear_tools();
118120
};

common/chat.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1921,7 +1921,9 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
19211921
} catch (const common_chat_msg_partial_exception & ex) {
19221922
LOG_DBG("Partial parse: %s\n", ex.what());
19231923
if (!is_partial) {
1924-
throw std::runtime_error(ex.what());
1924+
builder.clear_tools();
1925+
builder.move_to(0);
1926+
common_chat_parse_content_only(builder);
19251927
}
19261928
}
19271929
auto msg = builder.result();

convert_hf_to_gguf.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5262,6 +5262,34 @@ def prepare_tensors(self):
52625262
raise ValueError(f"Unprocessed experts: {experts}")
52635263

52645264

5265+
@ModelBase.register("Dots1ForCausalLM")
5266+
class Dots1Model(Qwen2MoeModel):
5267+
model_arch = gguf.MODEL_ARCH.DOTS1
5268+
5269+
def __init__(self, *args, **kwargs):
5270+
super().__init__(*args, **kwargs)
5271+
self.hparams["num_experts"] = self.hparams["n_routed_experts"]
5272+
5273+
def set_gguf_parameters(self):
5274+
super().set_gguf_parameters()
5275+
self.gguf_writer.add_leading_dense_block_count(self.hparams["first_k_dense_replace"])
5276+
self.gguf_writer.add_expert_shared_count(self.hparams["n_shared_experts"])
5277+
self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
5278+
self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
5279+
5280+
if self.hparams["scoring_func"] == "noaux_tc":
5281+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
5282+
else:
5283+
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
5284+
5285+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
5286+
if name.endswith("e_score_correction_bias"):
5287+
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
5288+
if "shared_experts" in name:
5289+
return [(self.map_tensor_name(name), data_torch)]
5290+
return super().modify_tensors(data_torch, name, bid)
5291+
5292+
52655293
@ModelBase.register("PLMForCausalLM")
52665294
class PLMModel(TextModel):
52675295
model_arch = gguf.MODEL_ARCH.PLM

docs/function-calling.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Function calling is supported for all models (see https://github.com/ggml-org/ll
1111
- Llama 3.1 / 3.3 (including builtin tools support - tool names for `wolfram_alpha`, `web_search` / `brave_search`, `code_interpreter`), Llama 3.2
1212
- Functionary v3.1 / v3.2
1313
- Hermes 2/3, Qwen 2.5
14-
- Qwen 2.5 Coder (WIP: https://github.com/ggml-org/llama.cpp/pull/12034)
14+
- Qwen 2.5 Coder
1515
- Mistral Nemo
1616
- Firefunction v2
1717
- Command R7B

gguf-py/gguf/constants.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ class MODEL_ARCH(IntEnum):
343343
WAVTOKENIZER_DEC = auto()
344344
PLM = auto()
345345
BAILINGMOE = auto()
346+
DOTS1 = auto()
346347

347348

348349
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -623,6 +624,7 @@ class MODEL_TENSOR(IntEnum):
623624
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
624625
MODEL_ARCH.PLM: "plm",
625626
MODEL_ARCH.BAILINGMOE: "bailingmoe",
627+
MODEL_ARCH.DOTS1: "dots1"
626628
}
627629

628630
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2044,6 +2046,30 @@ class MODEL_TENSOR(IntEnum):
20442046
MODEL_TENSOR.FFN_DOWN_SHEXP,
20452047
MODEL_TENSOR.FFN_UP_SHEXP,
20462048
],
2049+
MODEL_ARCH.DOTS1: [
2050+
MODEL_TENSOR.TOKEN_EMBD,
2051+
MODEL_TENSOR.OUTPUT_NORM,
2052+
MODEL_TENSOR.OUTPUT,
2053+
MODEL_TENSOR.ATTN_NORM,
2054+
MODEL_TENSOR.ATTN_Q,
2055+
MODEL_TENSOR.ATTN_Q_NORM,
2056+
MODEL_TENSOR.ATTN_K,
2057+
MODEL_TENSOR.ATTN_K_NORM,
2058+
MODEL_TENSOR.ATTN_V,
2059+
MODEL_TENSOR.ATTN_OUT,
2060+
MODEL_TENSOR.FFN_EXP_PROBS_B,
2061+
MODEL_TENSOR.FFN_NORM,
2062+
MODEL_TENSOR.FFN_GATE,
2063+
MODEL_TENSOR.FFN_GATE_EXP,
2064+
MODEL_TENSOR.FFN_GATE_INP,
2065+
MODEL_TENSOR.FFN_GATE_SHEXP,
2066+
MODEL_TENSOR.FFN_DOWN,
2067+
MODEL_TENSOR.FFN_DOWN_EXP,
2068+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2069+
MODEL_TENSOR.FFN_UP,
2070+
MODEL_TENSOR.FFN_UP_EXP,
2071+
MODEL_TENSOR.FFN_UP_SHEXP,
2072+
],
20472073
# TODO
20482074
}
20492075

gguf-py/gguf/tensor_mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ class TensorNameMap:
305305
),
306306

307307
MODEL_TENSOR.FFN_EXP_PROBS_B: (
308-
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3
308+
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
309309
),
310310

311311
# Feed-forward up

include/llama.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,14 +243,14 @@ extern "C" {
243243

244244
typedef bool (*llama_progress_callback)(float progress, void * user_data);
245245

246-
// Input data for llama_decode
246+
// Input data for llama_encode/llama_decode
247247
// A llama_batch object can contain input about one or many sequences
248248
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
249249
//
250250
// - token : the token ids of the input (used when embd is NULL)
251251
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
252252
// - pos : the positions of the respective token in the sequence
253-
// (if set to NULL, the token position will be tracked automatically by llama_decode)
253+
// (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode)
254254
// - seq_id : the sequence to which the respective token belongs
255255
// (if set to NULL, the sequence ID will be assumed to be 0)
256256
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
tabulate~=0.9.0
22
GitPython~=3.1.43
3+
matplotlib~=3.10.0

scripts/compare-llama-bench.py

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100
2020
raise e
2121

22+
2223
logger = logging.getLogger("compare-llama-bench")
2324

2425
# All llama-bench SQL fields
@@ -122,11 +123,15 @@
122123
parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed")
123124
parser.add_argument("-s", "--show", help=help_s)
124125
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
126+
parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)")
127+
parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth")
128+
parser.add_argument("--plot_log_scale", action="store_true", help="use log scale for x axis in plots (off by default)")
125129

126130
known_args, unknown_args = parser.parse_known_args()
127131

128132
logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
129133

134+
130135
if known_args.check:
131136
# Check if all required Python libraries are installed. Would have failed earlier if not.
132137
sys.exit(0)
@@ -499,7 +504,6 @@ def valid_format(data_files: list[str]) -> bool:
499504

500505
name_compare = bench_data.get_commit_name(hexsha8_compare)
501506

502-
503507
# If the user provided columns to group the results by, use them:
504508
if known_args.show is not None:
505509
show = known_args.show.split(",")
@@ -544,6 +548,14 @@ def valid_format(data_files: list[str]) -> bool:
544548
show.remove(prop)
545549
except ValueError:
546550
pass
551+
552+
# Add plot_x parameter to parameters to show if it's not already present:
553+
if known_args.plot:
554+
for k, v in PRETTY_NAMES.items():
555+
if v == known_args.plot_x and k not in show:
556+
show.append(k)
557+
break
558+
547559
rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
548560

549561
if not rows_show:
@@ -600,6 +612,161 @@ def valid_format(data_files: list[str]) -> bool:
600612
headers = [PRETTY_NAMES[p] for p in show]
601613
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
602614

615+
if known_args.plot:
616+
def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str, log_scale: bool = False):
617+
try:
618+
import matplotlib.pyplot as plt
619+
import matplotlib
620+
matplotlib.use('Agg')
621+
except ImportError as e:
622+
logger.error("matplotlib is required for --plot.")
623+
raise e
624+
625+
data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
626+
plot_x_index = None
627+
plot_x_label = plot_x_param
628+
629+
if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]:
630+
pretty_name = PRETTY_NAMES.get(plot_x_param, plot_x_param)
631+
if pretty_name in data_headers:
632+
plot_x_index = data_headers.index(pretty_name)
633+
plot_x_label = pretty_name
634+
elif plot_x_param in data_headers:
635+
plot_x_index = data_headers.index(plot_x_param)
636+
plot_x_label = plot_x_param
637+
else:
638+
logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}")
639+
return
640+
641+
grouped_data = {}
642+
643+
for i, row in enumerate(table_data):
644+
group_key_parts = []
645+
test_name = row[-4]
646+
647+
base_test = ""
648+
x_value = None
649+
650+
if plot_x_param in ["n_prompt", "n_gen", "n_depth"]:
651+
for j, val in enumerate(row[:-4]):
652+
header_name = data_headers[j]
653+
if val is not None and str(val).strip():
654+
group_key_parts.append(f"{header_name}={val}")
655+
656+
if plot_x_param == "n_prompt" and "pp" in test_name:
657+
base_test = test_name.split("@")[0]
658+
x_value = base_test
659+
elif plot_x_param == "n_gen" and "tg" in test_name:
660+
x_value = test_name.split("@")[0]
661+
elif plot_x_param == "n_depth" and "@d" in test_name:
662+
base_test = test_name.split("@d")[0]
663+
x_value = int(test_name.split("@d")[1])
664+
else:
665+
base_test = test_name
666+
667+
if base_test.strip():
668+
group_key_parts.append(f"Test={base_test}")
669+
else:
670+
for j, val in enumerate(row[:-4]):
671+
if j != plot_x_index:
672+
header_name = data_headers[j]
673+
if val is not None and str(val).strip():
674+
group_key_parts.append(f"{header_name}={val}")
675+
else:
676+
x_value = val
677+
678+
group_key_parts.append(f"Test={test_name}")
679+
680+
group_key = tuple(group_key_parts)
681+
682+
if group_key not in grouped_data:
683+
grouped_data[group_key] = []
684+
685+
grouped_data[group_key].append({
686+
'x_value': x_value,
687+
'baseline': float(row[-3]),
688+
'compare': float(row[-2]),
689+
'speedup': float(row[-1])
690+
})
691+
692+
if not grouped_data:
693+
logger.error("No data available for plotting")
694+
return
695+
696+
def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
697+
from math import ceil
698+
cols = 1 if num_groups == 1 else min(max_cols, num_groups)
699+
rows = ceil(num_groups / cols)
700+
701+
# Scale figure size by grid dimensions
702+
w, h = base_size
703+
fig, ax_arr = plt.subplots(rows, cols,
704+
figsize=(w * cols, h * rows),
705+
squeeze=False)
706+
707+
axes = ax_arr.flatten()[:num_groups]
708+
return fig, axes
709+
710+
num_groups = len(grouped_data)
711+
fig, axes = make_axes(num_groups)
712+
713+
plot_idx = 0
714+
715+
for group_key, points in grouped_data.items():
716+
if plot_idx >= len(axes):
717+
break
718+
ax = axes[plot_idx]
719+
720+
try:
721+
points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0)
722+
x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted]
723+
except ValueError:
724+
points_sorted = sorted(points, key=lambda p: group_key)
725+
x_values = [p['x_value'] for p in points_sorted]
726+
727+
baseline_vals = [p['baseline'] for p in points_sorted]
728+
compare_vals = [p['compare'] for p in points_sorted]
729+
730+
ax.plot(x_values, baseline_vals, 'o-', color='skyblue',
731+
label=f'{baseline_name}', linewidth=2, markersize=6)
732+
ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8,
733+
label=f'{compare_name}', linewidth=2, markersize=6)
734+
735+
if log_scale:
736+
ax.set_xscale('log', base=2)
737+
unique_x = sorted(set(x_values))
738+
ax.set_xticks(unique_x)
739+
ax.set_xticklabels([str(int(x)) for x in unique_x])
740+
741+
title_parts = []
742+
for part in group_key:
743+
if '=' in part:
744+
key, value = part.split('=', 1)
745+
title_parts.append(f"{key}: {value}")
746+
747+
title = ', '.join(title_parts) if title_parts else "Performance comparison"
748+
749+
ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold')
750+
ax.set_ylabel('Tokens per second (t/s)', fontsize=12, fontweight='bold')
751+
ax.set_title(title, fontsize=12, fontweight='bold')
752+
ax.legend(loc='best', fontsize=10)
753+
ax.grid(True, alpha=0.3)
754+
755+
plot_idx += 1
756+
757+
for i in range(plot_idx, len(axes)):
758+
axes[i].set_visible(False)
759+
760+
fig.suptitle(f'Performance comparison: {compare_name} vs. {baseline_name}',
761+
fontsize=14, fontweight='bold')
762+
fig.subplots_adjust(top=1)
763+
764+
plt.tight_layout()
765+
plt.savefig(output_file, dpi=300, bbox_inches='tight')
766+
plt.close()
767+
768+
create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x, known_args.plot_log_scale)
769+
603770
print(tabulate( # noqa: NP100
604771
table,
605772
headers=headers,

0 commit comments

Comments
 (0)