@@ -20,7 +20,7 @@ public sealed class Tiktoken : Model
20
20
{
21
21
private readonly Dictionary < ReadOnlyMemory < byte > , int > _encoder = null ! ;
22
22
private readonly IReadOnlyDictionary < int , byte [ ] > _decoder = null ! ;
23
- private readonly LruCache < string , int [ ] > _cache ;
23
+ private readonly LruCache < string , int [ ] > ? _cache ;
24
24
private readonly IReadOnlyDictionary < string , int > ? _specialTokensEncoder ;
25
25
private readonly Dictionary < int , string > ? _specialTokensDecoder ;
26
26
private readonly Dictionary < string , int > _vocab = null ! ;
@@ -96,7 +96,14 @@ private Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary<string, int>?
96
96
97
97
private Tiktoken ( int cacheSize )
98
98
{
99
- _cache = new LruCache < string , int [ ] > ( cacheSize ) ;
99
+ if ( cacheSize < 0 )
100
+ {
101
+ throw new ArgumentOutOfRangeException ( nameof ( cacheSize ) ) ;
102
+ }
103
+ else if ( cacheSize > 0 )
104
+ {
105
+ _cache = new LruCache < string , int [ ] > ( cacheSize ) ;
106
+ }
100
107
}
101
108
102
109
/// <summary>
@@ -198,7 +205,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
198
205
throw new InvalidOperationException ( $ "The special token { sequence } doesn't exist in the tokenizer") ;
199
206
}
200
207
201
- if ( _cache . Lookup ( sequence , out int [ ] ids ) )
208
+ if ( _cache ? . Lookup ( sequence , out int [ ] ids ) is true )
202
209
{
203
210
tokens = new Token [ ids . Length ] ;
204
211
tokens [ 0 ] = new Token ( ids [ 0 ] , sequence , ( 0 , sequence . Length ) ) ;
@@ -222,7 +229,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
222
229
223
230
int [ ] encodedIds = BytePairEncoder . BytePairEncode ( arrayPoolArray . AsMemory ( 0 , encodedLength ) , _encoder ) ;
224
231
Debug . Assert ( encodedIds . Length > 0 ) ;
225
- _cache . Add ( sequence , encodedIds ) ;
232
+ _cache ? . Add ( sequence , encodedIds ) ;
226
233
227
234
tokens = new Token [ encodedIds . Length ] ;
228
235
tokens [ 0 ] = new Token ( encodedIds [ 0 ] , sequence , ( 0 , sequence . Length ) ) ;
@@ -259,7 +266,7 @@ public override void TokenizeToIds(string sequence, bool isSpecialToken, IList<i
259
266
return ;
260
267
}
261
268
262
- if ( _cache . Lookup ( sequence , out int [ ] tokenIds ) )
269
+ if ( _cache ? . Lookup ( sequence , out int [ ] tokenIds ) is true )
263
270
{
264
271
accumulatedIds . AddRange ( tokenIds ) ;
265
272
return ;
@@ -275,7 +282,7 @@ public override void TokenizeToIds(string sequence, bool isSpecialToken, IList<i
275
282
int encodedLength = GetUtf8Bytes ( sequence . AsSpan ( ) , arrayPoolArray ) ;
276
283
277
284
int [ ] encodedIds = BytePairEncoder . BytePairEncode ( arrayPoolArray . AsMemory ( 0 , encodedLength ) , _encoder ) ;
278
- _cache . Add ( sequence , encodedIds ) ;
285
+ _cache ? . Add ( sequence , encodedIds ) ;
279
286
280
287
accumulatedIds . AddRange ( encodedIds ) ;
281
288
@@ -301,7 +308,7 @@ public override int CountTokens(string sequence, bool isSpecialToken)
301
308
return _specialTokensEncoder . TryGetValue ( sequence , out _ ) ? 1 : 0 ;
302
309
}
303
310
304
- if ( _cache . Lookup ( sequence , out int [ ] ids ) )
311
+ if ( _cache ? . Lookup ( sequence , out int [ ] ids ) is true )
305
312
{
306
313
return ids . Length ;
307
314
}
@@ -315,7 +322,7 @@ public override int CountTokens(string sequence, bool isSpecialToken)
315
322
int encodedLength = GetUtf8Bytes ( sequence . AsSpan ( ) , arrayPoolArray ) ;
316
323
317
324
int [ ] encodedIds = BytePairEncoder . BytePairEncode ( arrayPoolArray . AsMemory ( 0 , encodedLength ) , _encoder ) ;
318
- _cache . Add ( sequence , encodedIds ) ;
325
+ _cache ? . Add ( sequence , encodedIds ) ;
319
326
320
327
ArrayPool < byte > . Shared . Return ( arrayPoolArray ) ;
321
328
return encodedIds . Length ;
@@ -346,7 +353,7 @@ public override int CountTokens(string sequence, bool isSpecialToken)
346
353
return specialTokenId ;
347
354
}
348
355
349
- if ( _cache . Lookup ( token , out int [ ] ids ) )
356
+ if ( _cache ? . Lookup ( token , out int [ ] ids ) is true )
350
357
{
351
358
if ( ids . Length == 1 )
352
359
{
@@ -367,7 +374,7 @@ public override int CountTokens(string sequence, bool isSpecialToken)
367
374
int encodedLength = GetUtf8Bytes ( token . AsSpan ( ) , arrayPoolArray ) ;
368
375
369
376
int [ ] idsToCache = BytePairEncoder . BytePairEncode ( arrayPoolArray . AsMemory ( 0 , encodedLength ) , _encoder ) ;
370
- _cache . Add ( token , idsToCache ) ;
377
+ _cache ? . Add ( token , idsToCache ) ;
371
378
372
379
if ( idsToCache . Length == 1 )
373
380
{
0 commit comments