Skip to content

Commit 34fcc5a

Browse files
pwilkinggerganovJohannesGaesslerCISC
authored
model : Apertus model implementation (#15852)
* First attempt * No permute during convert (fixes qk tensors), proper norm application. * RoPE = NeoX * Coherence! * Migrate xielu params from tensors to hyperparameters * Simple CUDA kernel * Revert stupid LLM refactorings * Chat template support * configchecker / flake8 errors * Reorder unary.cu * I do conclude that LLMs are, in fact, stupid. * Fix after merge * Final newline * Make xIELU an UNARY_OP * Final newline * Correctly account for parameter shift * Argh. * Update ggml/src/ggml-cpu/unary-ops.cpp Co-authored-by: Georgi Gerganov <[email protected]> * Refactor: remove unused methods, inline and factorize softplus, add const modifiers * Revert CUDA changes, implement xIELU as a separate OP * Pesky newline * Add float2half / half2float for F16 inputs/outputs * CUDA variants, attempt 2 * Actually, attempt 3 * Update ggml/src/ggml-cuda/unary.cu Co-authored-by: Johannes Gäßler <[email protected]> * Missing convert header * Proper formula and reference for xIELU in the comments. * Modify unary-ops.cpp to add the functor-based logic besides the template system to retain optimizations * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <[email protected]> * Add tensor mappings for Apertus to global list instead * Fix lazy on scalars * Update ggml/src/ggml-cuda/unary.cu Co-authored-by: Johannes Gäßler <[email protected]> * Add comment about the constraints on positive/negative alpha * Change `softplus` to `ggml_softplus` --------- Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: Johannes Gäßler <[email protected]> Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 91a2a56 commit 34fcc5a

27 files changed

+1082
-7
lines changed

common/chat-parser.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,35 @@ bool common_chat_msg_parser::add_tool_calls(const json & arr) {
7575
}
7676
return true;
7777
}
78+
79+
bool common_chat_msg_parser::add_tool_call_short_form(const json & tool_call) {
80+
if (!tool_call.is_object() || tool_call.size() != 1) {
81+
return false;
82+
}
83+
84+
// Get the tool name (the single key in the object)
85+
auto it = tool_call.begin();
86+
std::string name = it.key();
87+
88+
if (name.empty()) {
89+
return false;
90+
}
91+
92+
// Get the arguments (the nested object)
93+
const json & args_json = it.value();
94+
std::string arguments = "";
95+
96+
if (args_json.is_object()) {
97+
arguments = args_json.dump();
98+
} else if (args_json.is_string()) {
99+
arguments = args_json;
100+
} else if (!args_json.is_null()) {
101+
// For other types, convert to string representation
102+
arguments = args_json.dump();
103+
}
104+
105+
return add_tool_call(name, "", arguments);
106+
}
78107
void common_chat_msg_parser::finish() {
79108
if (!is_partial_ && pos_ != input_.size()) {
80109
throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));

