Skip to content

Commit 4e743d8

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents ac6206c + 0a338ed commit 4e743d8

File tree

20 files changed

+747
-545
lines changed

20 files changed

+747
-545
lines changed

common/minja/chat-template.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
#include <chrono>
1414
#include <cstddef>
1515
#include <cstdio>
16+
#include <ctime>
1617
#include <exception>
1718
#include <iomanip>
1819
#include <memory>
1920
#include <sstream>
21+
#include <stdexcept>
2022
#include <string>
2123
#include <vector>
2224

@@ -393,8 +395,8 @@ class chat_template {
393395

394396
for (const auto & message_ : adjusted_messages) {
395397
auto message = message_;
396-
if (!message.contains("role") || !message.contains("content")) {
397-
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
398+
if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
399+
throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
398400
}
399401
std::string role = message.at("role");
400402

@@ -415,7 +417,6 @@ class chat_template {
415417
}
416418
}
417419
if (polyfill_tool_calls) {
418-
auto content = message.at("content");
419420
auto tool_calls = json::array();
420421
for (const auto & tool_call : message.at("tool_calls")) {
421422
if (tool_call.at("type") != "function") {
@@ -434,8 +435,11 @@ class chat_template {
434435
auto obj = json {
435436
{"tool_calls", tool_calls},
436437
};
437-
if (!content.is_null() && !content.empty()) {
438-
obj["content"] = content;
438+
if (message.contains("content")) {
439+
auto content = message.at("content");
440+
if (!content.is_null() && !content.empty()) {
441+
obj["content"] = content;
442+
}
439443
}
440444
message["content"] = obj.dump(2);
441445
message.erase("tool_calls");

common/minja/minja.hpp

Lines changed: 69 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <algorithm>
1212
#include <cctype>
1313
#include <cstddef>
14+
#include <cstdint>
1415
#include <cmath>
1516
#include <exception>
1617
#include <functional>
@@ -233,7 +234,7 @@ class Value : public std::enable_shared_from_this<Value> {
233234
}
234235
} else if (is_object()) {
235236
if (!index.is_hashable())
236-
throw std::runtime_error("Unashable type: " + index.dump());
237+
throw std::runtime_error("Unhashable type: " + index.dump());
237238
auto it = object_->find(index.primitive_);
238239
if (it == object_->end())
239240
throw std::runtime_error("Key not found: " + index.dump());
@@ -252,7 +253,7 @@ class Value : public std::enable_shared_from_this<Value> {
252253
auto index = key.get<int>();
253254
return array_->at(index < 0 ? array_->size() + index : index);
254255
} else if (object_) {
255-
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
256+
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
256257
auto it = object_->find(key.primitive_);
257258
if (it == object_->end()) return Value();
258259
return it->second;
@@ -261,7 +262,7 @@ class Value : public std::enable_shared_from_this<Value> {
261262
}
262263
void set(const Value& key, const Value& value) {
263264
if (!object_) throw std::runtime_error("Value is not an object: " + dump());
264-
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
265+
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
265266
(*object_)[key.primitive_] = value;
266267
}
267268
Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
@@ -398,7 +399,7 @@ class Value : public std::enable_shared_from_this<Value> {
398399
}
399400
return false;
400401
} else if (object_) {
401-
if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
402+
if (!value.is_hashable()) throw std::runtime_error("Unhashable type: " + value.dump());
402403
return object_->find(value.primitive_) != object_->end();
403404
} else {
404405
throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
@@ -416,7 +417,7 @@ class Value : public std::enable_shared_from_this<Value> {
416417
return const_cast<Value*>(this)->at(index);
417418
}
418419
Value& at(const Value & index) {
419-
if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
420+
if (!index.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
420421
if (is_array()) return array_->at(index.get<int>());
421422
if (is_object()) return object_->at(index.primitive_);
422423
throw std::runtime_error("Value is not an array or object: " + dump());
@@ -676,8 +677,8 @@ class Expression {
676677
class VariableExpr : public Expression {
677678
std::string name;
678679
public:
679-
VariableExpr(const Location & location, const std::string& n)
680-
: Expression(location), name(n) {}
680+
VariableExpr(const Location & loc, const std::string& n)
681+
: Expression(loc), name(n) {}
681682
std::string get_name() const { return name; }
682683
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
683684
if (!context->contains(name)) {
@@ -1200,9 +1201,9 @@ class DictExpr : public Expression {
12001201

12011202
class SliceExpr : public Expression {
12021203
public:
1203-
std::shared_ptr<Expression> start, end;
1204-
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
1205-
: Expression(loc), start(std::move(s)), end(std::move(e)) {}
1204+
std::shared_ptr<Expression> start, end, step;
1205+
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e, std::shared_ptr<Expression> && st = nullptr)
1206+
: Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {}
12061207
Value do_evaluate(const std::shared_ptr<Context> &) const override {
12071208
throw std::runtime_error("SliceExpr not implemented");
12081209
}
@@ -1219,18 +1220,35 @@ class SubscriptExpr : public Expression {
12191220
if (!index) throw std::runtime_error("SubscriptExpr.index is null");
12201221
auto target_value = base->evaluate(context);
12211222
if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
1222-
auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
1223-
auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
1223+
auto len = target_value.size();
1224+
auto wrap = [len](int64_t i) -> int64_t {
1225+
if (i < 0) {
1226+
return i + len;
1227+
}
1228+
return i;
1229+
};
1230+
int64_t step = slice->step ? slice->step->evaluate(context).get<int64_t>() : 1;
1231+
if (!step) {
1232+
throw std::runtime_error("slice step cannot be zero");
1233+
}
1234+
int64_t start = slice->start ? wrap(slice->start->evaluate(context).get<int64_t>()) : (step < 0 ? len - 1 : 0);
1235+
int64_t end = slice->end ? wrap(slice->end->evaluate(context).get<int64_t>()) : (step < 0 ? -1 : len);
12241236
if (target_value.is_string()) {
12251237
std::string s = target_value.get<std::string>();
1226-
if (start < 0) start = s.size() + start;
1227-
if (end < 0) end = s.size() + end;
1228-
return s.substr(start, end - start);
1238+
1239+
std::string result;
1240+
if (start < end && step == 1) {
1241+
result = s.substr(start, end - start);
1242+
} else {
1243+
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
1244+
result += s[i];
1245+
}
1246+
}
1247+
return result;
1248+
12291249
} else if (target_value.is_array()) {
1230-
if (start < 0) start = target_value.size() + start;
1231-
if (end < 0) end = target_value.size() + end;
12321250
auto result = Value::array();
1233-
for (auto i = start; i < end; ++i) {
1251+
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
12341252
result.push_back(target_value.at(i));
12351253
}
12361254
return result;
@@ -1305,6 +1323,8 @@ class BinaryOpExpr : public Expression {
13051323
if (name == "iterable") return l.is_iterable();
13061324
if (name == "sequence") return l.is_array();
13071325
if (name == "defined") return !l.is_null();
1326+
if (name == "true") return l.to_bool();
1327+
if (name == "false") return !l.to_bool();
13081328
throw std::runtime_error("Unknown type for 'is' operator: " + name);
13091329
};
13101330
auto value = eval();
@@ -1520,6 +1540,10 @@ class MethodCallExpr : public Expression {
15201540
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
15211541
auto suffix = vargs.args[0].get<std::string>();
15221542
return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
1543+
} else if (method->get_name() == "startswith") {
1544+
vargs.expectArgs("startswith method", {1, 1}, {0, 0});
1545+
auto prefix = vargs.args[0].get<std::string>();
1546+
return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin());
15231547
} else if (method->get_name() == "title") {
15241548
vargs.expectArgs("title method", {0, 0}, {0, 0});
15251549
auto res = str;
@@ -2082,28 +2106,37 @@ class Parser {
20822106

20832107
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
20842108
if (!consumeToken("[").empty()) {
2085-
std::shared_ptr<Expression> index;
2109+
std::shared_ptr<Expression> index;
2110+
auto slice_loc = get_location();
2111+
std::shared_ptr<Expression> start, end, step;
2112+
bool has_first_colon = false, has_second_colon = false;
2113+
2114+
if (!peekSymbols({ ":" })) {
2115+
start = parseExpression();
2116+
}
2117+
2118+
if (!consumeToken(":").empty()) {
2119+
has_first_colon = true;
2120+
if (!peekSymbols({ ":", "]" })) {
2121+
end = parseExpression();
2122+
}
20862123
if (!consumeToken(":").empty()) {
2087-
auto slice_end = parseExpression();
2088-
index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
2089-
} else {
2090-
auto slice_start = parseExpression();
2091-
if (!consumeToken(":").empty()) {
2092-
consumeSpaces();
2093-
if (peekSymbols({ "]" })) {
2094-
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
2095-
} else {
2096-
auto slice_end = parseExpression();
2097-
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
2098-
}
2099-
} else {
2100-
index = std::move(slice_start);
2124+
has_second_colon = true;
2125+
if (!peekSymbols({ "]" })) {
2126+
step = parseExpression();
21012127
}
21022128
}
2103-
if (!index) throw std::runtime_error("Empty index in subscript");
2104-
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
2129+
}
2130+
2131+
if ((has_first_colon || has_second_colon) && (start || end || step)) {
2132+
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
2133+
} else {
2134+
index = std::move(start);
2135+
}
2136+
if (!index) throw std::runtime_error("Empty index in subscript");
2137+
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
21052138

2106-
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
2139+
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
21072140
} else if (!consumeToken(".").empty()) {
21082141
auto identifier = parseIdentifier();
21092142
if (!identifier) throw std::runtime_error("Expected identifier in subscript");

convert_hf_to_gguf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,6 +2072,9 @@ def set_gguf_parameters(self):
20722072
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
20732073

20742074
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
2075+
if name.startswith("language_model."):
2076+
name = name.replace("language_model.", "")
2077+
20752078
# split the gate_up into gate and up
20762079
if "gate_up_proj" in name:
20772080
name_up = name.replace("gate_up_proj", "up_proj.weight")

docs/backend/SYCL.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
731731
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
732732
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. |
733733
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
734+
| GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. |
734735
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
735736
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
736737

@@ -741,6 +742,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
741742
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
742743
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
743744
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
745+
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
744746
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
745747

746748

ggml/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC"
193193
option(GGML_SYCL "ggml: use SYCL" OFF)
194194
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
195195
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
196+
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
196197
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
197198
"ggml: sycl target device")
198199
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING

ggml/src/ggml-sycl/CMakeLists.txt

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,34 +49,38 @@ endif()
4949
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
5050

5151
# Link against oneDNN
52-
find_package(DNNL)
5352
set(GGML_SYCL_DNNL 0)
54-
if(DNNL_FOUND)
55-
if (NOT DEFINED DNNL_GPU_VENDOR)
56-
# default to intel target
57-
set(DNNL_GPU_VENDOR "INTEL")
58-
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
59-
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
53+
if(GGML_SYCL_DNN)
54+
find_package(DNNL)
55+
if(DNNL_FOUND)
56+
if (NOT DEFINED DNNL_GPU_VENDOR)
57+
# default to intel target
58+
set(DNNL_GPU_VENDOR "INTEL")
59+
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
60+
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
61+
endif()
6062
endif()
61-
endif()
6263

63-
# Verify oneDNN was compiled for the same target as llama
64-
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
65-
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
66-
set(GGML_SYCL_DNNL 1)
67-
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
68-
foreach(CONFIG ${CONFIGS})
69-
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
70-
message(STATUS "Found oneDNN: ${DNNL_LIB}")
71-
endforeach()
64+
# Verify oneDNN was compiled for the same target as llama
65+
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
66+
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
67+
set(GGML_SYCL_DNNL 1)
68+
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
69+
foreach(CONFIG ${CONFIGS})
70+
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
71+
message(STATUS "Found oneDNN: ${DNNL_LIB}")
72+
endforeach()
73+
else()
74+
message(WARNING
75+
"oneDNN must be compiled for the same target as llama.cpp.
76+
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
77+
Disabling oneDNN support.")
78+
endif()
7279
else()
73-
message(WARNING
74-
"oneDNN must be compiled for the same target as llama.cpp.
75-
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
76-
Disabling oneDNN support.")
80+
message(STATUS "oneDNN not found, disabling oneDNN support")
7781
endif()
7882
else()
79-
message(STATUS "oneDNN not found, disabling oneDNN support")
83+
message(STATUS "oneDNN support disabled by the user")
8084
endif()
8185
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
8286

0 commit comments

Comments
 (0)