Skip to content

Commit e4e7b8b

Browse files
committed
First commit
0 parents  commit e4e7b8b

20 files changed

+1068
-0
lines changed

+gpt2/+layer/block.m

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
function [X, present] = block(X, past, weights, hyperParameters)
2+
% block Transformer block for GPT-2
3+
%
4+
% [X, present] = block(X, past, weights, hyperParameters) computes a
5+
% GPT-2 style transformer block on the input X as described in [1] (see
6+
% Section 2.3). One difference between this style of transformer block
7+
% and others is that this block uses layer normalization at the
8+
% beginning.
9+
%
10+
% Inputs:
11+
% X - A (numFeatures*numHeads)-by-numInputSubwords
12+
% input array.
13+
% past - A numFeatures-by-numPastSubwords-by-numHeads-by-2
14+
% array. This contains the 'keys' and 'values' for
15+
% past subwords. These are needed to predict future
16+
% outputs in an autoregressive manner. 'keys' are
17+
% stored in past(:,:,:,1) and 'values' are stored
18+
% in past(:,:,:,2).
19+
% weights - The weights for the transformer block stored in a
20+
% struct. In this block we have:
21+
% - ln_1_g_0: Weight vector for the first layer
22+
% normalization.
23+
% - ln_1_b_0: Bias vector for the first layer
24+
% normalization.
25+
% - ln_2_g_0: Weight vector for the second layer
26+
% normalization.
27+
% - ln_2_b_0: Bias vector for the second layer
28+
% normalization.
29+
% In the attention sub-block:
30+
% - attn_c_attn_w_0: A weight matrix for the
31+
% first fully connected layer.
32+
% - attn_c_attn_b_0: A bias vector for the first
33+
% fully connected layer.
34+
% - attn_c_proj_w_0: A weight matrix for the
35+
% final fully connected layer.
36+
% - attn_c_proj_b_0: A bias vector for the final
37+
% fully connected layer.
38+
% In the multi-layer perceptron block:
39+
% - mlp_c_fc_w_0: A weight matrix for the first
40+
% fully connected layer.
41+
% - mlp_c_fc_b_0: A bias vector for the first
42+
% fully connected layer.
43+
% - mlp_c_proj_w_0: A weight matrix for the
44+
% second fully connected layer.
45+
% - mlp_c_proj_b_0: A bias vector for the second
46+
% fully connected layer.
47+
% numHeads - The number of attention heads. This is a
48+
% hyper-parameter.
49+
%
50+
% Outputs:
51+
% Z - A (numFeatures*numHeads)-by-numInputSubwords
52+
% output array.
53+
% present - A numFeatures-by-numAllSubwords-by-numHeads-by-2
54+
% array. This contains the 'keys' and 'values' that
55+
% are created from inputs. These need to passed
56+
% back in as the 'past' input if we want to predict
57+
% future outputs in an autoregressive manner. 'keys'
58+
% are stored in present(:,:,:,1) and 'values' are
59+
% stored in present(:,:,:,2).
60+
%
61+
% References:
62+
%
63+
% [1] Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei,
64+
% Ilya Sutskever, "Language Models are Unsupervised Multitask
65+
% Learners",
66+
% https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
67+
68+
XNorm1 = transformer.layer.normalization(X, ...
69+
weights.ln_1_g_0, weights.ln_1_b_0);
70+
71+
[A, present] = transformer.layer.attention(XNorm1, past, weights, hyperParameters);
72+
73+
X = X + A;
74+
75+
XNorm2 = transformer.layer.normalization(X, ...
76+
weights.ln_2_g_0, weights.ln_2_b_0);
77+
78+
M = transformer.layer.multiLayerPerceptron(XNorm2, weights);
79+
80+
X = X + M;
81+
82+
end

