66% 2 in [1]. See below for details of inputs and outputs.
77%
88% Inputs:
9- % X - A (numFeatures*numHeads)-by-numInputSubwords
9+ % X - A (numFeatures*numHeads)-by-numInputSubwords-by-numObs
1010% input array.
11- % past - A numFeatures-by-numPastSubwords-by-numHeads-by-2
11+ % past - A numFeatures-by-numPastSubwords-by-numHeads-by-numObs-by- 2
1212% array. This contains the 'keys' and 'values' for
1313% past subwords. These are needed to predict future
1414% outputs in an autoregressive manner. 'keys' are
15- % stored in past(:,:,:,1) and 'values' are stored
16- % in past(:,:,:,2).
15+ % stored in past(:,:,:,:, 1) and 'values' are stored
16+ % in past(:,:,:,:, 2).
1717% weights - The weights for the full multi-head attention
1818% block stored in a struct. This includes:
1919% - attn_c_attn_w_0: A weight matrix for the
2828% hyper-parameter.
2929%
3030% Outputs:
31- % Z - A (numFeatures*numHeads)-by-numInputSubwords
31+ % Z - A (numFeatures*numHeads)-by-numInputSubwords-by-numObs
3232% output array.
33- % present - A numFeatures-by-numAllSubwords-by-numHeads-by-2
33+ % present - A numFeatures-by-numAllSubwords-by-numHeads-by-numObs-by- 2
3434% array. This contains the 'keys' and 'values' that
3535% are created from inputs. These need to passed
3636% back in as the 'past' input if we want to predict
3737% future outputs in an autoregressive manner. 'keys'
38- % are stored in present(:,:,:,1) and 'values' are
39- % stored in present(:,:,:,2).
38+ % are stored in present(:,:,:,:, 1) and 'values' are
39+ % stored in present(:,:,:,:, 2).
4040%
4141% References:
4242%
5252
5353% Split the results into Q (Query), K (Keys) and V (Values).
5454splitSize = size(C ,1 )/3 ;
55- Q = C(1 : splitSize ,: );
56- K = C((splitSize + 1 ): (2 * splitSize ),: );
57- V = C((2 * splitSize + 1 ): (3 * splitSize ),: );
55+ Q = C(1 : splitSize ,: , : );
56+ K = C((splitSize + 1 ): (2 * splitSize ),: , : );
57+ V = C((2 * splitSize + 1 ): (3 * splitSize ),: , : );
5858
5959% Split heads
6060Q = iSplitHeads(Q , splitSize , hyperParameters .NumHeads );
6363
6464% Use the past
6565if ~isempty(past )
66- PK = past(: ,: ,: ,1 );
67- PV = past(: ,: ,: ,2 );
66+ PK = past(: ,: ,: ,: , 1 );
67+ PV = past(: ,: ,: ,: , 2 );
6868 K = cat(2 ,PK ,K );
6969 V = cat(2 ,PV ,V );
7070end
7171
7272% Set present. Note that this is done differently from the original
7373% implementation which sets the value of present before the previous if
74- % statement.
75- present = cat(4 ,K ,V );
74+ % statement
75+ present = cat(5 ,K ,V );
7676
7777A = transformer .layer .multiheadAttention(Q ,K ,V );
7878
8181A = transformer .layer .convolution1d( A , ...
8282 weights .attn_c_proj_w_0 , ...
8383 weights .attn_c_proj_b_0 );
84-
8584end
8685
8786function Z = iSplitHeads(X , splitSize , numHeads )
8887% We permute the data to put the dimension for the heads last, so that we
8988% can use batched matrix multiplication to compute attention for all of the
9089% heads at once.
9190%
92- % X - A (numFeatures*numHeads)-by-numSubwords array.
93- % Z - A numFeatures-by-numSubwords-by-numHeads array.
94- X = reshape(X , splitSize / numHeads , numHeads , []);
95- Z = permute(X ,[1 3 2 ]);
91+ % X - A (numFeatures*numHeads)-by-numSubwords-by-numObs array.
92+ % Z - A numFeatures-by-numSubwords-by-numHeads-by-numObs array.
93+ X = reshape(X , splitSize / numHeads , numHeads , [], size( X , 3 ) );
94+ Z = permute(X ,[1 3 2 4 ]);
9695end
9796
9897function Z = iMergeHeads(X )
99- % X - A numFeatures-by-numSubwords-by-numHeads array.
100- % Z - A (numFeatures*numHeads)-by-numSubwords array.
101- X = permute(X , [1 3 2 ]);
102- Z = reshape(X , size(X ,1 )*size(X ,2 ), []);
98+ % X - A numFeatures-by-numSubwords-by-numHeads-by-numObs array.
99+ % Z - A (numFeatures*numHeads)-by-numSubwords-by-numObs array.
100+ X = permute(X , [1 3 2 4 ]);
101+ Z = reshape(X , size(X ,1 )*size(X ,2 ), [], size( X , 4 ) );
103102end
0 commit comments