3
3
// See the LICENSE file in the project root for more information.
4
4
5
5
using System ;
6
+ using System . Buffers ;
6
7
using System . Collections . Generic ;
7
8
8
9
namespace Microsoft . ML . Tokenizers
@@ -19,63 +20,82 @@ public static int[] BytePairEncode(ReadOnlyMemory<byte> mergingBytes, Dictionary
19
20
return [ ranks [ mergingBytes ] ] ;
20
21
}
21
22
22
- var byteIndicesAndRanks = new List < ( int Index , int Rank ) > ( ) ;
23
- for ( int i = 0 ; i < mergingBytes . Length + 1 ; i ++ )
23
+ ( int Index , int Rank ) [ ] ? arrayPoolArray = null ;
24
+ int requiredLength = mergingBytes . Length + 1 ;
25
+ Span < ( int Index , int Rank ) > byteIndicesAndRanks = requiredLength <= 64 ?
26
+ stackalloc ( int , int ) [ 64 ] :
27
+ ( arrayPoolArray = ArrayPool < ( int , int ) > . Shared . Rent ( requiredLength ) ) ;
28
+ byteIndicesAndRanks = byteIndicesAndRanks . Slice ( 0 , requiredLength ) ;
29
+
30
+ for ( int i = 0 ; i < byteIndicesAndRanks . Length ; i ++ )
24
31
{
25
- byteIndicesAndRanks . Add ( ( i , int . MaxValue ) ) ;
32
+ byteIndicesAndRanks [ i ] = ( i , int . MaxValue ) ;
26
33
}
27
- int GetRank ( int startIndex , int skip = 0 )
34
+
35
+ int GetRank ( Span < ( int Index , int Rank ) > byteIndicesAndRanks , int startIndex , int skip = 0 )
28
36
{
29
- if ( startIndex + skip + 2 < byteIndicesAndRanks . Count )
37
+ if ( startIndex + skip + 2 < byteIndicesAndRanks . Length )
30
38
{
31
39
var slice = mergingBytes . SliceStartEnd ( byteIndicesAndRanks [ startIndex ] . Index , byteIndicesAndRanks [ startIndex + skip + 2 ] . Index ) ;
32
40
if ( ranks . TryGetValue ( slice , out var rank ) )
33
41
{
34
42
return rank ;
35
43
}
36
44
}
45
+
37
46
return int . MaxValue ;
38
47
}
39
- for ( int i = 0 ; i < byteIndicesAndRanks . Count - 2 ; i ++ )
48
+
49
+ for ( int i = 0 ; i < byteIndicesAndRanks . Length - 2 ; i ++ )
40
50
{
41
- var rank = GetRank ( i ) ;
51
+ int rank = GetRank ( byteIndicesAndRanks , i ) ;
42
52
if ( rank != int . MaxValue )
43
53
{
44
- byteIndicesAndRanks [ i ] = ( byteIndicesAndRanks [ i ] . Index , rank ) ;
54
+ byteIndicesAndRanks [ i ] . Rank = rank ;
45
55
}
46
56
}
47
- while ( byteIndicesAndRanks . Count > 1 )
57
+
58
+ while ( byteIndicesAndRanks . Length > 1 )
48
59
{
49
60
var minRank = ( Index : 0 , Rank : int . MaxValue ) ;
50
- for ( int i = 0 ; i < byteIndicesAndRanks . Count - 1 ; i ++ )
61
+ for ( int i = 0 ; i < byteIndicesAndRanks . Length - 1 ; i ++ )
51
62
{
52
63
if ( byteIndicesAndRanks [ i ] . Rank < minRank . Rank )
53
64
{
54
65
minRank = ( i , byteIndicesAndRanks [ i ] . Rank ) ;
55
66
}
56
67
}
68
+
57
69
if ( minRank . Rank != int . MaxValue )
58
70
{
59
71
int j = minRank . Index ;
60
- byteIndicesAndRanks [ j ] = ( byteIndicesAndRanks [ j ] . Index , GetRank ( j , 1 ) ) ;
72
+ byteIndicesAndRanks [ j ] . Rank = GetRank ( byteIndicesAndRanks , j , 1 ) ;
61
73
if ( j > 0 )
62
74
{
63
- byteIndicesAndRanks [ j - 1 ] = ( byteIndicesAndRanks [ j - 1 ] . Index , GetRank ( j - 1 , 1 ) ) ;
75
+ byteIndicesAndRanks [ j - 1 ] . Rank = GetRank ( byteIndicesAndRanks , j - 1 , 1 ) ;
64
76
}
65
- byteIndicesAndRanks . RemoveAt ( j + 1 ) ;
77
+
78
+ byteIndicesAndRanks . Slice ( j + 2 ) . CopyTo ( byteIndicesAndRanks . Slice ( j + 1 ) ) ;
79
+ byteIndicesAndRanks = byteIndicesAndRanks . Slice ( 0 , byteIndicesAndRanks . Length - 1 ) ;
66
80
}
67
81
else
68
82
{
69
83
break ;
70
84
}
71
85
}
72
86
73
- var outList = new int [ byteIndicesAndRanks . Count - 1 ] ;
74
- for ( int i = 0 ; i < byteIndicesAndRanks . Count - 1 ; i ++ )
87
+ var result = new int [ byteIndicesAndRanks . Length - 1 ] ;
88
+ for ( int i = 0 ; i < result . Length ; i ++ )
75
89
{
76
- outList [ i ] = ranks [ mergingBytes . SliceStartEnd ( byteIndicesAndRanks [ i ] . Index , byteIndicesAndRanks [ i + 1 ] . Index ) ] ;
90
+ result [ i ] = ranks [ mergingBytes . SliceStartEnd ( byteIndicesAndRanks [ i ] . Index , byteIndicesAndRanks [ i + 1 ] . Index ) ] ;
77
91
}
78
- return outList ;
92
+
93
+ if ( arrayPoolArray is not null )
94
+ {
95
+ ArrayPool < ( int , int ) > . Shared . Return ( arrayPoolArray ) ;
96
+ }
97
+
98
+ return result ;
79
99
}
80
100
81
101
private static ReadOnlyMemory < byte > SliceStartEnd ( this ReadOnlyMemory < byte > memory , int start , int end ) => memory . Slice ( start , end - start ) ;
0 commit comments