Skip to content

Commit 3282f44

Browse files
authored
Tweak Tiktoken's BytePairEncode for improved perf (#7017)
- Stackalloc the indices/ranks when feasible - Use a span to eliminate bounds checks and allow for directly updating ranks
1 parent eb66d73 commit 3282f44

File tree

1 file changed

+37
-17
lines changed

1 file changed

+37
-17
lines changed

src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Buffers;
67
using System.Collections.Generic;
78

89
namespace Microsoft.ML.Tokenizers
@@ -19,63 +20,82 @@ public static int[] BytePairEncode(ReadOnlyMemory<byte> mergingBytes, Dictionary
1920
return [ranks[mergingBytes]];
2021
}
2122

22-
var byteIndicesAndRanks = new List<(int Index, int Rank)>();
23-
for (int i = 0; i < mergingBytes.Length + 1; i++)
23+
(int Index, int Rank)[]? arrayPoolArray = null;
24+
int requiredLength = mergingBytes.Length + 1;
25+
Span<(int Index, int Rank)> byteIndicesAndRanks = requiredLength <= 64 ?
26+
stackalloc (int, int)[64] :
27+
(arrayPoolArray = ArrayPool<(int, int)>.Shared.Rent(requiredLength));
28+
byteIndicesAndRanks = byteIndicesAndRanks.Slice(0, requiredLength);
29+
30+
for (int i = 0; i < byteIndicesAndRanks.Length; i++)
2431
{
25-
byteIndicesAndRanks.Add((i, int.MaxValue));
32+
byteIndicesAndRanks[i] = (i, int.MaxValue);
2633
}
27-
int GetRank(int startIndex, int skip = 0)
34+
35+
int GetRank(Span<(int Index, int Rank)> byteIndicesAndRanks, int startIndex, int skip = 0)
2836
{
29-
if (startIndex + skip + 2 < byteIndicesAndRanks.Count)
37+
if (startIndex + skip + 2 < byteIndicesAndRanks.Length)
3038
{
3139
var slice = mergingBytes.SliceStartEnd(byteIndicesAndRanks[startIndex].Index, byteIndicesAndRanks[startIndex + skip + 2].Index);
3240
if (ranks.TryGetValue(slice, out var rank))
3341
{
3442
return rank;
3543
}
3644
}
45+
3746
return int.MaxValue;
3847
}
39-
for (int i = 0; i < byteIndicesAndRanks.Count - 2; i++)
48+
49+
for (int i = 0; i < byteIndicesAndRanks.Length - 2; i++)
4050
{
41-
var rank = GetRank(i);
51+
int rank = GetRank(byteIndicesAndRanks, i);
4252
if (rank != int.MaxValue)
4353
{
44-
byteIndicesAndRanks[i] = (byteIndicesAndRanks[i].Index, rank);
54+
byteIndicesAndRanks[i].Rank = rank;
4555
}
4656
}
47-
while (byteIndicesAndRanks.Count > 1)
57+
58+
while (byteIndicesAndRanks.Length > 1)
4859
{
4960
var minRank = (Index: 0, Rank: int.MaxValue);
50-
for (int i = 0; i < byteIndicesAndRanks.Count - 1; i++)
61+
for (int i = 0; i < byteIndicesAndRanks.Length - 1; i++)
5162
{
5263
if (byteIndicesAndRanks[i].Rank < minRank.Rank)
5364
{
5465
minRank = (i, byteIndicesAndRanks[i].Rank);
5566
}
5667
}
68+
5769
if (minRank.Rank != int.MaxValue)
5870
{
5971
int j = minRank.Index;
60-
byteIndicesAndRanks[j] = (byteIndicesAndRanks[j].Index, GetRank(j, 1));
72+
byteIndicesAndRanks[j].Rank = GetRank(byteIndicesAndRanks, j, 1);
6173
if (j > 0)
6274
{
63-
byteIndicesAndRanks[j - 1] = (byteIndicesAndRanks[j - 1].Index, GetRank(j - 1, 1));
75+
byteIndicesAndRanks[j - 1].Rank = GetRank(byteIndicesAndRanks, j - 1, 1);
6476
}
65-
byteIndicesAndRanks.RemoveAt(j + 1);
77+
78+
byteIndicesAndRanks.Slice(j + 2).CopyTo(byteIndicesAndRanks.Slice(j + 1));
79+
byteIndicesAndRanks = byteIndicesAndRanks.Slice(0, byteIndicesAndRanks.Length - 1);
6680
}
6781
else
6882
{
6983
break;
7084
}
7185
}
7286

73-
var outList = new int[byteIndicesAndRanks.Count - 1];
74-
for (int i = 0; i < byteIndicesAndRanks.Count - 1; i++)
87+
var result = new int[byteIndicesAndRanks.Length - 1];
88+
for (int i = 0; i < result.Length; i++)
7589
{
76-
outList[i] = ranks[mergingBytes.SliceStartEnd(byteIndicesAndRanks[i].Index, byteIndicesAndRanks[i + 1].Index)];
90+
result[i] = ranks[mergingBytes.SliceStartEnd(byteIndicesAndRanks[i].Index, byteIndicesAndRanks[i + 1].Index)];
7791
}
78-
return outList;
92+
93+
if (arrayPoolArray is not null)
94+
{
95+
ArrayPool<(int, int)>.Shared.Return(arrayPoolArray);
96+
}
97+
98+
return result;
7999
}
80100

81101
private static ReadOnlyMemory<byte> SliceStartEnd(this ReadOnlyMemory<byte> memory, int start, int end) => memory.Slice(start, end - start);

0 commit comments

Comments
 (0)