Skip to content

Commit 024ed11

Browse files
refactor: Improve tokenizer type detection logic when not specified.
1 parent 57d019e commit 024ed11

File tree

10 files changed

+172
-214
lines changed

10 files changed

+172
-214
lines changed

src/Normalizers/BertNormalizer.php

Lines changed: 76 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,57 @@
22

33
declare(strict_types=1);
44

5-
65
namespace Codewithkyrian\Transformers\Normalizers;
76

8-
use Codewithkyrian\Transformers\Tokenizers\TokenizerModel;
9-
107
/**
118
* A class representing a normalizer used in BERT tokenization.
129
*/
1310
class BertNormalizer extends Normalizer
1411
{
12+
/**
13+
* Performs invalid character removal and whitespace cleanup on text.
14+
* @param string $text The text to clean.
15+
* @return string The cleaned text.
16+
* @private
17+
*/
18+
function cleanText(string $text): string
19+
{
20+
$output = [];
21+
for ($i = 0; $i < mb_strlen($text); ++$i) {
22+
$char = mb_substr($text, $i, 1);
23+
$cp = mb_ord($char);
24+
if ($cp === 0 || $cp === 0xFFFD || $this->isControl($char)) {
25+
continue;
26+
}
27+
if (preg_match('/^\s$/', $char)) { // is whitespace
28+
$output[] = " ";
29+
} else {
30+
$output[] = $char;
31+
}
32+
}
33+
return implode("", $output);
34+
}
35+
36+
public function normalize(string $text): string
37+
{
38+
if ($this->config['clean_text'] ?? false) {
39+
$text = $this->cleanText($text);
40+
}
41+
42+
if ($this->config['handle_chinese_chars'] ?? false) {
43+
$text = $this->tokenizeChineseChars($text);
44+
}
45+
46+
if ($this->config['lowercase'] ?? false) {
47+
$text = mb_strtolower($text);
48+
}
49+
50+
if ($this->config['strip_accents'] ?? false) {
51+
$text = $this->stripAccents($text);
52+
}
53+
54+
return $text;
55+
}
1556

1657
/**
1758
* Strips accents from the given text.
@@ -43,47 +84,49 @@ protected function isControl(string $char): bool
4384
}
4485

4586
/**
46-
* Performs invalid character removal and whitespace cleanup on text.
47-
* @param string $text The text to clean.
48-
* @return string The cleaned text.
49-
* @private
87+
* Checks whether the given Unicode codepoint represents a CJK (Chinese, Japanese, or Korean) character.
88+
*
89+
* A "chinese character" is defined as anything in the CJK Unicode block.
90+
*
91+
* @param int $cp The Unicode codepoint to check.
92+
*
93+
* @return bool True if the codepoint represents a CJK character, false otherwise.
5094
*/
51-
function cleanText(string $text): string
95+
protected function isChineseChar(int $cp): bool
96+
{
97+
return (
98+
($cp >= 0x4E00 && $cp <= 0x9FFF)
99+
|| ($cp >= 0x3400 && $cp <= 0x4DBF)
100+
|| ($cp >= 0x20000 && $cp <= 0x2A6DF)
101+
|| ($cp >= 0x2A700 && $cp <= 0x2B73F)
102+
|| ($cp >= 0x2B740 && $cp <= 0x2B81F)
103+
|| ($cp >= 0x2B820 && $cp <= 0x2CEAF)
104+
|| ($cp >= 0xF900 && $cp <= 0xFAFF)
105+
|| ($cp >= 0x2F800 && $cp <= 0x2FA1F)
106+
);
107+
}
108+
109+
/**
110+
* Adds whitespace around any CJK (Chinese, Japanese, or Korean) character in the input text.
111+
*
112+
* @param string $text The input text to tokenize.
113+
*
114+
* @return string The tokenized text with whitespace added around CJK characters.
115+
*/
116+
public function tokenizeChineseChars(string $text): string
52117
{
53118
$output = [];
54119
for ($i = 0; $i < mb_strlen($text); ++$i) {
55120
$char = mb_substr($text, $i, 1);
56121
$cp = mb_ord($char);
57-
if ($cp === 0 || $cp === 0xFFFD || $this->isControl($char)) {
58-
continue;
59-
}
60-
if (preg_match('/^\s$/', $char)) { // is whitespace
122+
if ($this->isChineseChar($cp)) {
123+
$output[] = " ";
124+
$output[] = $char;
61125
$output[] = " ";
62126
} else {
63127
$output[] = $char;
64128
}
65129
}
66130
return implode("", $output);
67131
}
68-
69-
public function normalize(string $text): string
70-
{
71-
if ($this->config['clean_text'] ?? false) {
72-
$text = $this->cleanText($text);
73-
}
74-
75-
if ($this->config['handle_chinese_chars'] ?? false) {
76-
$text = TokenizerModel::tokenizeChineseChars($text);
77-
}
78-
79-
if ($this->config['lowercase'] ?? false) {
80-
$text = mb_strtolower($text);
81-
}
82-
83-
if ($this->config['strip_accents'] ?? false) {
84-
$text = $this->stripAccents($text);
85-
}
86-
87-
return $text;
88-
}
89132
}

src/Normalizers/StripAccents.php

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,20 @@
22

33
declare(strict_types=1);
44

5-
65
namespace Codewithkyrian\Transformers\Normalizers;
76

8-
use Codewithkyrian\Transformers\Tokenizers\TokenizerModel;
9-
107
/**
118
* StripAccents normalizer removes all accents from the text.
129
*/
1310
class StripAccents extends Normalizer
1411
{
15-
12+
/**
13+
* Removes accents from the text.
14+
* @param string $text The text to remove accents from.
15+
* @return string The text with accents removed.
16+
*/
1617
public function normalize(string $text): string
1718
{
18-
return TokenizerModel::removeAccents($text);
19+
return preg_replace('/[\x{0300}-\x{036f}]/u', '', $text);
1920
}
2021
}

src/PreTrainedTokenizers/PreTrainedTokenizer.php

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ public function __construct(protected array $tokenizerJSON, protected ?array $to
103103
$addedTokensPatterns = array_map(function ($x) {
104104
$lstrip = $x->lStrip ? '\s*' : '';
105105
$rstrip = $x->rStrip ? '\s*' : '';
106-
return $lstrip.'('.preg_quote($x->content, '/').')'.$rstrip;
106+
return $lstrip . '(' . preg_quote($x->content, '/') . ')' . $rstrip;
107107
}, $this->addedTokens);
108108

109-
$this->addedTokensRegex = '/'.implode('|', $addedTokensPatterns).'/';
109+
$this->addedTokensRegex = '/' . implode('|', $addedTokensPatterns) . '/';
110110
}
111111

112112
// Set mask token if present
@@ -160,7 +160,7 @@ protected function getToken(string ...$keys): ?string
160160
if ($item['__type'] == 'AddedToken') {
161161
return $item['content'];
162162
} else {
163-
throw new Exception("Unknown token: ".json_encode($item));
163+
throw new Exception("Unknown token: " . json_encode($item));
164164
}
165165
} else {
166166
return $item;
@@ -184,9 +184,8 @@ public static function fromPretrained(
184184
string $modelNameOrPath,
185185
?string $cacheDir = null,
186186
string $revision = 'main',
187-
$legacy = null,
188-
): PreTrainedTokenizer
189-
{
187+
$legacy = null,
188+
): PreTrainedTokenizer {
190189
['tokenizerJson' => $tokenizerJson, 'tokenizerConfig' => $tokenizerConfig] =
191190
TokenizerModel::load($modelNameOrPath, $cacheDir, $revision, $legacy);
192191

@@ -213,8 +212,7 @@ public function tokenize(
213212
bool $truncation = false,
214213
?int $maxLength = null,
215214
bool $returnTensor = true
216-
): array
217-
{
215+
): array {
218216
return $this->__invoke($text, $textPair, $padding, $addSpecialTokens, $truncation, $maxLength, $returnTensor);
219217
}
220218

@@ -239,8 +237,7 @@ public function __invoke(
239237
bool $truncation = false,
240238
?int $maxLength = null,
241239
bool $returnTensor = true
242-
): array
243-
{
240+
): array {
244241
$isBatched = is_array($text);
245242

246243
$encodedTokens = [];
@@ -258,13 +255,13 @@ public function __invoke(
258255
}
259256

260257
$encodedTokens = array_map(
261-
fn ($t, $i) => $this->encodePlus($t, $textPair[$i], $addSpecialTokens),
258+
fn($t, $i) => $this->encodePlus($t, $textPair[$i], $addSpecialTokens),
262259
$text,
263260
array_keys($text)
264261
);
265262
} else {
266263
$encodedTokens = array_map(
267-
fn ($x) => $this->encodePlus($x, addSpecialTokens: $addSpecialTokens),
264+
fn($x) => $this->encodePlus($x, addSpecialTokens: $addSpecialTokens),
268265
$text
269266
);
270267
}
@@ -285,7 +282,7 @@ public function __invoke(
285282
$maxLength = $this->modelMaxLength;
286283
} else {
287284
// Calculate max length from sequences
288-
$maxLength = max(array_map(fn ($x) => count($x['input_ids']), $encodedTokens));
285+
$maxLength = max(array_map(fn($x) => count($x['input_ids']), $encodedTokens));
289286
}
290287
} else {
291288
if (!$truncation) {
@@ -314,7 +311,7 @@ public function __invoke(
314311
$this->padHelper(
315312
$token,
316313
$maxLength,
317-
fn ($key) => $key === 'input_ids' ? $this->padTokenId : 0,
314+
fn($key) => $key === 'input_ids' ? $this->padTokenId : 0,
318315
$this->paddingSide
319316
);
320317
}
@@ -353,15 +350,15 @@ public function __invoke(
353350
continue;
354351
}
355352

356-
$array = array_map(fn ($x) => $x[$key], $encodedTokens);
353+
$array = array_map(fn($x) => $x[$key], $encodedTokens);
357354

358355
$result[$key] = new Tensor($array, Tensor::int64, $shape);
359356
}
360357
} else {
361358
$result = [];
362359

363360
foreach ($encodedTokens[0] as $key => $value) {
364-
$result[$key] = array_map(fn ($x) => $x[$key], $encodedTokens);
361+
$result[$key] = array_map(fn($x) => $x[$key], $encodedTokens);
365362
}
366363

367364
// If not returning a tensor, we match the input type
@@ -388,8 +385,7 @@ public function encodePlus(
388385
string|null $text,
389386
string|null $textPair = null,
390387
bool $addSpecialTokens = true
391-
): array
392-
{
388+
): array {
393389
// Function called by users to encode possibly multiple texts
394390
$tokens = $this->encodeText($text);
395391

@@ -443,9 +439,8 @@ protected function encodeText(?string $text): ?array
443439
$x = preg_replace('/\s+/', ' ', trim($x));
444440
}
445441

446-
447442
if ($this->doLowerCaseAndRemoveAccent) {
448-
$x = TokenizerModel::lowerCaseAndRemoveAccents($x);
443+
$x = $this->lowerCaseAndRemoveAccents($x);
449444
}
450445

451446
if ($this->normalizer !== null) {
@@ -522,7 +517,7 @@ protected function padHelper(array &$item, int $length, Closure $value_fn, strin
522517
*
523518
* @return array
524519
*/
525-
public function encode(string $text, string $textPair = null, bool $addSpecialTokens = true): array
520+
public function encode(string $text, ?string $textPair = null, bool $addSpecialTokens = true): array
526521
{
527522
return $this->encodePlus($text, $textPair, $addSpecialTokens)['input_ids'];
528523
}
@@ -539,7 +534,7 @@ public function encode(string $text, string $textPair = null, bool $addSpecialTo
539534
public function batchDecode(array|Tensor $batch, bool $skipSpecialTokens = false, ?bool $cleanUpTokenizationSpaces = null): array
540535
{
541536
if ($batch instanceof Tensor) $batch = $batch->toArray();
542-
return array_map(fn ($x) => $this->decode($x, $skipSpecialTokens, $cleanUpTokenizationSpaces), $batch);
537+
return array_map(fn($x) => $this->decode($x, $skipSpecialTokens, $cleanUpTokenizationSpaces), $batch);
543538
}
544539

545540
/**
@@ -574,7 +569,7 @@ private function decodeSingle(array $tokenIds, bool $skipSpecialTokens = false,
574569
$tokens = $this->model->convertIdsToTokens($tokenIds);
575570

576571
if ($skipSpecialTokens) {
577-
$tokens = array_values(array_filter($tokens, fn ($x) => !in_array($x, $this->specialTokens)));
572+
$tokens = array_values(array_filter($tokens, fn($x) => !in_array($x, $this->specialTokens)));
578573
}
579574

580575
// If `this.decoder` is null, we just join tokens with a space:
@@ -592,7 +587,6 @@ private function decodeSingle(array $tokenIds, bool $skipSpecialTokens = false,
592587
}
593588
}
594589

595-
596590
if ($cleanUpTokenizationSpaces ?? $this->cleanUpTokenizationSpaces) {
597591
$decoded = TokenizerModel::cleanUpTokenization($decoded);
598592
}
@@ -644,8 +638,7 @@ public function applyChatTemplate(
644638
bool $truncation = false,
645639
?int $maxLength = null,
646640
bool $returnTensor = true
647-
): string|array
648-
{
641+
): string|array {
649642
$chatTemplate ??= $this->chatTemplate ?? $this->getDefaultChatTemplate();
650643

651644
// Compilation function uses a cache to avoid recompiling the same template
@@ -693,4 +686,28 @@ protected function getDefaultChatTemplate(): string
693686

694687
return $this->defaultChatTemplate;
695688
}
689+
690+
/**
691+
* Helper function to lowercase a string and remove accents.
692+
*
693+
* @param string $text The text to lowercase and remove accents from.
694+
*
695+
* @return string The text with accents removed and lowercased.
696+
*/
697+
protected function lowerCaseAndRemoveAccents(string $text): string
698+
{
699+
return mb_strtolower($this->removeAccents($text));
700+
}
701+
702+
/**
703+
* Helper function to remove accents from a string.
704+
*
705+
* @param string $text The text to remove accents from.
706+
*
707+
* @return string The text with accents removed.
708+
*/
709+
protected function removeAccents(string $text): string
710+
{
711+
return preg_replace('/[\x{0300}-\x{036f}]/u', '', $text);
712+
}
696713
}

0 commit comments

Comments
 (0)