Skip to content

Commit 2c9f775

Browse files
authored
Tweak CreateByModelNameAsync (#7015)
- Add a CancellationToken to CreateByModelNameAsync, allowing the download and parsing to be canceled. - Use ReadLineAsync(cancellationToken), which not only allows it to be canceled, but avoids ~100K task allocations - Fix Helpers.FromBase64String to support lines longer than 300 chars
1 parent 3282f44 commit 2c9f775

File tree

4 files changed

+85
-36
lines changed

4 files changed

+85
-36
lines changed

src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.IO;
1010
using System.Linq;
1111
using System.Text;
12+
using System.Threading;
1213
using System.Threading.Tasks;
1314

1415
namespace Microsoft.ML.Tokenizers
@@ -111,9 +112,11 @@ private Tiktoken(int cacheSize)
111112
/// </summary>
112113
/// <param name="tikTokenBpeFileStream">Stream to the BPE rank file</param>
113114
/// <param name="useAsync">Whether to perform I/O synchronously or asynchronously.</param>
115+
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
114116
/// <returns>Map of byte[] to integer token id</returns>
115117
/// <exception cref="InvalidOperationException"></exception>
116-
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<string, int>, IReadOnlyDictionary<int, byte[]>)> LoadTikTokenBpeAsync(Stream tikTokenBpeFileStream, bool useAsync)
118+
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<string, int>, IReadOnlyDictionary<int, byte[]>)> LoadTikTokenBpeAsync(
119+
Stream tikTokenBpeFileStream, bool useAsync, CancellationToken cancellationToken = default)
117120
{
118121
var encoder = new Dictionary<ReadOnlyMemory<byte>, int>(ReadOnlyMemoryByteComparer.Instance);
119122
var vocab = new Dictionary<string, int>();
@@ -126,7 +129,7 @@ private Tiktoken(int cacheSize)
126129
while (true)
127130
{
128131
string? line = useAsync ?
129-
await reader.ReadLineAsync().ConfigureAwait(false) :
132+
await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
130133
reader.ReadLine();
131134
if (string.IsNullOrWhiteSpace(line))
132135
{
@@ -143,10 +146,10 @@ await reader.ReadLineAsync().ConfigureAwait(false) :
143146
throw new FormatException($"Invalid format in the BPE encoder file stream");
144147
}
145148

146-
byte[] tokenBytes = Helpers.FromBase64String(line, 0, spaceIndex);
147-
148149
if (Helpers.TryParseInt32(line, spaceIndex + 1, out int rank))
149150
{
151+
byte[] tokenBytes = Helpers.FromBase64String(line, 0, spaceIndex);
152+
150153
encoder[tokenBytes] = rank;
151154
decoder[rank] = tokenBytes;
152155

@@ -221,7 +224,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
221224
// cache miss
222225
if (_vocab.TryGetValue(sequence, out int mappedId))
223226
{
224-
return new List<Token> { new(mappedId, sequence, (0, sequence.Length)) };
227+
return new Token[1] { new(mappedId, sequence, (0, sequence.Length)) };
225228
}
226229

227230
byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length));

src/Microsoft.ML.Tokenizers/Tokenizer.cs

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.IO;
1010
using System.Net.Http;
1111
using System.Text.RegularExpressions;
12+
using System.Threading;
1213
using System.Threading.Tasks;
1314

1415
namespace Microsoft.ML.Tokenizers
@@ -346,32 +347,41 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo
346347
/// <param name="modelName">Model name</param>
347348
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
348349
/// <param name="normalizer">To normalize the text before tokenization</param>
350+
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
349351
/// <returns>The tokenizer</returns>
350-
public static async Task<Tokenizer> CreateByModelNameAsync(
352+
public static Task<Tokenizer> CreateByModelNameAsync(
351353
string modelName,
352354
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
353-
Normalizer? normalizer = null)
355+
Normalizer? normalizer = null,
356+
CancellationToken cancellationToken = default)
354357
{
355-
ModelEncoding encoder;
356-
357-
if (!_modelToEncoding.TryGetValue(modelName, out encoder))
358+
try
358359
{
359-
foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding)
360+
ModelEncoding encoder;
361+
362+
if (!_modelToEncoding.TryGetValue(modelName, out encoder))
360363
{
361-
if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase))
364+
foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding)
362365
{
363-
encoder = Encoding;
364-
break;
366+
if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase))
367+
{
368+
encoder = Encoding;
369+
break;
370+
}
365371
}
366372
}
367-
}
368373

369-
if (encoder == ModelEncoding.None)
374+
if (encoder == ModelEncoding.None)
375+
{
376+
throw new NotImplementedException($"Doesn't support this model [{modelName}]");
377+
}
378+
379+
return CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer, cancellationToken);
380+
}
381+
catch (Exception ex)
370382
{
371-
throw new NotImplementedException($"Doesn't support this model [{modelName}]");
383+
return Task.FromException<Tokenizer>(ex);
372384
}
373-
374-
return await CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer).ConfigureAwait(false);
375385
}
376386

