Skip to content

Commit 3e031bc

Browse files
committed
Improve regex matching whilst still validating allowed tensors
1 parent 6c281ad commit 3e031bc

File tree

2 files changed

+19
-34
lines changed

2 files changed

+19
-34
lines changed

examples/quantize/quantize.cpp

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
261261
printf("\n%s: missing tensor name\n\n", __func__);
262262
return false;
263263
}
264-
265264
if (const size_t qt_len = strlen(sep); qt_len == 1) {
266265
printf("\n%s: missing quantization type\n\n", __func__);
267266
return false;
@@ -270,37 +269,15 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
270269
std::string tn(data, tn_len);
271270
std::transform(tn.begin(), tn.end(), tn.begin(), tolower);
272271
sep++;
273-
const std::string qt(sep);
274-
275-
bool found = false;
276-
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
277-
std::string tensor;
278-
tensor = tn.rfind('.') != std::string::npos ? tn.substr(tn.rfind('.') + 1) : tn;
279-
// handle special case of cls.output
280-
std::string cls_output = "cls.output";
281-
if (tn.find(cls_output) != std::string::npos) {
282-
tensor = "cls.output";
283-
}
284-
// check if an allowed tensor exists and it's at the end of the kv string
285-
if (tensor == allowed) {
286-
found = true;
287-
break;
288-
}
289-
}
290-
if (!found) {
291-
printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str());
292-
return false;
293-
}
294-
295-
if (parse_ggml_type(qt.c_str()) == GGML_TYPE_COUNT) {
296-
printf("\n%s: invalid quantization type '%s'\n\n", __func__, qt.c_str());
297-
return false;
298-
}
299-
300272
tensor_quantization tqz;
301273
tqz.name = tn;
302-
tqz.quant = parse_ggml_type(qt.c_str());
274+
tqz.quant = parse_ggml_type(sep);
303275
tensor_type.emplace_back(std::move(tqz));
276+
if (tqz.quant == GGML_TYPE_COUNT) {
277+
printf("\n%s: invalid quantization type '%s'\n\n", __func__, sep);
278+
return false;
279+
}
280+
304281
return true;
305282
}
306283

src/llama-quant.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -790,17 +790,25 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
790790
// unless the user specifies a type
791791
if (params->tensor_types) {
792792
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
793+
const std::string tensor_name(tensor->name);
793794
for (const auto & [tname, qtype] : tensor_types) {
794-
if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) {
795-
if (qtype != new_type) {
796-
LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype));
795+
if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
796+
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
797+
if (tensor_name.find(allowed) != std::string::npos) {
798+
if (qtype != new_type) {
799+
LLAMA_LOG_DEBUG("(overriding %s), ", ggml_type_name(new_type));
800+
new_type = qtype;
801+
break;
802+
}
803+
}
797804
}
798-
new_type = qtype;
799-
break;
805+
goto loop_exit; // if two or more types are specified for the tensor, first match wins
800806
}
801807
}
802808
}
809+
loop_exit:;
803810
}
811+
804812
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
805813
new_type = params->token_embedding_type;
806814
}

0 commit comments

Comments
 (0)