common/chat-parser.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ class common_chat_msg_parser {
6464
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
6565
bool add_tool_calls(const nlohmann::ordered_json & arr);
6666

67+
// Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } }
68+
bool add_tool_call_short_form(const nlohmann::ordered_json & tool_call);
69+
6770
void finish();
6871

6972
bool consume_spaces();

common/chat.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,7 @@ const char * common_chat_format_name(common_chat_format format) {
638638
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
639639
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
640640
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
641+
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
641642
default:
642643
throw std::runtime_error("Unknown chat format");
643644
}
@@ -801,6 +802,7 @@ static std::string apply(
801802
}
802803
tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt;
803804
tmpl_inputs.extra_context = inputs.extra_context;
805+
tmpl_inputs.extra_context["enable_thinking"] = inputs.enable_thinking;
804806
if (additional_context) {
805807
tmpl_inputs.extra_context.merge_patch(*additional_context);
806808
}
@@ -1264,6 +1266,75 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
12641266
}
12651267
return data;
12661268
}
1269+
1270+
static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) {
1271+
common_chat_params data;
1272+
1273+
// Generate the prompt using the apply() function with the template
1274+
data.prompt = apply(tmpl, inputs);
1275+
data.format = COMMON_CHAT_FORMAT_APERTUS;
1276+
1277+
// Handle thinking tags appropriately based on inputs.enable_thinking
1278+
if (string_ends_with(data.prompt, "<|inner_prefix|>")) {
1279+
if (!inputs.enable_thinking) {
1280+
data.prompt += "<|inner_suffix|>";
1281+
} else {
1282+
data.thinking_forced_open = true;
1283+
}
1284+
}
1285+
1286+
// When tools are present, build grammar for the <|tools_prefix|> format
1287+
if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) {
1288+
data.grammar_lazy = true;
1289+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1290+
auto schemas = json::array();
1291+
foreach_function(inputs.tools, [&](const json & tool) {
1292+
const auto & function = tool.at("function");
1293+
schemas.push_back({
1294+
{ "type", "object" },
1295+
{ "properties",
1296+
{
1297+
{ function.at("name"), function.at("parameters") }
1298+
} },
1299+
{ "required", json::array({ function.at("name") }) },
1300+
});
1301+
});
1302+
auto schema = json{
1303+
{ "type", "array" },
1304+
{ "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } },
1305+
{ "minItems", 1 },
1306+
};
1307+
if (!inputs.parallel_tool_calls) {
1308+
schema["maxItems"] = 1;
1309+
}
1310+
builder.add_rule("root",
1311+
std::string(data.thinking_forced_open ? "( \"<|inner_suffix|>\" space )? " : "") +
1312+
"\"<|tools_prefix|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tools_suffix|>\"");
1313+
});
1314+
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
1315+
// If thinking_forced_open, then we capture the <|inner_suffix|> tag in the grammar,
1316+
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
1317+
std::string(data.thinking_forced_open ?
1318+
"[\\s\\S]*?(<\\|inner_suffix\\|>\\s*)" :
1319+
"(?:<\\|inner_prefix\\|>[\\s\\S]*?<\\|inner_suffix\\|>\\s*)?") +
1320+
"(<\\|tools_prefix\\|>)[\\s\\S]*" });
1321+
data.preserved_tokens = {
1322+
"<|system_start|>",
1323+
"<|system_end|>",
1324+
"<|developer_start|>",
1325+
"<|developer_end|>",
1326+
"<|user_start|>",
1327+
"<|user_end|>",
1328+
"<|assistant_start|>",
1329+
"<|assistant_end|>",
1330+
"<|inner_prefix|>",
1331+
"<|inner_suffix|>",
1332+
"<|tools_prefix|>",
1333+
"<|tools_suffix|>",
1334+
};
1335+
}
1336+
return data;
1337+
}
12671338
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
12681339
if (!builder.syntax().parse_tool_calls) {
12691340
builder.add_content(builder.consume_rest());
@@ -2323,6 +2394,37 @@ static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
23232394
builder.add_content(builder.consume_rest());
23242395
}
23252396

2397+
static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
2398+
// Parse thinking tags
2399+
builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>");
2400+
if (!builder.syntax().parse_tool_calls) {
2401+
builder.add_content(builder.consume_rest());
2402+
return;
2403+
}
2404+
2405+
// Look for tool calls
2406+
static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>"));
2407+
if (auto res = builder.try_find_regex(tool_call_regex)) {
2408+
builder.move_to(res->groups[0].end);
2409+
2410+
auto tool_calls_data = builder.consume_json();
2411+
if (tool_calls_data.json.is_array()) {
2412+
builder.consume_spaces();
2413+
if (!builder.try_consume_literal("<|tools_suffix|>")) {
2414+
throw common_chat_msg_partial_exception("Incomplete tool call");
2415+
}
2416+
for (const auto & value : tool_calls_data.json) {
2417+
if (value.is_object()) {
2418+
builder.add_tool_call_short_form(value);
2419+
}
2420+
}
2421+
} else {
2422+
throw common_chat_msg_partial_exception("Incomplete tool call");
2423+
}
2424+
}
2425+
builder.add_content(builder.consume_rest());
2426+
}
2427+
23262428
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
23272429
// Parse thinking tags first - this handles the main reasoning content
23282430
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
@@ -2567,6 +2669,11 @@ static common_chat_params common_chat_templates_apply_jinja(
25672669
return common_chat_params_init_nemotron_v2(tmpl, params);
25682670
}
25692671

2672+
// Apertus format detection
2673+
if (src.find("<|system_start|>") != std::string::npos && src.find("<|tools_prefix|>") != std::string::npos) {
2674+
return common_chat_params_init_apertus(tmpl, params);
2675+
}
2676+
25702677
// Use generic handler when mixing tools + JSON schema.
25712678
// TODO: support that mix in handlers below.
25722679
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -2734,6 +2841,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
27342841
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
27352842
common_chat_parse_nemotron_v2(builder);
27362843
break;
2844+
case COMMON_CHAT_FORMAT_APERTUS:
2845+
common_chat_parse_apertus(builder);
2846+
break;
27372847
default:
27382848
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
27392849
}

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ enum common_chat_format {
114114
COMMON_CHAT_FORMAT_GPT_OSS,
115115
COMMON_CHAT_FORMAT_SEED_OSS,
116116
COMMON_CHAT_FORMAT_NEMOTRON_V2,
117+
COMMON_CHAT_FORMAT_APERTUS,
117118

118119
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
119120
};

convert_hf_to_gguf.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8945,6 +8945,43 @@ def prepare_tensors(self):
89458945
raise ValueError(f"Unprocessed experts: {experts}")
89468946

89478947

