Skip to content

Commit ef33a6f

Browse files
authored
Merge pull request #2 from matlab-deep-learning/add_tests
Add tests and CircleCI config
2 parents e4e7b8b + cda9fe5 commit ef33a6f

File tree

18 files changed

+1188
-0
lines changed

18 files changed

+1188
-0
lines changed

.circleci/config.yml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
version: 2.1
2+
orbs:
3+
matlab: mathworks/[email protected]
4+
5+
jobs:
6+
build:
7+
machine:
8+
image: ubuntu-1604:201903-01
9+
steps:
10+
- checkout
11+
- matlab/install:
12+
# We want to ensure the repo works in R2020a
13+
release: R2020a
14+
- matlab/run-tests:
15+
test-results-junit: artifacts/test_results/matlab/results.xml
16+
# Have to add test/tools to the path for certain tests.
17+
source-folder: .;test/tools
18+
- store_test_results:
19+
path: artifacts/test_results
20+
- store_artifacts:
21+
path: artifacts/

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*.asv
2+
gpt2-355M

test/gpt2/layer/tblock.m

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
classdef tblock < matlab.unittest.TestCase
2+
% tblock Unit tests for transformer.layer.block
3+
4+
% Copyright 2020 The MathWorks, Inc.
5+
6+
properties(Constant, Access=private)
7+
block = @gpt2.layer.block
8+
end
9+
10+
properties(TestParameter)
11+
Input = struct(...
12+
'Scalar', 1,...
13+
'Vector', 1:5,...
14+
'Matrix', reshape(1:6,[3,2]))
15+
end
16+
17+
methods(Test)
18+
function outputHasInputSize(test,Input)
19+
% The block is simply a composition of other layers. Simply
20+
% verify the output of a block is the same size as the input,
21+
% and as such the blocks can be stacked.
22+
x = dlarray(Input);
23+
C = size(Input,1);
24+
weights = test.randomWeights(C);
25+
hyperParameters.NumHeads = 1;
26+
y = test.block(x,[],weights,hyperParameters);
27+
test.verifySize(y,size(x));
28+
end
29+
30+
function outputHasInputSizeWithPasts(test,Input)
31+
% As above but using "pasts" - a concatenation of key and value
32+
% matrices.
33+
x = dlarray(Input);
34+
C = size(Input,1);
35+
weights = test.randomWeights(C);
36+
hyperParameters.NumHeads = 1;
37+
% Provide a fake past of sequence length 1
38+
K_fake = dlarray(rand(C,1));
39+
V_fake = dlarray(rand(C,1));
40+
past = cat(4,K_fake,V_fake);
41+
[y,present] = test.block(x,past,weights,hyperParameters);
42+
test.verifySize(y,size(x));
43+
% The size of presents is the size of past except the sequence
44+
% dimension gets extended by the sequence length of y
45+
exp_present_size = size(past);
46+
exp_present_size(2) = exp_present_size(2)+size(y,2);
47+
test.verifySize(present,exp_present_size);
48+
end
49+
end
50+
51+
methods(Access=private)
52+
function weights = randomWeights(test,C)
53+
% C is num features, or latent dimension of the block
54+
g1 = dlarray(rand(C,1));
55+
b1 = dlarray(rand(C,1));
56+
g2 = dlarray(rand(C,1));
57+
b2 = dlarray(rand(C,1));
58+
W_A1 = dlarray(rand(3*C,C));
59+
W_A2 = dlarray(rand(C));
60+
b_A1 = dlarray(rand(3*C,1));
61+
b_A2 = dlarray(rand(C,1));
62+
W_P1 = dlarray(rand(C));
63+
b_P1 = dlarray(rand(C,1));
64+
W_P2 = dlarray(rand(C));
65+
b_P2 = dlarray(rand(C,1));
66+
weights = test.prepareBlockWeightsStruct(g1,b1,W_A1,b_A1,W_A2,b_A2,g2,b2,W_P1,b_P1,W_P2,b_P2);
67+
end
68+
69+
function s = prepareBlockWeightsStruct(test,g1,b1,W_A1,b_A1,W_A2,b_A2,g2,b2,W_P1,b_P1,W_P2,b_P2)
70+
% Merge various structs that have the appropriate weight naming
71+
% syntax.
72+
s_ln = test.prepareLayerNormWeightsStruct(g1,b1,g2,b2);
73+
s_attn = test.prepareAttentionWeightsStruct(W_A1,b_A1,W_A2,b_A2);
74+
s_mlp = test.prepareMLPWeightsStruct(W_P1,b_P1,W_P2,b_P2);
75+
c = {s_ln,s_attn,s_mlp};
76+
fn = cellfun(@fieldnames,c,'UniformOutput',false);
77+
fn = cat(1,fn{:});
78+
fv = cellfun(@struct2cell,c,'UniformOutput',false);
79+
fv = cat(1,fv{:});
80+
s = struct();
81+
for i = 1:numel(fn)
82+
s.(fn{i}) = fv{i};
83+
end
84+
end
85+
86+
function s = prepareAttentionWeightsStruct(~,W1,b1,W2,b2)
87+
% Prepare a struct compatible with the weights input of
88+
% attention. These are for the fully connected layers.
89+
s = struct(...
90+
'attn_c_attn_w_0',W1,...
91+
'attn_c_attn_b_0',b1,...
92+
'attn_c_proj_w_0',W2,...
93+
'attn_c_proj_b_0',b2);
94+
end
95+
96+
function s = prepareLayerNormWeightsStruct(~,g1,b1,g2,b2)
97+
% Prepare a struct of weights compatible with the two layer
98+
% norm calls in block
99+
s = struct(...
100+
'ln_1_g_0',g1,...
101+
'ln_1_b_0',b1,...
102+
'ln_2_g_0',g2,...
103+
'ln_2_b_0',b2);
104+
end
105+
106+
function s = prepareMLPWeightsStruct(~,W1,b1,W2,b2)
107+
% Create a struct of weights to be consumed by
108+
% transformer.layer.multiLayerPerceptron
109+
s = struct(...
110+
'mlp_c_fc_w_0',W1,...
111+
'mlp_c_fc_b_0',b1,...
112+
'mlp_c_proj_w_0',W2,...
113+
'mlp_c_proj_b_0',b2);
114+
end
115+
end
116+
end

