Skip to content

Commit b115c28

Browse files
Cache tokenizer output to improve speed of tasks like Zero Shot Classification
1 parent 05e5588 commit b115c28

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

examples/pipelines/zero-shot-classification.php

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,27 @@
44

55
use function Codewithkyrian\Transformers\Pipelines\pipeline;
66
use function Codewithkyrian\Transformers\Utils\memoryUsage;
7+
use function Codewithkyrian\Transformers\Utils\timeUsage;
78

89
require_once './bootstrap.php';
910

1011

1112
//$classifier = pipeline('zero-shot-classification', 'Xenova/mobilebert-uncased-mnli');
1213
//$result = $classifier('Who are you voting for in 2020?', ['politics', 'public health', 'economics', 'elections']);
1314

14-
ini_set('memory_limit', '160M');
15+
ini_set('memory_limit', -1);
1516
$classifier = pipeline('zero-shot-classification', 'Xenova/nli-deberta-v3-xsmall');
17+
18+
$input = "The tension was thick as fog in the arena tonight as the underdogs, the Nets, clawed their way back from a significant deficit to steal a victory from the heavily favored The BUlls in a final score of 120 - Nets to 80 - Bulls
19+
20+
The game was a nail-biter from the start. The Bulls jumped out to an early lead, showcasing their signature fast-paced offense. Net's defense struggled to contain their star player, Frank, who racked up points in the first half.
21+
22+
However, just before halftime, the tide began to turn. The NEts's forward - James hit a series of clutch three-pointers, igniting a spark in the home crowd. The team rallied behind his energy, tightening up their defense and chipping away at the lead.
23+
24+
The second half was a back-and-forth affair, with neither team able to establish a clear advantage. Both sides traded baskets, steals, and blocks, keeping the fans on the edge of their seats. With seconds remaining on the clock, the score was tied.";
1625
$result = $classifier(
17-
'I have a problem with my iphone that needs to be resolved asap!',
18-
['urgent', 'not urgent', 'phone', 'tablet', 'computer'],
26+
$input,
27+
['politics', 'public health', 'economics', 'elections', 'sports', 'entertainment', 'technology', 'business', 'finance', 'education', 'science', 'religion', 'history', 'culture', 'environment', 'weather'],
1928
multiLabel: true
2029
);
2130

@@ -29,5 +38,6 @@
2938
//
3039
//$result = $classifier('Apple just announced the newest iPhone 13', ["technology", "sports", "politics"]);
3140

32-
dd(memoryUsage(), $result);
41+
dd( $result, timeUsage(), memoryUsage());
3342

43+
// Improved from 11.7687s to 2.9687s, 3.5x faster (75% improvement)

src/PretrainedTokenizers/PretrainedTokenizer.php

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
use Codewithkyrian\Transformers\Tokenizers\AddedToken;
1515
use Codewithkyrian\Transformers\Tokenizers\Tokenizer;
1616
use Codewithkyrian\Transformers\Utils\Tensor;
17+
use function Codewithkyrian\Transformers\Utils\timeUsage;
1718

1819
class PretrainedTokenizer
1920
{
@@ -70,6 +71,7 @@ class PretrainedTokenizer
7071

7172
protected mixed $chatTemplate;
7273
protected array $compiledTemplateCache = [];
74+
protected array $tokenizationCache = [];
7375

7476
/**
7577
* @param array $tokenizerJSON The JSON of the tokenizer.
@@ -404,6 +406,13 @@ protected function encodeText(?string $text): ?array
404406
return null;
405407
}
406408

409+
// Hash the text and check if it is in the cache
410+
$hash = hash('sha256', $text);
411+
412+
if (isset($this->tokenizationCache[$hash])) {
413+
return $this->tokenizationCache[$hash];
414+
}
415+
407416
// Actual function which does encoding, for a single text
408417
// First, we take care of special tokens. Needed to avoid issues arising from
409418
// normalization and/or pretokenization (which may not preserve special tokens)
@@ -442,7 +451,12 @@ protected function encodeText(?string $text): ?array
442451
}
443452
}, $sections, array_keys($sections));
444453

445-
return array_merge(...$tokens);
454+
$result = array_merge(...$tokens);
455+
456+
// Cache the result
457+
$this->tokenizationCache[$hash] = $result;
458+
459+
return $result;
446460
}
447461

448462

0 commit comments

Comments
 (0)