Skip to content

Commit 6f55525

Browse files
authored
Introducing Tiktoken Tokenizer (#6981)
* Introducing Tiktoken Tokenizer * Address the feedback * file renaming
1 parent 902102e commit 6f55525

29 files changed

+2247
-66
lines changed

src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,19 @@
22
<Import Project="$(RepoRoot)eng/pkg/Pack.props" />
33

44
<PropertyGroup>
5-
<TargetFramework>netstandard2.0</TargetFramework>
5+
<TargetFrameworks>netstandard2.0;net8.0</TargetFrameworks>
66
<Nullable>enable</Nullable>
77
<PackageDescription>Microsoft.ML.Tokenizers contains the implmentation of the tokenization used in the NLP transforms.</PackageDescription>
88
</PropertyGroup>
99

10+
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
11+
<Compile Remove="Utils/Helpers.netcoreapp.cs" />
12+
</ItemGroup>
13+
14+
<ItemGroup Condition="'$(TargetFramework)' != 'netstandard2.0'">
15+
<Compile Remove="Utils/Helpers.netstandard.cs" />
16+
</ItemGroup>
17+
1018
<ItemGroup>
1119
<PackageReference Include="System.Text.Json" Version="$(SystemTextJsonVersion)" />
1220
</ItemGroup>

src/Microsoft.ML.Tokenizers/Model/BPE.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public string? UnknownToken
3636

3737
if (value is null)
3838
{
39-
if (VocabReverse.TryGetValue(0, out string v))
39+
if (VocabReverse.TryGetValue(0, out string? v))
4040
{
4141
VocabReverse.Remove(0);
4242
if (Vocab.TryGetValue(v, out int id))
@@ -103,7 +103,7 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st
103103
VocabReverse.Add(kvp.Value, kvp.Key);
104104
}
105105

106-
if (unknownToken is null && VocabReverse.TryGetValue(0, out string unkToken))
106+
if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken))
107107
{
108108
unknownToken = unkToken;
109109
}
@@ -187,7 +187,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence)
187187
/// <returns>The mapped token of the Id.</returns>
188188
public override string? IdToToken(int id, bool skipSpecialTokens = false)
189189
{
190-
if (VocabReverse.TryGetValue(id, out string value))
190+
if (VocabReverse.TryGetValue(id, out string? value))
191191
{
192192
return value;
193193
}
@@ -253,7 +253,7 @@ public override string[] Save(string path, string? prefix = null)
253253
}
254254