test/gpt2/tdownload.m

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
classdef(SharedTestFixtures = {DownloadGPT2Fixture}) tdownload < matlab.unittest.TestCase
2+
% tdownload Tests for gpt2.download
3+
4+
% Copyright 2020 The MathWorks, Inc.
5+
6+
% downloadGPT2Fixture.setup calls gpt2.download so this test is just a
7+
% sanity check that the required files are downloaded.
8+
9+
properties(Constant)
10+
ExpectedDataDir = fullfile(getRepoRoot(),'gpt2-355M')
11+
ExpectedFiles = ["parameters.mat","vocab.bpe","encoder.txt"]
12+
end
13+
14+
methods(Test)
15+
function verifyFilesExist(test)
16+
test.assertEqual(exist(test.ExpectedDataDir,"dir"),7,...
17+
"Expected download directory for gpt2-355M not created.");
18+
files = dir(test.ExpectedDataDir);
19+
filenames = {files.name};
20+
import matlab.unittest.constraints.IsSupersetOf
21+
test.verifyThat(filenames,IsSupersetOf(test.ExpectedFiles),...
22+
"Expected files not downloaded for gpt2-355M.");
23+
import matlab.unittest.constraints.IsSameSetAs
24+
% dir picks up "." and ".." too.
25+
test.verifyThat(setdiff(filenames,test.ExpectedFiles), IsSameSetAs([".",".."]),...
26+
"Unexpected files downloaded for gpt2-355M.");
27+
end
28+
end
29+
end
30+