8948+
@ModelBase.register("ApertusForCausalLM")
8949+
class ApertusModel(LlamaModel):
8950+
model_arch = gguf.MODEL_ARCH.APERTUS
8951+
undo_permute = False
8952+
8953+
_alpha_n = {}
8954+
_alpha_p = {}
8955+
_beta = {}
8956+
_eps = {}
8957+
8958+
def modify_tensors(self, data_torch, name, bid):
8959+
# Handle xIELU activation parameters
8960+
n_layers = self.hparams["num_hidden_layers"]
8961+
if name.endswith(".act_fn.alpha_n"):
8962+
self._alpha_n[bid] = data_torch.to("cpu").float().item()
8963+
if (len(self._alpha_n) == n_layers):
8964+
self.gguf_writer.add_xielu_alpha_n([self._alpha_n[k] for k in sorted(self._alpha_n)])
8965+
return []
8966+
if name.endswith(".act_fn.alpha_p"):
8967+
self._alpha_p[bid] = data_torch.to("cpu").float().item()
8968+
if (len(self._alpha_p) == n_layers):
8969+
self.gguf_writer.add_xielu_alpha_p([self._alpha_p[k] for k in sorted(self._alpha_p)])
8970+
return []
8971+
if name.endswith(".act_fn.beta"):
8972+
self._beta[bid] = data_torch.to("cpu").float().item()
8973+
if (len(self._beta) == n_layers):
8974+
self.gguf_writer.add_xielu_beta([self._beta[k] for k in sorted(self._beta)])
8975+
return []
8976+
if name.endswith(".act_fn.eps"):
8977+
self._eps[bid] = data_torch.to("cpu").float().item()
8978+
if (len(self._eps) == n_layers):
8979+
self.gguf_writer.add_xielu_eps([self._eps[k] for k in sorted(self._eps)])
8980+
return []
8981+
8982+
return super().modify_tensors(data_torch, name, bid)
8983+
8984+
89488985
class MistralModel(LlamaModel):
89498986
model_arch = gguf.MODEL_ARCH.LLAMA
89508987
model_name = "Mistral"
@@ -9112,7 +9149,7 @@ def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -
91129149
def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
91139150
dtype = cls._dtype_str_map[st_slice.get_dtype()]
91149151
shape: tuple[int, ...] = tuple(st_slice.get_shape())
9115-
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
9152+
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:])
91169153
return cast(torch.Tensor, lazy)
91179154

91189155
@classmethod

ggml/include/ggml.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ extern "C" {
576576
GGML_UNARY_OP_HARDSIGMOID,
577577
GGML_UNARY_OP_EXP,
578578
GGML_UNARY_OP_GELU_ERF,
579+
GGML_UNARY_OP_XIELU,
579580

580581
GGML_UNARY_OP_COUNT,
581582
};
@@ -1150,6 +1151,18 @@ extern "C" {
11501151
struct ggml_context * ctx,
11511152
struct ggml_tensor * a);
11521153

1154+
// xIELU activation function
1155+
// x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
1156+
// where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
1157+
// that constrain the positive and negative source alpha values respectively
1158+
GGML_API struct ggml_tensor * ggml_xielu(
1159+
struct ggml_context * ctx,
1160+
struct ggml_tensor * a,
1161+
float alpha_n,
1162+
float alpha_p,
1163+
float beta,
1164+
float eps);
1165+
11531166
// gated linear unit ops
11541167
// A: n columns, r rows,
11551168
// result is n / 2 columns, r rows,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2187,6 +2187,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
21872187
case GGML_UNARY_OP_GELU_ERF:
21882188
case GGML_UNARY_OP_GELU_QUICK:
21892189
case GGML_UNARY_OP_SILU:
2190+
case GGML_UNARY_OP_XIELU:
21902191
{
21912192
n_tasks = n_threads;
21922193
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8637,7 +8637,7 @@ static void ggml_compute_forward_ssm_scan_f32(
86378637
// n_head
86388638
for (int h = ih0; h < ih1; ++h) {
86398639
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8640-
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8640+
const float dt_soft_plus = ggml_softplus(dt[h]);
86418641
const float dA = expf(dt_soft_plus * A[h]);
86428642
const int g = h / (nh / ng); // repeat_interleave
86438643

@@ -8734,7 +8734,7 @@ static void ggml_compute_forward_ssm_scan_f32(
87348734
// n_head
87358735
for (int h = ih0; h < ih1; ++h) {
87368736
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8737-
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8737+
const float dt_soft_plus = ggml_softplus(dt[h]);
87388738
const int g = h / (nh / ng); // repeat_interleave
87398739

87408740
// dim
@@ -8997,6 +8997,10 @@ void ggml_compute_forward_unary(
89978997
{
89988998
ggml_compute_forward_exp(params, dst);
89998999
} break;
9000+
case GGML_UNARY_OP_XIELU:
9001+
{
9002+
ggml_compute_forward_xielu(params, dst);
9003+
} break;
90009004
default:
90019005
{
90029006
GGML_ABORT("fatal error");

0 commit comments

Comments
 (0)