Skip to content
Merged
5 changes: 3 additions & 2 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,9 @@ struct llm_tokenizer_bpe : llm_tokenizer {
break;
case LLAMA_VOCAB_PRE_TYPE_AFMOE:
regex_exprs = {
// Digits in groups of 1-3
"\\p{N}{1,3}",
// Digit handling - uses custom implementation in unicode.cpp
// Groups digits with leading 1-2 based on total length modulo 3
"\\p{Nd}+",
// CJK and Asian scripts (using direct Unicode literals)
"[一-鿿㐀-䶿豈-﫿぀-ゟ゠-ヿ・-゚⼀-⿟เ-๿຀-໿ក-៿က-႟ꩠ-ꩿꧠ-꧿가-힯ᄀ-ᇿ]+",
// Main BPE pattern
Expand Down
77 changes: 77 additions & 0 deletions src/unicode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,80 @@ static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string
return bpe_offsets;
}

// AFMOE digit handling: splits digits with leading 1-2 based on total length modulo 3
static std::vector<size_t> unicode_regex_split_custom_afmoe(const std::string & text, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets;
bpe_offsets.reserve(offsets.size());

const auto cpts = unicode_cpts_from_utf8(text);

size_t start = 0;
for (auto offset : offsets) {
const size_t offset_ini = start;
const size_t offset_end = start + offset;
assert(offset_end <= cpts.size());
start = offset_end;

auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
};

size_t _prev_end = offset_ini;
auto _add_token = [&] (const size_t end) -> size_t {
assert(_prev_end <= end && end <= offset_end);
size_t len = end - _prev_end;
if (len > 0) {
bpe_offsets.push_back(len);
}
_prev_end = end;
return len;
};

for (size_t pos = offset_ini; pos < offset_end; ) {
const auto flags = _get_flags(pos);

// Handle digit sequences with special splitting logic
if (flags.is_number) {
size_t digit_start = pos;
size_t digit_count = 0;

// Count consecutive digits
while (_get_flags(pos).is_number && pos < offset_end) {
digit_count++;
pos++;
}

// Split based on total length modulo 3
size_t remainder = digit_count % 3;
size_t current = digit_start;

// Emit leading 1-2 digits if needed
if (remainder > 0) {
_add_token(current + remainder);
current += remainder;
}

// Emit groups of 3
while (current < digit_start + digit_count) {
_add_token(current + 3);
current += 3;
}
continue;
}

// For non-digits, just move forward
pos++;
}

// Add any remaining content
if (_prev_end < offset_end) {
_add_token(offset_end);
}
}

return bpe_offsets;
}

static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets;

Expand All @@ -742,6 +816,9 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
} else if (regex_expr == "\\p{Han}+") {
// K2's first pattern - handle all K2 patterns together
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
} else if (regex_expr == "\\p{Nd}+") {
// AFMOE digit pattern - use custom implementation for proper splitting
bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets);
}

return bpe_offsets;
Expand Down
Loading