Skip to content

Commit eb66d73

Browse files
authored
Avoid LruCache in Tiktoken when cacheSize specified is 0 (#7016)
1 parent f976424 commit eb66d73

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public sealed class Tiktoken : Model
2020
{
2121
private readonly Dictionary<ReadOnlyMemory<byte>, int> _encoder = null!;
2222
private readonly IReadOnlyDictionary<int, byte[]> _decoder = null!;
23-
private readonly LruCache<string, int[]> _cache;
23+
private readonly LruCache<string, int[]>? _cache;
2424
private readonly IReadOnlyDictionary<string, int>? _specialTokensEncoder;
2525
private readonly Dictionary<int, string>? _specialTokensDecoder;
2626
private readonly Dictionary<string, int> _vocab = null!;
@@ -96,7 +96,14 @@ private Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary<string, int>?
9696

9797
private Tiktoken(int cacheSize)
9898
{
99-
_cache = new LruCache<string, int[]>(cacheSize);
99+
if (cacheSize < 0)
100+
{
101+
throw new ArgumentOutOfRangeException(nameof(cacheSize));
102+
}
103+
else if (cacheSize > 0)
104+
{
105+
_cache = new LruCache<string, int[]>(cacheSize);
106+
}
100107
}
101108

102109
/// <summary>
@@ -198,7 +205,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
198205
throw new InvalidOperationException($"The special token {sequence} doesn't exist in the tokenizer");
199206
}
200207

201-
if (_cache.Lookup(sequence, out int[] ids))
208+
if (_cache?.Lookup(sequence, out int[] ids) is true)
202209
{
203210
tokens = new Token[ids.Length];
204211
tokens[0] = new Token(ids[0], sequence, (0, sequence.Length));
@@ -222,7 +229,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
222229

223230
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
224231
Debug.Assert(encodedIds.Length > 0);
225-
_cache.Add(sequence, encodedIds);
232+
_cache?.Add(sequence, encodedIds);
226233

227234
tokens = new Token[encodedIds.Length];
228235
tokens[0] = new Token(encodedIds[0], sequence, (0, sequence.Length));
@@ -259,7 +266,7 @@ public override void TokenizeToIds(string sequence, bool isSpecialToken, IList<i
259266
return;
260267
}
261268

262-
if (_cache.Lookup(sequence, out int[] tokenIds))
269+
if (_cache?.Lookup(sequence, out int[] tokenIds) is true)
263270
{
264271
accumulatedIds.AddRange(tokenIds);
265272
return;
@@ -275,7 +282,7 @@ public override void TokenizeToIds(string sequence, bool isSpecialToken, IList<i
275282
int encodedLength = GetUtf8Bytes(sequence.AsSpan(), arrayPoolArray);
276283

277284
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
278-
_cache.Add(sequence, encodedIds);
285+
_cache?.Add(sequence, encodedIds);
279286

280287
accumulatedIds.AddRange(encodedIds);
281288

@@ -301,7 +308,7 @@ public override int CountTokens(string sequence, bool isSpecialToken)
301308
return _specialTokensEncoder.TryGetValue(sequence, out _) ? 1 : 0;
302309
}
303310

304-
if (_cache.Lookup(sequence, out int[] ids))
311+
if (_cache?.Lookup(sequence, out int[] ids) is true)
305312
{
306313
return ids.Length;
307314
}
@@ -315,7 +322,7 @@ public override int CountTokens(string sequence, bool isSpecialToken)
315322
int encodedLength = GetUtf8Bytes(sequence.AsSpan(), arrayPoolArray);
316323

317324
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
318-
_cache.Add(sequence, encodedIds);
325+
_cache?.Add(sequence, encodedIds);
319326

320327
ArrayPool<byte>.Shared.Return(arrayPoolArray);
321328
return encodedIds.Length;
@@ -346,7 +353,7 @@ public override int CountTokens(string sequence, bool isSpecialToken)
346353
return specialTokenId;
347354
}
348355

349-
if (_cache.Lookup(token, out int[] ids))
356+
if (_cache?.Lookup(token, out int[] ids) is true)
350357
{
351358
if (ids.Length == 1)
352359
{
@@ -367,7 +374,7 @@ public override int CountTokens(string sequence, bool isSpecialToken)
367374
int encodedLength = GetUtf8Bytes(token.AsSpan(), arrayPoolArray);
368375

369376
int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
370-
_cache.Add(token, idsToCache);
377+
_cache?.Add(token, idsToCache);
371378

372379
if (idsToCache.Length == 1)
373380
{

0 commit comments

Comments
 (0)