+gpt2/+tokenizer/GPT2Tokenizer.m

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
classdef GPT2Tokenizer < handle
2+
% GPT2Tokenizer Object for encoding text so it can be fed to GPT2
3+
4+
properties(SetAccess = private)
5+
% Encoding
6+
Encoding
7+
8+
% BPERanks
9+
BPERanks
10+
11+
% Cache
12+
Cache = containers.Map()
13+
end
14+
15+
properties(Constant)
16+
% TokenizationExpression Regular expression used for tokenization
17+
%
18+
% This is the regular expression used for the first stage of
19+
% tokenization. It was hard-coded by the creators of GPT-2. It
20+
% appears to apply a tokenization rule that can be summarised as
21+
% follows:
22+
%
23+
% A token is one of the following things:
24+
%
25+
% - An exact string match for 's, 't, 're, 've, 'm, 'll, or 'd.
26+
% This means common contractions in words like don't and you'll
27+
% will get split into their own tokens.
28+
% - Zero or one spaces followed by one or more Unicode letters.
29+
% - Zero or one spaces followed by one or more Unicode numbers.
30+
% - Zero or one spaces followed by one or more things that are
31+
% not whitespace, a Unicode letter or a Unicode number.
32+
% - One or more whitespace characters not followed by a
33+
% non-whitepace character. This is tricky to understand, but
34+
% basically it means that a string with a word preceeded by
35+
% several spaces like ' Hello' will get split into ' ' and
36+
% ' Hello'.
37+
% - One or more whitespace characters.
38+
%
39+
% Note that we have had to modify the original expression, which
40+
% is shown below:
41+
%
42+
% '''s|''t|''re|''ve|''m|''ll|''d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+'
43+
%
44+
% MATLAB's regexp function does not support the \p flag, so we
45+
% have replaced it with something with equivalent functionality.
46+
TokenizationExpression = '''s|''t|''re|''ve|''m|''ll|''d| ?((?![\d_])\w)+| ?\d+| ?(_|[^\s\w])+|\s+(?!\S)|\s+';
47+
48+
% ByteEncoder Encodes bytes into a set of 256 Unicode characters
49+
%
50+
% The size of the output vocabulary from this encoder determines
51+
% the size of the embedding needed by the GPT-2 transformer
52+
% model. The creators of GPT-2 wanted to keep this at around
53+
% 50,000. However, they wanted to be able to encode any Unicode
54+
% string. Unicode has potentially hundreds of thousands of
55+
% characters. So to keep the overall vocabulary low, we go
56+
% through an extra encoding stage:
57+
%
58+
% - The raw Unicode string (which can contain any Unicode
59+
% character) is converted into UTF-8 bytes. Note that UTF-8 is
60+
% a variable length encoding scheme, so each character can get
61+
% mapped to between 1 to 4 bytes.
62+
% - These individual bytes are then mapped to a restricted
63+
% vocabulary of 256 Unicode characters. ByteEncoder defines
64+
% this mapping.
65+
ByteEncoder = iBytesToUnicode()
66+
end
67+
68+
methods
69+
function this = GPT2Tokenizer(modelName, modelsDirectory)
70+
% Read in the vocabulary. The UTF-8 part is really important to
71+
% make this work on Windows.
72+
fid = fopen([modelsDirectory filesep() modelName filesep() 'vocab.bpe'], 'r', 'n', 'UTF-8');
73+
bpeData = textscan(fid,'%s', 'Delimiter', '\n');
74+
fclose(fid);
75+
76+
bpeData = bpeData{1}; % textscan always reads everything in a cell
77+
bpeData(1) = []; % Delete the first line we read in (it's a comment)
78+
79+
% Split the bpe data into two columns.
80+
this.BPERanks = split(string(bpeData));
81+
82+
% Read in the encoding data. The UTF-8 part is really important
83+
% to make this work on Windows.
84+
fid = fopen([modelsDirectory filesep() modelName filesep() 'encoder.txt'], 'r', 'n', 'UTF-8');
85+
encoderData = textscan(fid,'%s', 'Delimiter', '\n');
86+
fclose(fid);
87+
88+
encoderData = encoderData{1};
89+
90+
% Set the encoding
91+
this.Encoding = string(encoderData);
92+
end
93+
94+
function numericTokens = encode(this, text)
95+
96+
% Note that this function returns tokens with indices that
97+
% begin at 1. The Python implementation indexes from 0.
98+
99+
% Step 1: Apply regular expression to split text into words.
100+
% See the comment for 'TokenizationExpression' for more detail
101+
% on what is going on here.
102+
[inputTokens, ~] = regexp( ...
103+
text, ...
104+
this.TokenizationExpression, ...
105+
'match', 'split');
106+
107+
% Step 2: The incoming text is Unicode. Unicode has a huge set
108+
% of characters. We do not want our BPE algorithm to deal with
109+
% a huge set of characters, because that will inflate the BPE
110+
% vocabulary. So we need to reduce the set of characters. We do
111+
% this by converting the Unicode text to the UTF-8 encoding,
112+
% and then we replace each UTF-8 byte with another Unicode
113+
% character, out of a set of 256 Unicode characters. This will
114+
% mean that our original Unicode string which could have
115+
% contained any Unicode character will now contain only one of
116+
% 256 characters.
117+
encodedTokens = cellfun( @(x)unicode2native(x, 'UTF-8'), ...
118+
inputTokens, 'UniformOutput', false );
119+
encodedTokens = cellfun( @(x)this.ByteEncoder(x+1), ...
120+
encodedTokens, 'UniformOutput', false );
121+
122+
% Step 3: Do the BPE encoding on a per word basis. Words are
123+
% either left as they are, or for rare words we split them into
124+
% word fragments.
125+
bpeTokens = cellfun(@(x)this.bpe(x), encodedTokens, 'UniformOutput', false);
126+
127+
% Step 4: Look up each word or word fragment and replace it
128+
% with a number.
129+
numericTokens = [];
130+
for i = 1:numel(bpeTokens)
131+
bpeTokensSplit = split(bpeTokens{i});
132+
for j = 1:numel(bpeTokensSplit)
133+
numericTokens = [numericTokens find(this.Encoding == bpeTokensSplit(j))]; %#ok<AGROW>
134+
end
135+
end
136+
end
137+
138+
function text = decode(this, numericTokens)
139+
140+
% Note that this function expects tokens that begin at 1!
141+
142+
% Step 1: Turn tokens into text
143+
text = join(this.Encoding(numericTokens),'');
144+
145+
% Step 2: Replace characters with byte values
146+
[~,text] = max( char(text) == this.ByteEncoder' );
147+
text = text -1;
148+
149+
% Step 3: Decode byte values as UTF-8
150+
text = native2unicode(text, 'UTF-8');
151+
end
152+
end
153+
154+
methods(Access = private)
155+
function word = bpe(this, token)
156+
if this.Cache.isKey(token)
157+
word = this.Cache(token);
158+
elseif isempty(token)
159+
word = token;
160+
else
161+
wordFragments = string(num2cell(token));
162+
pairs = iGetPairs(wordFragments);
163+
164+
while true
165+
matches = [];
166+
for i = 1:numel(pairs)
167+
match = find(sum(pairs{i} == this.BPERanks, 2) == 2);
168+
matches = [matches match]; %#ok<AGROW>
169+
end
170+
minIndex = min(matches);
171+
if isempty(minIndex)
172+
break;
173+
end
174+
bigram = this.BPERanks(minIndex,:);
175+
176+
first = bigram(1);
177+
second = bigram(2);
178+
newWordFragments = [];
179+
i = 1;
180+
while i < length(wordFragments)+1
181+
j = find( ...
182+
wordFragments == first & ...
183+
[zeros(1,(i-1)) ones(1,length(wordFragments)-i+1)]);
184+
if isempty(j)
185+
newWordFragments = [newWordFragments wordFragments(i:end)]; %#ok<AGROW>
186+
break
187+
else
188+
newWordFragments = [newWordFragments wordFragments(i:(j(1)-1))]; %#ok<AGROW>
189+
i = j(1);
190+
end
191+
192+
if wordFragments(i) == first && ...
193+
i < length(wordFragments) && ...
194+
wordFragments(i+1) == second
195+
newWordFragments = [newWordFragments first+second]; %#ok<AGROW>
196+
i = i + 2;
197+
else
198+
newWordFragments = [newWordFragments wordFragments(i)]; %#ok<AGROW>
199+
i = i + 1;
200+
end
201+
end
202+
203+
% We have a new word because we have merged some of the
204+
% word fragments. If there is only one element in
205+
% 'wordFragments', we have merges all of the fragments,
206+
% and can stop now, so we break. Otherwise, we generate
207+
% pairs again, and start the process again.
208+
wordFragments = newWordFragments;
209+
if numel(wordFragments) == 1
210+
break;
211+
else
212+
pairs = iGetPairs(wordFragments);
213+
end
214+
end
215+
216+
word = join(wordFragments, ' ');
217+
this.Cache(token) = word;
218+
end
219+
end
220+
end
221+
end
222+
223+
function cs = iBytesToUnicode()
224+
% Note that the third character here is not the letter i! It is the
225+
% extended Unicode character corresponding to the number 161.
226+
%cs = ['!':'~' '¡':'¬' '®':'ÿ'];
227+
cs = char([33:126 161:172 174:255]);
228+
bs = double(cs);
229+
n = 0;
230+
for b = 0:255
231+
if ~any(b == bs)
232+
bs = [bs b]; %#ok<AGROW>
233+
cs = [cs 256+n]; %#ok<AGROW>
234+
n = n + 1;
235+
end
236+
end
237+
[~,sortedIndices] = sort(bs);
238+
cs = cs(sortedIndices);
239+
end
240+
241+
function pairs = iGetPairs(wordFragments)
242+
numLetters = length(wordFragments);
243+
pairIndices = [1:(numLetters-1); 2:numLetters]';
244+
pairIndices = mat2cell(pairIndices, ones(numLetters-1,1), 2);
245+
pairs = cellfun(@(x)wordFragments(x), pairIndices, ...
246+
'UniformOutput', false);
247+
pairs = cellfun(@(x)[string(x(1)) string(x(2))], pairs, ...
248+
'UniformOutput', false);
249+
end

+gpt2/download.m

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
function download()
2+
% download Download all of the necessary weight files for transformer model
3+
%
4+
% download() will download all of the files that define the pretrained
5+
% GPT-2 355M model.
6+
7+
% Create directories for the model.
8+
modelType = 'gpt2-355M';
9+
modelDirectory = fullfile(fileparts(mfilename('fullpath')),'..',modelType);
10+
iCreateDirectoryIfItDoesNotExist(modelDirectory);
11+
12+
% Download 'encoder.txt'. This is equivalent to 'encoder.json' from the
13+
% original GPT-2.
14+
iDownloadFileIfItDoesNotExist( ...
15+
fullfile(modelDirectory,'encoder.txt'), ...
16+
'https://ssd.mathworks.com/supportfiles/nnet/data/networks/gpt2_encoder.txt' );
17+
18+
% Download 'vocab.bpe'. This file contains the BPE ranks for the encoder.
19+
% This file is identical to the one used by the original OpenAI repo.
20+
iDownloadFileIfItDoesNotExist( ...
21+
fullfile(modelDirectory,'vocab.bpe'), ...
22+
'https://ssd.mathworks.com/supportfiles/nnet/data/networks/gpt2_vocab.bpe' );
23+
24+
% Download 'parameters.mat'. This contains all of the weights in the GPT-2
25+
% model. They have been exported from the original TensorFlow
26+
% implementation.
27+
iDownloadFileIfItDoesNotExist( ...
28+
fullfile(modelDirectory,'parameters.mat'), ...
29+
'https://ssd.mathworks.com/supportfiles/nnet/data/networks/gpt2_355M_params.mat' );
30+
end
31+
32+
function iCreateDirectoryIfItDoesNotExist(directory)
33+
if ~exist(directory, 'dir')
34+
fprintf('Creating directory ''%s''...\n', directory);
35+
mkdir(directory);
36+
else
37+
fprintf('Skipped creating directory ''%s'' as it already exists\n', directory);
38+
end
39+
end
40+
41+
function iDownloadFileIfItDoesNotExist(destination, source)
42+
if ~exist(destination, 'file')
43+
fprintf('Downloading file ''%s'' ...\n', destination);
44+
websave(destination, source);
45+
else
46+
fprintf('Skipped downloading file ''%s'' as it already exists\n', destination);
47+
end
48+
end

0 commit comments

Comments
 (0)