377387
private const string Cl100kBaseRegexPattern = @"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
@@ -402,36 +412,38 @@ public static async Task<Tokenizer> CreateByModelNameAsync(
402412
/// <param name="modelEncoding">Encoder label</param>
403413
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoder</param>
404414
/// <param name="normalizer">To normalize the text before tokenization</param>
415+
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
405416
/// <returns>The tokenizer</returns>
406417
/// <exception cref="NotImplementedException">Throws if the encoder is not supported</exception>
407-
private static async Task<Tokenizer> CreateByEncoderNameAsync(
418+
private static Task<Tokenizer> CreateByEncoderNameAsync(
408419
ModelEncoding modelEncoding,
409420
IReadOnlyDictionary<string, int>? extraSpecialTokens,
410-
Normalizer? normalizer)
421+
Normalizer? normalizer,
422+
CancellationToken cancellationToken)
411423
{
412424
switch (modelEncoding)
413425
{
414426
case ModelEncoding.Cl100kBase:
415427
var specialTokens = new Dictionary<string, int>
416428
{ { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} };
417-
return await CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
429+
return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
418430

419431
case ModelEncoding.P50kBase:
420432
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 } };
421-
return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
433+
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
422434

423435
case ModelEncoding.P50kEdit:
424436
specialTokens = new Dictionary<string, int>
425437
{ { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } };
426-
return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
438+
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
427439

428440
case ModelEncoding.R50kBase:
429441
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 } };
430-
return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
442+
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
431443

432444
case ModelEncoding.GPT2:
433445
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 }, };
434-
return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
446+
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
435447

436448
default:
437449
Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]");
@@ -449,13 +461,15 @@ private static async Task<Tokenizer> CreateByEncoderNameAsync(
449461
/// <param name="specialTokens">Special tokens mapping. This may be mutated by the method.</param>
450462
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoder</param>
451463
/// <param name="normalizer">To normalize the text before tokenization</param>
464+
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
452465
/// <returns>The tokenizer</returns>
453466
private static async Task<Tokenizer> CreateTikTokenTokenizerAsync(
454467
Regex regex,
455468
string mergeableRanksFileUrl,
456469
Dictionary<string, int> specialTokens,
457470
IReadOnlyDictionary<string, int>? extraSpecialTokens,
458-
Normalizer? normalizer)
471+
Normalizer? normalizer,
472+
CancellationToken cancellationToken)
459473
{
460474
if (extraSpecialTokens is not null)
461475
{
@@ -467,9 +481,9 @@ private static async Task<Tokenizer> CreateTikTokenTokenizerAsync(
467481

468482
if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<string, int> vocab, IReadOnlyDictionary<int, byte[]> decoder) cache))
469483
{
470-
using (Stream stream = await _httpClient.GetStreamAsync(mergeableRanksFileUrl).ConfigureAwait(false))
484+
using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false))
471485
{
472-
cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true).ConfigureAwait(false);
486+
cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false);
473487
}
474488

475489
_tiktokenCache.TryAdd(mergeableRanksFileUrl, cache);

src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,41 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Buffers.Text;
7+
using System.Diagnostics;
68
using System.Globalization;
9+
using System.IO;
10+
using System.Threading.Tasks;
11+
using System.Threading;
12+
using System.Net.Http;
713

814
namespace Microsoft.ML.Tokenizers
915
{
1016
internal static class Helpers
1117
{
18+
public static ValueTask<string?> ReadLineAsync(StreamReader reader, CancellationToken cancellationToken) =>
19+
reader.ReadLineAsync(cancellationToken);
20+
21+
public static Task<Stream> GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken) =>
22+
client.GetStreamAsync(url, cancellationToken);
23+
1224
public static byte[] FromBase64String(string base64String, int offset, int length)
1325
{
14-
Span<byte> bytes = stackalloc byte[300];
15-
if (!Convert.TryFromBase64Chars(base64String.AsSpan().Slice(offset, length), bytes, out int bytesWritten))
26+
if (!Base64.IsValid(base64String.AsSpan(offset, length), out int decodedLength))
1627
{
17-
throw new System.FormatException($"Invalid base64 string '{base64String.Substring(offset, length)}'");
28+
throw new FormatException($"Invalid base64 string '{base64String.Substring(offset, length)}'");
1829
}
19-
return bytes.Slice(0, bytesWritten).ToArray();
30+
31+
byte[] bytes = new byte[decodedLength];
32+
bool success = Convert.TryFromBase64Chars(base64String.AsSpan(offset, length), bytes, out int bytesWritten);
33+
Debug.Assert(success);
34+
Debug.Assert(bytes.Length == bytesWritten);
35+
return bytes;
2036
}
2137

2238
internal static bool TryParseInt32(string s, int offset, out int result)
2339
=> int.TryParse(s.AsSpan().Slice(offset), NumberStyles.None, CultureInfo.InvariantCulture, out result);
2440
}
2541
}
26-

src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,30 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.IO;
7+
using System.Net.Http;
8+
using System.Threading;
9+
using System.Threading.Tasks;
610

711
namespace Microsoft.ML.Tokenizers
812
{
913
internal static class Helpers
1014
{
15+
public static ValueTask<string> ReadLineAsync(StreamReader reader, CancellationToken cancellationToken)
16+
{
17+
cancellationToken.ThrowIfCancellationRequested();
18+
return new ValueTask<string>(reader.ReadLineAsync());
19+
}
20+
21+
public static async Task<Stream> GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken)
22+
{
23+
HttpResponseMessage response = await client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
24+
response.EnsureSuccessStatusCode();
25+
return await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
26+
}
27+
1128
public static byte[] FromBase64String(string base64String, int offset, int length) => Convert.FromBase64String(base64String.Substring(offset, length));
1229

1330
// Not support signed number

0 commit comments

Comments
 (0)