9
9
using System . IO ;
10
10
using System . Net . Http ;
11
11
using System . Text . RegularExpressions ;
12
+ using System . Threading ;
12
13
using System . Threading . Tasks ;
13
14
14
15
namespace Microsoft . ML . Tokenizers
@@ -346,32 +347,41 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo
346
347
/// <param name="modelName">Model name</param>
347
348
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
348
349
/// <param name="normalizer">To normalize the text before tokenization</param>
350
+ /// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
349
351
/// <returns>The tokenizer</returns>
350
- public static async Task < Tokenizer > CreateByModelNameAsync (
352
+ public static Task < Tokenizer > CreateByModelNameAsync (
351
353
string modelName ,
352
354
IReadOnlyDictionary < string , int > ? extraSpecialTokens = null ,
353
- Normalizer ? normalizer = null )
355
+ Normalizer ? normalizer = null ,
356
+ CancellationToken cancellationToken = default )
354
357
{
355
- ModelEncoding encoder ;
356
-
357
- if ( ! _modelToEncoding . TryGetValue ( modelName , out encoder ) )
358
+ try
358
359
{
359
- foreach ( ( string Prefix , ModelEncoding Encoding ) in _modelPrefixToEncoding )
360
+ ModelEncoding encoder ;
361
+
362
+ if ( ! _modelToEncoding . TryGetValue ( modelName , out encoder ) )
360
363
{
361
- if ( modelName . StartsWith ( Prefix , StringComparison . OrdinalIgnoreCase ) )
364
+ foreach ( ( string Prefix , ModelEncoding Encoding ) in _modelPrefixToEncoding )
362
365
{
363
- encoder = Encoding ;
364
- break ;
366
+ if ( modelName . StartsWith ( Prefix , StringComparison . OrdinalIgnoreCase ) )
367
+ {
368
+ encoder = Encoding ;
369
+ break ;
370
+ }
365
371
}
366
372
}
367
- }
368
373
369
- if ( encoder == ModelEncoding . None )
374
+ if ( encoder == ModelEncoding . None )
375
+ {
376
+ throw new NotImplementedException ( $ "Doesn't support this model [{ modelName } ]") ;
377
+ }
378
+
379
+ return CreateByEncoderNameAsync ( encoder , extraSpecialTokens , normalizer , cancellationToken ) ;
380
+ }
381
+ catch ( Exception ex )
370
382
{
371
- throw new NotImplementedException ( $ "Doesn't support this model [ { modelName } ]" ) ;
383
+ return Task . FromException < Tokenizer > ( ex ) ;
372
384
}
373
-
374
- return await CreateByEncoderNameAsync ( encoder , extraSpecialTokens , normalizer ) . ConfigureAwait ( false ) ;
375
385
}
376
386
377
387
private const string Cl100kBaseRegexPattern = @"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" ;
@@ -402,36 +412,38 @@ public static async Task<Tokenizer> CreateByModelNameAsync(
402
412
/// <param name="modelEncoding">Encoder label</param>
403
413
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoder</param>
404
414
/// <param name="normalizer">To normalize the text before tokenization</param>
415
+ /// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
405
416
/// <returns>The tokenizer</returns>
406
417
/// <exception cref="NotImplementedException">Throws if the encoder is not supported</exception>
407
- private static async Task < Tokenizer > CreateByEncoderNameAsync (
418
+ private static Task < Tokenizer > CreateByEncoderNameAsync (
408
419
ModelEncoding modelEncoding ,
409
420
IReadOnlyDictionary < string , int > ? extraSpecialTokens ,
410
- Normalizer ? normalizer )
421
+ Normalizer ? normalizer ,
422
+ CancellationToken cancellationToken )
411
423
{
412
424
switch ( modelEncoding )
413
425
{
414
426
case ModelEncoding . Cl100kBase :
415
427
var specialTokens = new Dictionary < string , int >
416
428
{ { EndOfText , 100257 } , { FimPrefix , 100258 } , { FimMiddle , 100259 } , { FimSuffix , 100260 } , { EndOfPrompt , 100276 } } ;
417
- return await CreateTikTokenTokenizerAsync ( Cl100kBaseRegex ( ) , Cl100kBaseVocabUrl , specialTokens , extraSpecialTokens , normalizer ) . ConfigureAwait ( false ) ;
429
+ return CreateTikTokenTokenizerAsync ( Cl100kBaseRegex ( ) , Cl100kBaseVocabUrl , specialTokens , extraSpecialTokens , normalizer , cancellationToken ) ;
418
430
419
431
case ModelEncoding . P50kBase :
420
432
specialTokens = new Dictionary < string , int > { { EndOfText , 50256 } } ;
421
- return await CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , P50RanksUrl , specialTokens , extraSpecialTokens , normalizer ) . ConfigureAwait ( false ) ;
433
+ return CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , P50RanksUrl , specialTokens , extraSpecialTokens , normalizer , cancellationToken ) ;
422
434
423
435
case ModelEncoding . P50kEdit :
424
436
specialTokens = new Dictionary < string , int >
425
437
{ { EndOfText , 50256 } , { FimPrefix , 50281 } , { FimMiddle , 50282 } , { FimSuffix , 50283 } } ;
426
- return await CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , P50RanksUrl , specialTokens , extraSpecialTokens , normalizer ) . ConfigureAwait ( false ) ;
438
+ return CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , P50RanksUrl , specialTokens , extraSpecialTokens , normalizer , cancellationToken ) ;
427
439
428
440
case ModelEncoding . R50kBase :
429
441
specialTokens = new Dictionary < string , int > { { EndOfText , 50256 } } ;
430
- return await CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , R50RanksUrl , specialTokens , extraSpecialTokens , normalizer ) . ConfigureAwait ( false ) ;
442
+ return CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , R50RanksUrl , specialTokens , extraSpecialTokens , normalizer , cancellationToken ) ;
431
443
432
444
case ModelEncoding . GPT2 :
433
445
specialTokens = new Dictionary < string , int > { { EndOfText , 50256 } , } ;
434
- return await CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , GPT2Url , specialTokens , extraSpecialTokens , normalizer ) . ConfigureAwait ( false ) ;
446
+ return CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , GPT2Url , specialTokens , extraSpecialTokens , normalizer , cancellationToken ) ;
435
447
436
448
default :
437
449
Debug . Assert ( false , $ "Unexpected encoder [{ modelEncoding } ]") ;
@@ -449,13 +461,15 @@ private static async Task<Tokenizer> CreateByEncoderNameAsync(
449
461
/// <param name="specialTokens">Special tokens mapping. This may be mutated by the method.</param>
450
462
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoder</param>
451
463
/// <param name="normalizer">To normalize the text before tokenization</param>
464
+ /// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
452
465
/// <returns>The tokenizer</returns>
453
466
private static async Task < Tokenizer > CreateTikTokenTokenizerAsync (
454
467
Regex regex ,
455
468
string mergeableRanksFileUrl ,
456
469
Dictionary < string , int > specialTokens ,
457
470
IReadOnlyDictionary < string , int > ? extraSpecialTokens ,
458
- Normalizer ? normalizer )
471
+ Normalizer ? normalizer ,
472
+ CancellationToken cancellationToken )
459
473
{
460
474
if ( extraSpecialTokens is not null )
461
475
{
@@ -467,9 +481,9 @@ private static async Task<Tokenizer> CreateTikTokenTokenizerAsync(
467
481
468
482
if ( ! _tiktokenCache . TryGetValue ( mergeableRanksFileUrl , out ( Dictionary < ReadOnlyMemory < byte > , int > encoder , Dictionary < string , int > vocab , IReadOnlyDictionary < int , byte [ ] > decoder ) cache ) )
469
483
{
470
- using ( Stream stream = await _httpClient . GetStreamAsync ( mergeableRanksFileUrl ) . ConfigureAwait ( false ) )
484
+ using ( Stream stream = await Helpers . GetStreamAsync ( _httpClient , mergeableRanksFileUrl , cancellationToken ) . ConfigureAwait ( false ) )
471
485
{
472
- cache = await Tiktoken . LoadTikTokenBpeAsync ( stream , useAsync : true ) . ConfigureAwait ( false ) ;
486
+ cache = await Tiktoken . LoadTikTokenBpeAsync ( stream , useAsync : true , cancellationToken ) . ConfigureAwait ( false ) ;
473
487
}
474
488
475
489
_tiktokenCache . TryAdd ( mergeableRanksFileUrl , cache ) ;
0 commit comments