test/gpt2/tload.m

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
classdef(SharedTestFixtures = {DownloadGPT2Fixture}) tload < matlab.unittest.TestCase
2+
% tload Test for gpt2.load
3+
4+
% Copyright 2020 The MathWorks, Inc.
5+
6+
properties(Constant)
7+
ExpectedNumHeads = 16
8+
ExpectedNumLayers = 24
9+
ExpectedContext = 1024
10+
end
11+
12+
properties
13+
Parameters
14+
end
15+
16+
methods(TestClassSetup)
17+
function loadParameters(test)
18+
% Load the parameters once for all tests
19+
test.Parameters = gpt2.load(fullfile(getRepoRoot,"gpt2-355M","parameters.mat"));
20+
end
21+
end
22+
23+
methods(Test)
24+
function verifyLoadStructFields(test)
25+
% Verify the expected fieldnames of the loaded struct
26+
import matlab.unittest.constraints.IsSameSetAs
27+
expected = ["Hyperparameters","Weights"];
28+
test.verifyThat(fieldnames(test.Parameters), IsSameSetAs(expected));
29+
end
30+
31+
function verifyHyperparameters(test)
32+
% Verify the 355M config
33+
hyperParams = test.Parameters.Hyperparameters;
34+
test.verifyEqual(hyperParams.NumHeads,test.ExpectedNumHeads,...
35+
"Unexpected value for Hyperparameters.NumHeads");
36+
test.verifyEqual(hyperParams.NumLayers,test.ExpectedNumLayers,...
37+
"Unexpected value for Hyperparameters.NumLayers");
38+
test.verifyEqual(hyperParams.NumContext,test.ExpectedContext,...
39+
"Unexpected value for Hyperparameters.NumContext");
40+
end
41+
42+
function verifyWeights(test)
43+
% Verify the structure of the Weights field and check some
44+
% particular weight has the expected type.
45+
46+
% Here there is an implicit check that "model_" has been
47+
% removed from the weight names and the flat parameters.mat has
48+
% been organised into a heirarchy for each gpt2.block
49+
w = test.assertWarningFree(@() test.Parameters.Weights.h0.ln_1_g_0);
50+
import matlab.unittest.constraints.IsOfClass
51+
test.verifyThat(w,IsOfClass('dlarray'));
52+
test.verifyThat(extractdata(w),IsOfClass('single'));
53+
end
54+
end
55+
end

test/gpt2/tmodel.m

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
classdef(SharedTestFixtures = {DownloadGPT2Fixture}) tmodel < matlab.unittest.TestCase
2+
% tmodel Tests for gpt2.model
3+
4+
% Copyright 2020 The MathWorks, Inc.
5+
6+
properties(Constant)
7+
model = @gpt2.model
8+
end
9+
10+
methods(Test)
11+
function canUseModel(test)
12+
inputs = test.prepareInputs();
13+
test.verifyWarningFree(@() test.model(inputs{:}));
14+
end
15+
end
16+
17+
methods(Access=private)
18+
function inputs = prepareInputs(test)
19+
% Convenience method to setup inputs for
20+
% transformer.model
21+
X = test.prepareX();
22+
parameters = test.prepareParameters();
23+
pasts = test.preparePasts(parameters.Hyperparameters.NumLayers);
24+
inputs = {X,pasts,parameters};
25+
end
26+
27+
function X = prepareX(~)
28+
X = dlarray(1);
29+
end
30+
31+
function pasts = preparePasts(~,numLayers)
32+
pasts = cell(numLayers,1);
33+
end
34+
35+
function parameters = prepareParameters(~)
36+
parametersFile = fullfile(getRepoRoot(),'gpt2-355M','parameters.mat');
37+
parameters = gpt2.load(parametersFile);
38+
end
39+
end
40+
end

0 commit comments

Comments
 (0)