Skip to content

Commit 1acb9f4

Browse files
committed
Handle edge case when tensor name is cls.output
1 parent 5a304b8 commit 1acb9f4

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

examples/quantize/quantize.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,15 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
319319

320320
bool found = false;
321321
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
322+
std::string tensor;
323+
tensor = tn.rfind('.') != std::string::npos ? tn.substr(tn.rfind('.') + 1) : tn;
324+
// handle special case of cls.output
325+
std::string cls_output = "cls.output";
326+
if (tn.find(cls_output) != std::string::npos) {
327+
tensor = "cls.output";
328+
}
322329
// check if an allowed tensor exists and it's at the end of the kv string
323-
if (tn.length() - allowed.length() == tn.find(allowed) && tn == allowed) {
330+
if (tensor == allowed) {
324331
found = true;
325332
break;
326333
}

0 commit comments

Comments
 (0)