Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 113 additions & 93 deletions common/arg.cpp

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions common/chat-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,23 @@ std::string common_chat_msg_parser::consume_rest() {
}

// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) {
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
return std::nullopt;
}
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
pos_ = m.groups[0].end;

if (add_prelude_to_content) {
add_content(prelude);
}
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
if (is_partial()) {
throw common_chat_msg_partial_exception(regex.str());
}
return std::nullopt;
}
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
pos_ = m.groups[0].end;

return find_regex_result{prelude, m.groups};
}

Expand Down
3 changes: 2 additions & 1 deletion common/chat-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class common_chat_msg_parser {
const std::string & healing_marker() const { return healing_marker_; }
const bool & is_partial() const { return is_partial_; }
const common_chat_msg & result() const { return result_; }
const common_chat_syntax & syntax() const { return syntax_; }

void move_to(size_t pos) {
if (pos > input_.size()) {
Expand Down Expand Up @@ -77,7 +78,7 @@ class common_chat_msg_parser {
std::vector<common_string_range> groups;
};

std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);

bool try_consume_literal(const std::string & literal);

Expand Down
95 changes: 57 additions & 38 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ static std::string string_diff(const std::string & last, const std::string & cur
return current;
}
if (!string_starts_with(current, last)) {
if (string_starts_with(last, current)) {
// This happens if the last generation ended on a partial stop word (not erased),
// and the current ended on a stop word (erased).
return "";
}
throw std::runtime_error("Invalid diff: '" + last + "' not found at start of '" + current + "'");
}
return current.substr(last.size());
Expand Down Expand Up @@ -101,9 +106,9 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
if (!args_diff.empty() || pref.id != newf.id) {
auto & diff = diffs.emplace_back();
diff.tool_call_index = idx;
diff.tool_call_delta.name = newf.name;
if (pref.id != newf.id) {
diff.tool_call_delta.id = newf.id;
diff.tool_call_delta.name = newf.name;
}
diff.tool_call_delta.arguments = args_diff;
}
Expand Down Expand Up @@ -387,22 +392,19 @@ template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_di
delta["content"] = diff.content_delta;
}
if (diff.tool_call_index != std::string::npos) {
json tool_call;
tool_call["index"] = diff.tool_call_index;
if (!diff.tool_call_delta.id.empty()) {
tool_call["id"] = diff.tool_call_delta.id;
tool_call["type"] = "function";
}
json function = json::object();
if (!diff.tool_call_delta.name.empty()) {
function["name"] = diff.tool_call_delta.name;
}
if (!diff.tool_call_delta.id.empty()) {
function["id"] = diff.tool_call_delta.id;
}
if (!diff.tool_call_delta.arguments.empty()) {
function["arguments"] = diff.tool_call_delta.arguments;
}
delta["tool_calls"] = json::array({
json {
{"index", diff.tool_call_index},
{"function", function}
}
});
function["arguments"] = diff.tool_call_delta.arguments;
tool_call["function"] = function;
delta["tool_calls"] = json::array({tool_call});
}
return delta;
}
Expand Down Expand Up @@ -654,7 +656,6 @@ static void parse_json_tool_calls(
}
from = std::string::npos;

builder.add_content(res->prelude);
auto maybe_raw_python = name == "python" && allow_raw_python;
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
Expand Down Expand Up @@ -684,7 +685,6 @@ static void parse_json_tool_calls(
};
if (block_open) {
if (auto res = builder.try_find_regex(*block_open)) {
builder.add_content(res->prelude);
parse_tool_calls();
} else {
builder.add_content(builder.consume_rest());
Expand All @@ -697,7 +697,6 @@ static void parse_json_tool_calls(
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) {
static const std::vector<std::vector<std::string>> args_paths = {{"arguments"}};
if (auto res = builder.try_find_regex(prefix)) {
builder.add_content(res->prelude);
builder.move_back(rstrip_prefix);
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
Expand Down Expand Up @@ -833,6 +832,10 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
return data;
}
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const std::vector<std::vector<std::string>> content_paths = {
{"response"},
};
Expand Down Expand Up @@ -905,6 +908,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
return data;
}
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}

static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
Expand Down Expand Up @@ -999,7 +1007,6 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {

if (auto res = builder.try_find_regex(start_action_regex)) {
// If we didn't extract thoughts, prelude includes them.
builder.add_content(res->prelude);
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
for (const auto & tool_call : tool_calls.value) {
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
Expand All @@ -1014,11 +1021,7 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
}
builder.consume_regex(end_action_regex);
} else if (auto res = builder.try_find_regex(start_response_regex)) {
// If we didn't extract thoughts, prelude includes them.
builder.add_content(res->prelude);
if (auto res = builder.try_find_regex(end_response_regex)) {
builder.add_content(res->prelude);
} else {
if (!builder.try_find_regex(end_response_regex)) {
builder.add_content(builder.consume_rest());
throw common_chat_msg_partial_exception(end_response_regex.str());
}
Expand Down Expand Up @@ -1126,6 +1129,11 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
return data;
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}

static const common_regex function_regex(
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
static const common_regex close_regex("\\}\\s*");
Expand All @@ -1136,8 +1144,6 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
if (with_builtin_tools) {
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
if (auto res = builder.try_find_regex(builtin_call_regex)) {
builder.add_content(res->prelude);

auto fun_res = builder.consume_regex(function_name_regex);
auto function_name = builder.str(fun_res.groups[1]);

Expand Down Expand Up @@ -1253,6 +1259,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
}
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}

static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)");
static const common_regex tool_calls_end("<|tool▁calls▁end|>");
Expand Down Expand Up @@ -1314,6 +1324,10 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
return data;
}
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape(" functools["));
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
}
Expand Down Expand Up @@ -1455,15 +1469,12 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
return data;
}
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));

