1+ classdef tblock < matlab .unittest .TestCase
2+ % tblock Unit tests for transformer.layer.block
3+
4+ % Copyright 2020 The MathWorks, Inc.
5+
6+ properties (Constant , Access = private )
7+ block = @gpt2.layer.block
8+ end
9+
10+ properties (TestParameter )
11+ Input = struct(...
12+ ' Scalar' , 1 ,...
13+ ' Vector' , 1 : 5 ,...
14+ ' Matrix' , reshape (1 : 6 ,[3 ,2 ]))
15+ end
16+
17+ methods (Test )
18+ function outputHasInputSize(test ,Input )
19+ % The block is simply a composition of other layers. Simply
20+ % verify the output of a block is the same size as the input,
21+ % and as such the blocks can be stacked.
22+ x = dlarray(Input );
23+ C = size(Input ,1 );
24+ weights = test .randomWeights(C );
25+ hyperParameters.NumHeads = 1 ;
26+ y = test .block(x ,[],weights ,hyperParameters );
27+ test .verifySize(y ,size(x ));
28+ end
29+
30+ function outputHasInputSizeWithPasts(test ,Input )
31+ % As above but using "pasts" - a concatenation of key and value
32+ % matrices.
33+ x = dlarray(Input );
34+ C = size(Input ,1 );
35+ weights = test .randomWeights(C );
36+ hyperParameters.NumHeads = 1 ;
37+ % Provide a fake past of sequence length 1
38+ K_fake = dlarray(rand(C ,1 ));
39+ V_fake = dlarray(rand(C ,1 ));
40+ past = cat(4 ,K_fake ,V_fake );
41+ [y ,present ] = test .block(x ,past ,weights ,hyperParameters );
42+ test .verifySize(y ,size(x ));
43+ % The size of presents is the size of past except the sequence
44+ % dimension gets extended by the sequence length of y
45+ exp_present_size = size(past );
46+ exp_present_size(2 ) = exp_present_size(2 )+size(y ,2 );
47+ test .verifySize(present ,exp_present_size );
48+ end
49+ end
50+
51+ methods (Access = private )
52+ function weights = randomWeights(test ,C )
53+ % C is num features, or latent dimension of the block
54+ g1 = dlarray(rand(C ,1 ));
55+ b1 = dlarray(rand(C ,1 ));
56+ g2 = dlarray(rand(C ,1 ));
57+ b2 = dlarray(rand(C ,1 ));
58+ W_A1 = dlarray(rand(3 * C ,C ));
59+ W_A2 = dlarray(rand(C ));
60+ b_A1 = dlarray(rand(3 * C ,1 ));
61+ b_A2 = dlarray(rand(C ,1 ));
62+ W_P1 = dlarray(rand(C ));
63+ b_P1 = dlarray(rand(C ,1 ));
64+ W_P2 = dlarray(rand(C ));
65+ b_P2 = dlarray(rand(C ,1 ));
66+ weights = test .prepareBlockWeightsStruct(g1 ,b1 ,W_A1 ,b_A1 ,W_A2 ,b_A2 ,g2 ,b2 ,W_P1 ,b_P1 ,W_P2 ,b_P2 );
67+ end
68+
69+ function s = prepareBlockWeightsStruct(test ,g1 ,b1 ,W_A1 ,b_A1 ,W_A2 ,b_A2 ,g2 ,b2 ,W_P1 ,b_P1 ,W_P2 ,b_P2 )
70+ % Merge various structs that have the appropriate weight naming
71+ % syntax.
72+ s_ln = test .prepareLayerNormWeightsStruct(g1 ,b1 ,g2 ,b2 );
73+ s_attn = test .prepareAttentionWeightsStruct(W_A1 ,b_A1 ,W_A2 ,b_A2 );
74+ s_mlp = test .prepareMLPWeightsStruct(W_P1 ,b_P1 ,W_P2 ,b_P2 );
75+ c = {s_ln ,s_attn ,s_mlp };
76+ fn = cellfun(@fieldnames ,c ,' UniformOutput' ,false );
77+ fn = cat(1 ,fn{: });
78+ fv = cellfun(@struct2cell ,c ,' UniformOutput' ,false );
79+ fv = cat(1 ,fv{: });
80+ s = struct();
81+ for i = 1 : numel(fn )
82+ s.(fn{i }) = fv{i };
83+ end
84+ end
85+
86+ function s = prepareAttentionWeightsStruct(~,W1 ,b1 ,W2 ,b2 )
87+ % Prepare a struct compatible with the weights input of
88+ % attention. These are for the fully connected layers.
89+ s = struct(...
90+ ' attn_c_attn_w_0' ,W1 ,...
91+ ' attn_c_attn_b_0' ,b1 ,...
92+ ' attn_c_proj_w_0' ,W2 ,...
93+ ' attn_c_proj_b_0' ,b2 );
94+ end
95+
96+ function s = prepareLayerNormWeightsStruct(~,g1 ,b1 ,g2 ,b2 )
97+ % Prepare a struct of weights compatible with the two layer
98+ % norm calls in block
99+ s = struct(...
100+ ' ln_1_g_0' ,g1 ,...
101+ ' ln_1_b_0' ,b1 ,...
102+ ' ln_2_g_0' ,g2 ,...
103+ ' ln_2_b_0' ,b2 );
104+ end
105+
106+ function s = prepareMLPWeightsStruct(~,W1 ,b1 ,W2 ,b2 )
107+ % Create a struct of weights to be consumed by
108+ % transformer.layer.multiLayerPerceptron
109+ s = struct(...
110+ ' mlp_c_fc_w_0' ,W1 ,...
111+ ' mlp_c_fc_b_0' ,b1 ,...
112+ ' mlp_c_proj_w_0' ,W2 ,...
113+ ' mlp_c_proj_b_0' ,b2 );
114+ end
115+ end
116+ end
0 commit comments