255255
/// Read the given files to extract the vocab and merges
256-
internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadFile(string? vocab, string? merges)
256+
internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadFile(string vocab, string? merges)
257257
{
258258
Dictionary<string, int>? dic;
259259
using (Stream stream = File.OpenRead(vocab))
@@ -320,7 +320,7 @@ internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadFile(strin
320320
[MethodImpl(MethodImplOptions.AggressiveInlining)]
321321
internal string CharToString(char c)
322322
{
323-
if (_charToString.TryGetValue(c, out string v))
323+
if (_charToString.TryGetValue(c, out string? v))
324324
{
325325
return v;
326326
}

src/Microsoft.ML.Tokenizers/Model/BpeTrainer.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ public BpeTrainer(
8383
MinFrequency = minFrequency;
8484
VocabSize = vocabSize;
8585
Progress = progress;
86-
SpecialTokens = new List<AddedToken>(specialTokens);
86+
87+
if (specialTokens is not null)
88+
{
89+
SpecialTokens = new List<AddedToken>(specialTokens);
90+
}
91+
8792
LimitAlphabet = limitAlphabet;
8893
InitialAlphabet = initialAlphabet;
8994
ContinuingSubwordPrefix = continuingSubwordPrefix;
@@ -172,7 +177,7 @@ private void ComputeAlphabet(Dictionary<string, int> wc, Dictionary<string, int>
172177
[MethodImpl(MethodImplOptions.AggressiveInlining)]
173178
internal string CharToString(char c)
174179
{
175-
if (_charToString.TryGetValue(c, out string v))
180+
if (_charToString.TryGetValue(c, out string? v))
176181
{
177182
return v;
178183
}
@@ -259,7 +264,7 @@ internal string CharToString(char c)
259264
// Then update counts
260265
int count = counts[i];
261266

262-
if (!whereToUpdate.TryGetValue(curPair, out HashSet<int> h))
267+
if (!whereToUpdate.TryGetValue(curPair, out HashSet<int>? h))
263268
{
264269
h = new HashSet<int>();
265270
whereToUpdate[curPair] = h;
@@ -398,7 +403,7 @@ internal string CharToString(char c)
398403

399404
if (change > 0)
400405
{
401-
if (!whereToUpdate.TryGetValue(p, out HashSet<int> h))
406+
if (!whereToUpdate.TryGetValue(p, out HashSet<int>? h))
402407
{
403408
h = new();
404409
whereToUpdate[p] = h;

src/Microsoft.ML.Tokenizers/Model/Cache.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace Microsoft.ML.Tokenizers
1111
{
12-
internal sealed class Cache<TKey, TValue>
12+
internal sealed class Cache<TKey, TValue> where TKey : notnull
1313
{
1414
internal Cache() : this(Bpe.DefaultCacheCapacity) { }
1515

@@ -39,13 +39,13 @@ internal void Clear()
3939

4040
internal List<TValue> GetValues(IEnumerable<TKey> keys)
4141
{
42-
List<TValue>? values = new();
42+
List<TValue> values = new();
4343
_cacheLock.EnterReadLock();
4444
try
4545
{
4646
foreach (TKey key in keys)
4747
{
48-
if (Map.TryGetValue(key, out TValue value))
48+
if (Map.TryGetValue(key, out TValue? value))
4949
{
5050
values.Add(value);
5151
}
@@ -61,7 +61,7 @@ internal List<TValue> GetValues(IEnumerable<TKey> keys)
6161
_cacheLock.EnterReadLock();
6262
try
6363
{
64-
if (Map.TryGetValue(key, out TValue value))
64+
if (Map.TryGetValue(key, out TValue? value))
6565
{
6666
return value;
6767
}

src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)
429429
using StreamReader reader = new StreamReader(mergeStream);
430430
while (reader.Peek() >= 0)
431431
{
432-
splitContents.Add(reader.ReadLine());
432+
splitContents.Add(reader.ReadLine()!);
433433
}
434434
}
435435
catch (Exception e)
@@ -761,7 +761,11 @@ public void AddFromStream(Stream stream)
761761

762762
while (reader.Peek() >= 0)
763763
{
764-
string line = reader.ReadLine();
764+
string? line = reader.ReadLine();
765+
if (line is null)
766+
{
767+
continue;
768+
}
765769

766770
var splitLine = line.Trim().Split(' ');
767771
if (splitLine.Length != 2)

src/Microsoft.ML.Tokenizers/Model/Model.cs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,51 @@ public abstract class Model
2020
/// <returns>The list of tokens generated from the sequence tokenization.</returns>
2121
public abstract IReadOnlyList<Token> Tokenize(string sequence);
2222

23+
/// <summary>
24+
/// Tokenize a split sequence string to a list of tokens.
25+
/// </summary>
26+
/// <param name="sequence">The text to tokenize.</param>
27+
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
28+
/// <returns>The list of tokens generated from the sequence tokenization.</returns>
29+
public virtual IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialToken) => Tokenize(sequence);
30+
31+
/// <summary>
32+
/// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list.
33+
/// </summary>
34+
/// <param name="sequence">The sequence to split.</param>
35+
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
36+
/// <param name="accumulatedIds">The list of accumulated tokenized Ids.</param>
37+
/// <returns>True if the operation succeeded, false otherwise.</returns>
38+
public virtual bool TokenizeToIds(string sequence, bool isSpecialToken, IList<int> accumulatedIds)
39+
{
40+
if (accumulatedIds is null)
41+
{
42+
throw new ArgumentNullException(nameof(accumulatedIds));
43+
}
44+
45+
var tokens = Tokenize(sequence);
46+
foreach (var token in tokens)
47+
{
48+
accumulatedIds.Add(token.Id);
49+
}
50+
return true;
51+
}
52+
2353
/// <summary>
2454
/// Map the token to tokenized Id.
2555
/// </summary>
2656
/// <param name="token">The token to map to the Id.</param>
2757
/// <returns>The mapped Id of the token.</returns>
2858
public abstract int? TokenToId(string token);
2959

60+
/// <summary>
61+
/// Map the token to tokenized id with the option to skip the special tokens.
62+
/// </summary>
63+
/// <param name="token">The token to map to Id</param>
64+
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the encoding.</param>
65+
/// <returns>The mapped Id of the token.</returns>
66+
public virtual int? TokenToId(string token, bool skipSpecialTokens) => TokenToId(token);
67+
3068
/// <summary>
3169
/// Map the tokenized Id to the token.
3270
/// </summary>

0 commit comments

Comments
 (0)