Skip to content

Commit 2fd0b41

Browse files
committed
Add regex capability for tensor selection
1 parent 625f0ae commit 2fd0b41

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

examples/quantize/quantize.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,15 @@ static bool string_parse_tensor_type(const char * data, std::vector<tensor_quant
319319
sep++;
320320
const std::string qt(sep);
321321

322-
if (find(ALLOWED_TENSOR_TYPE.begin(), ALLOWED_TENSOR_TYPE.end(), tn) == ALLOWED_TENSOR_TYPE.end()) {
322+
bool found = false;
323+
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
324+
// check if an allowed tensor exists and it's at the end of the kv string
325+
if (tn.length() - allowed.length() == tn.find(allowed)) {
326+
found = true;
327+
break;
328+
}
329+
}
330+
if (!found) {
323331
printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str());
324332
return false;
325333
}

src/llama-quant.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <cinttypes>
1111
#include <fstream>
1212
#include <mutex>
13+
#include <regex>
1314
#include <thread>
1415
#include <unordered_map>
1516

@@ -795,9 +796,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
795796
// unless the user specifies a type
796797
if (params->tensor_types) {
797798
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
798-
for (const auto & [name, quant] : tensor_types) {
799-
if (std::string str(tensor->name); str.find(name) != std::string::npos) {
800-
new_type = quant;
799+
for (const auto & [tname, qtype] : tensor_types) {
800+
if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) {
801+
LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype));
802+
new_type = qtype;
801803
break;
802804
}
803805
}

0 commit comments

Comments
 (0)