1+ classdef GPT2Tokenizer < handle
2+ % GPT2Tokenizer Object for encoding text so it can be fed to GPT2
3+
4+ properties (SetAccess = private )
5+ % Encoding
6+ Encoding
7+
8+ % BPERanks
9+ BPERanks
10+
11+ % Cache
12+ Cache = containers.Map()
13+ end
14+
15+ properties (Constant )
16+ % TokenizationExpression Regular expression used for tokenization
17+ %
18+ % This is the regular expression used for the first stage of
19+ % tokenization. It was hard-coded by the creators of GPT-2. It
20+ % appears to apply a tokenization rule that can be summarised as
21+ % follows:
22+ %
23+ % A token is one of the following things:
24+ %
25+ % - An exact string match for 's, 't, 're, 've, 'm, 'll, or 'd.
26+ % This means common contractions in words like don't and you'll
27+ % will get split into their own tokens.
28+ % - Zero or one spaces followed by one or more Unicode letters.
29+ % - Zero or one spaces followed by one or more Unicode numbers.
30+ % - Zero or one spaces followed by one or more things that are
31+ % not whitespace, a Unicode letter or a Unicode number.
32+ % - One or more whitespace characters not followed by a
33+ % non-whitepace character. This is tricky to understand, but
34+ % basically it means that a string with a word preceeded by
35+ % several spaces like ' Hello' will get split into ' ' and
36+ % ' Hello'.
37+ % - One or more whitespace characters.
38+ %
39+ % Note that we have had to modify the original expression, which
40+ % is shown below:
41+ %
42+ % '''s|''t|''re|''ve|''m|''ll|''d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+'
43+ %
44+ % MATLAB's regexp function does not support the \p flag, so we
45+ % have replaced it with something with equivalent functionality.
46+ TokenizationExpression = ' '' s|'' t|'' re|'' ve|'' m|'' ll|'' d| ?((?![\d_])\w)+| ?\d+| ?(_|[^\s\w])+|\s+(?!\S)|\s+' ;
47+
48+ % ByteEncoder Encodes bytes into a set of 256 Unicode characters
49+ %
50+ % The size of the output vocabulary from this encoder determines
51+ % the size of the embedding needed by the GPT-2 transformer
52+ % model. The creators of GPT-2 wanted to keep this at around
53+ % 50,000. However, they wanted to be able to encode any Unicode
54+ % string. Unicode has potentially hundreds of thousands of
55+ % characters. So to keep the overall vocabulary low, we go
56+ % through an extra encoding stage:
57+ %
58+ % - The raw Unicode string (which can contain any Unicode
59+ % character) is converted into UTF-8 bytes. Note that UTF-8 is
60+ % a variable length encoding scheme, so each character can get
61+ % mapped to between 1 to 4 bytes.
62+ % - These individual bytes are then mapped to a restricted
63+ % vocabulary of 256 Unicode characters. ByteEncoder defines
64+ % this mapping.
65+ ByteEncoder = iBytesToUnicode()
66+ end
67+
68+ methods
69+ function this = GPT2Tokenizer(modelName , modelsDirectory )
70+ % Read in the vocabulary. The UTF-8 part is really important to
71+ % make this work on Windows.
72+ fid = fopen([modelsDirectory filesep() modelName filesep() ' vocab.bpe' ], ' r' , ' n' , ' UTF-8' );
73+ bpeData = textscan(fid ,' %s ' , ' Delimiter' , ' \n ' );
74+ fclose(fid );
75+
76+ bpeData = bpeData{1 }; % textscan always reads everything in a cell
77+ bpeData(1 ) = []; % Delete the first line we read in (it's a comment)
78+
79+ % Split the bpe data into two columns.
80+ this.BPERanks = split(string(bpeData ));
81+
82+ % Read in the encoding data. The UTF-8 part is really important
83+ % to make this work on Windows.
84+ fid = fopen([modelsDirectory filesep() modelName filesep() ' encoder.txt' ], ' r' , ' n' , ' UTF-8' );
85+ encoderData = textscan(fid ,' %s ' , ' Delimiter' , ' \n ' );
86+ fclose(fid );
87+
88+ encoderData = encoderData{1 };
89+
90+ % Set the encoding
91+ this.Encoding = string(encoderData );
92+ end
93+
94+ function numericTokens = encode(this , text )
95+
96+ % Note that this function returns tokens with indices that
97+ % begin at 1. The Python implementation indexes from 0.
98+
99+ % Step 1: Apply regular expression to split text into words.
100+ % See the comment for 'TokenizationExpression' for more detail
101+ % on what is going on here.
102+ [inputTokens , ~ ] = regexp( ...
103+ text , ...
104+ this .TokenizationExpression , ...
105+ ' match' , ' split' );
106+
107+ % Step 2: The incoming text is Unicode. Unicode has a huge set
108+ % of characters. We do not want our BPE algorithm to deal with
109+ % a huge set of characters, because that will inflate the BPE
110+ % vocabulary. So we need to reduce the set of characters. We do
111+ % this by converting the Unicode text to the UTF-8 encoding,
112+ % and then we replace each UTF-8 byte with another Unicode
113+ % character, out of a set of 256 Unicode characters. This will
114+ % mean that our original Unicode string which could have
115+ % contained any Unicode character will now contain only one of
116+ % 256 characters.
117+ encodedTokens = cellfun( @(x )unicode2native(x , ' UTF-8' ), ...
118+ inputTokens , ' UniformOutput' , false ) ;
119+ encodedTokens = cellfun( @(x )this .ByteEncoder(x + 1 ), ...
120+ encodedTokens , ' UniformOutput' , false ) ;
121+
122+ % Step 3: Do the BPE encoding on a per word basis. Words are
123+ % either left as they are, or for rare words we split them into
124+ % word fragments.
125+ bpeTokens = cellfun(@(x )this .bpe(x ), encodedTokens , ' UniformOutput' , false );
126+
127+ % Step 4: Look up each word or word fragment and replace it
128+ % with a number.
129+ numericTokens = [];
130+ for i = 1 : numel(bpeTokens )
131+ bpeTokensSplit = split(bpeTokens{i });
132+ for j = 1 : numel(bpeTokensSplit )
133+ numericTokens = [numericTokens find(this .Encoding == bpeTokensSplit(j ))]; % #ok<AGROW>
134+ end
135+ end
136+ end
137+
138+ function text = decode(this , numericTokens )
139+
140+ % Note that this function expects tokens that begin at 1!
141+
142+ % Step 1: Turn tokens into text
143+ text = join(this .Encoding(numericTokens ),' ' );
144+
145+ % Step 2: Replace characters with byte values
146+ [~ ,text ] = max( char(text ) == this .ByteEncoder ' );
147+ text = text - 1 ;
148+
149+ % Step 3: Decode byte values as UTF-8
150+ text = native2unicode(text , ' UTF-8' );
151+ end
152+ end
153+
154+ methods (Access = private )
155+ function word = bpe(this , token )
156+ if this .Cache .isKey(token )
157+ word = this .Cache(token );
158+ elseif isempty(token )
159+ word = token ;
160+ else
161+ wordFragments = string(num2cell(token ));
162+ pairs = iGetPairs(wordFragments );
163+
164+ while true
165+ matches = [];
166+ for i = 1 : numel(pairs )
167+ match = find(sum(pairs{i } == this .BPERanks , 2 ) == 2 );
168+ matches = [matches match ]; % #ok<AGROW>
169+ end
170+ minIndex = min(matches );
171+ if isempty(minIndex )
172+ break ;
173+ end
174+ bigram = this .BPERanks(minIndex ,: );
175+
176+ first = bigram(1 );
177+ second = bigram(2 );
178+ newWordFragments = [];
179+ i = 1 ;
180+ while i < length(wordFragments )+1
181+ j = find( ...
182+ wordFragments == first & ...
183+ [zeros(1 ,(i - 1 )) ones(1 ,length(wordFragments )-i + 1 )]);
184+ if isempty(j )
185+ newWordFragments = [newWordFragments wordFragments(i : end )]; % #ok<AGROW>
186+ break
187+ else
188+ newWordFragments = [newWordFragments wordFragments(i : (j(1 )-1 ))]; % #ok<AGROW>
189+ i = j(1 );
190+ end
191+
192+ if wordFragments(i ) == first && ...
193+ i < length(wordFragments ) && ...
194+ wordFragments(i + 1 ) == second
195+ newWordFragments = [newWordFragments first + second ]; % #ok<AGROW>
196+ i = i + 2 ;
197+ else
198+ newWordFragments = [newWordFragments wordFragments(i )]; % #ok<AGROW>
199+ i = i + 1 ;
200+ end
201+ end
202+
203+ % We have a new word because we have merged some of the
204+ % word fragments. If there is only one element in
205+ % 'wordFragments', we have merges all of the fragments,
206+ % and can stop now, so we break. Otherwise, we generate
207+ % pairs again, and start the process again.
208+ wordFragments = newWordFragments ;
209+ if numel(wordFragments ) == 1
210+ break ;
211+ else
212+ pairs = iGetPairs(wordFragments );
213+ end
214+ end
215+
216+ word = join(wordFragments , ' ' );
217+ this .Cache(token ) = word ;
218+ end
219+ end
220+ end
221+ end
222+
223+ function cs = iBytesToUnicode()
224+ % Note that the third character here is not the letter i! It is the
225+ % extended Unicode character corresponding to the number 161.
226+ % cs = ['!':'~' '¡':'¬' '®':'ÿ'];
227+ cs = char([33 : 126 161 : 172 174 : 255 ]);
228+ bs = double(cs );
229+ n = 0 ;
230+ for b = 0 : 255
231+ if ~any(b == bs )
232+ bs = [bs b ]; % #ok<AGROW>
233+ cs = [cs 256 + n ]; % #ok<AGROW>
234+ n = n + 1 ;
235+ end
236+ end
237+ [~ ,sortedIndices ] = sort(bs );
238+ cs = cs(sortedIndices );
239+ end
240+
241+ function pairs = iGetPairs(wordFragments )
242+ numLetters = length(wordFragments );
243+ pairIndices = [1 : (numLetters - 1 ); 2 : numLetters ]' ;
244+ pairIndices = mat2cell(pairIndices , ones(numLetters - 1 ,1 ), 2 );
245+ pairs = cellfun(@(x )wordFragments(x ), pairIndices , ...
246+ ' UniformOutput' , false );
247+ pairs = cellfun(@(x )[string(x(1 )) string(x(2 ))], pairs , ...
248+ ' UniformOutput' , false );
249+ end
0 commit comments