if (auto res = builder.try_find_regex(python_tag_regex)) {
builder.add_content(res->prelude);
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
builder.add_tool_call("python", "", arguments);
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));

static const common_regex function_regex(R"(<function=(\w+)>)");
static const common_regex close_regex(R"(</function>)");
Expand All @@ -1475,6 +1486,12 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
function_regex,
close_regex,
std::nullopt);

if (auto res = builder.try_find_regex(python_tag_regex)) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
builder.add_tool_call("python", "", arguments);
return;
}
}

static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
Expand Down Expand Up @@ -1593,6 +1610,10 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
}
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}

static const common_regex open_regex(
"(?:"
Expand All @@ -1614,8 +1635,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
);

if (auto res = builder.try_find_regex(open_regex)) {
builder.add_content(res->prelude);

const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";

Expand Down Expand Up @@ -1851,10 +1870,10 @@ static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.add_content(builder.consume_rest());
}

static void common_chat_parse(common_chat_msg_parser & builder, common_chat_format format) {
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(format), builder.input().c_str());
static void common_chat_parse(common_chat_msg_parser & builder) {
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());

switch (format) {
switch (builder.syntax().format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
common_chat_parse_content_only(builder);
break;
Expand Down Expand Up @@ -1889,15 +1908,15 @@ static void common_chat_parse(common_chat_msg_parser & builder, common_chat_form
common_chat_parse_command_r7b(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(format));
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
builder.finish();
}

common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg_parser builder(input, is_partial, syntax);
try {
common_chat_parse(builder, syntax.format);
common_chat_parse(builder);
} catch (const common_chat_msg_partial_exception & ex) {
LOG_DBG("Partial parse: %s\n", ex.what());
if (!is_partial) {
Expand Down
1 change: 1 addition & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ struct common_chat_syntax {
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false;
bool thinking_forced_open = false;
bool parse_tool_calls = true;
};

// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ struct common_params {
int32_t verbosity = 0;
int32_t control_vector_layer_start = -1; // layer range for control vector
int32_t control_vector_layer_end = -1; // layer range for control vector
bool offline = false;

int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
Expand Down
9 changes: 9 additions & 0 deletions docs/backend/CANN.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,15 @@ cmake --build build --config release
### **GitHub contribution**:
Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay.

## Updates
### Basic Flash Attention Support
The basic FA kernel with aclnnops has been added in aclnn_ops.cpp.
Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap.
Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the quantized version in the future.

Authors from Peking University: Bizhao Shi ([email protected]), Yuxin Yang ([email protected]), Ruiyang Ma ([email protected]), and Guojie Luo ([email protected]).

We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers from Huawei Technologies Co., Ltd for their help during the code development and pull request.

## TODO
- Support more models and data types.
4 changes: 2 additions & 2 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu

// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_encode(ctx, batch) < 0) {
LOG_ERR("%s : failed to encode\n", __func__);
if (llama_decode(ctx, batch) < 0) {
LOG_ERR("%s : failed to process\n", __func__);
}

for (int i = 0; i < batch.n_tokens; i++) {
Expand Down
12 changes: 6 additions & 6 deletions examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
}
}

static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_self_clear(ctx);

// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_encode(ctx, batch) < 0) {
LOG_ERR("%s : failed to encode\n", __func__);
if (llama_decode(ctx, batch) < 0) {
LOG_ERR("%s : failed to process\n", __func__);
}

for (int i = 0; i < batch.n_tokens; i++) {
Expand Down Expand Up @@ -233,7 +233,7 @@ int main(int argc, char ** argv) {
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd;
batch_encode(ctx, batch, out, s, n_embd);
batch_process(ctx, batch, out, s, n_embd);
common_batch_clear(batch);
p += s;
s = 0;
Expand All @@ -246,7 +246,7 @@ int main(int argc, char ** argv) {

// final batch
float * out = emb + p * n_embd;
batch_encode(ctx, batch, out, s, n_embd);
batch_process(ctx, batch, out, s, n_embd);

// save embeddings to chunks
for (int i = 0; i < n_chunks; i++) {
Expand All @@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
batch_add_seq(query_batch, query_tokens, 0);

std::vector<float> query_emb(n_embd, 0);
batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd);
batch_process(ctx, query_batch, query_emb.data(), 1, n_embd);

common_batch_clear(query_batch);

Expand Down
4 changes: 2 additions & 2 deletions examples/training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ Proof of concept:

``` sh
export model_name=llama_3.2-1b && export quantization=f32
./build/bin/finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
./build/bin/perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
./build/bin/llama-finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
./build/bin/llama-perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
```

The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs.
Empty file modified ggml/src/ggml-cann/CMakeLists.txt
100644 → 100755
Empty file.
Empty file modified ggml/src/ggml-cann/Doxyfile
100644 → 100755
Empty file.
Loading