Skip to content

Commit 9d354fd

Browse files
committed
More convenient pasts layout
1 parent e33baf2 commit 9d354fd

File tree

2 files changed

+13
-26
lines changed

2 files changed

+13
-26
lines changed

+transformer/+layer/attention.m

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
% Inputs:
99
% X - A (numFeatures*numHeads)-by-numInputSubwords-by-numObs
1010
% input array.
11-
% past - A numFeatures-by-numPastSubwords-by-numHeads-by-2-by-numObs
11+
% past - A numFeatures-by-numPastSubwords-by-numHeads-by-numObs-by-2
1212
% array. This contains the 'keys' and 'values' for
1313
% past subwords. These are needed to predict future
1414
% outputs in an autoregressive manner. 'keys' are
15-
% stored in past(:,:,:,1,:) and 'values' are stored
16-
% in past(:,:,:,2,:).
15+
% stored in past(:,:,:,:,1) and 'values' are stored
16+
% in past(:,:,:,:,2).
1717
% weights - The weights for the full multi-head attention
1818
% block stored in a struct. This includes:
1919
% - attn_c_attn_w_0: A weight matrix for the
@@ -30,13 +30,13 @@
3030
% Outputs:
3131
% Z - A (numFeatures*numHeads)-by-numInputSubwords-by-numObs
3232
% output array.
33-
% present - A numFeatures-by-numAllSubwords-by-numHeads-by-2-by-numObs
33+
% present - A numFeatures-by-numAllSubwords-by-numHeads-by-numObs-by-2
3434
% array. This contains the 'keys' and 'values' that
3535
% are created from inputs. These need to passed
3636
% back in as the 'past' input if we want to predict
3737
% future outputs in an autoregressive manner. 'keys'
38-
% are stored in present(:,:,:,1,:) and 'values' are
39-
% stored in present(:,:,:,2,:).
38+
% are stored in present(:,:,:,:,1) and 'values' are
39+
% stored in present(:,:,:,:,2).
4040
%
4141
% References:
4242
%
@@ -63,21 +63,16 @@
6363

6464
% Use the past
6565
if ~isempty(past)
66-
% Here we must squeeze out the singleton fourth dimensions after
67-
% extracting the keys and values from past, since K, V have dimensions
68-
% numFeatures-by-numPastSubwords-by-numHeads-by-numObs
69-
PK = permute(past(:,:,:,1,:), [1 2 3 5 4]);
70-
PV = permute(past(:,:,:,2,:), [1 2 3 5 4]);
66+
PK = past(:,:,:,:,1);
67+
PV = past(:,:,:,:,2);
7168
K = cat(2,PK,K);
7269
V = cat(2,PV,V);
7370
end
7471

7572
% Set present. Note that this is done differently from the original
7673
% implementation which sets the value of present before the previous if
77-
% statement. Here we cat K, V along the fifth dimension, then permute to
78-
% recover the layout numFeatures-by-numPastSubwords-by-numHeads-by-2-by-numObs
74+
% statement
7975
present = cat(5,K,V);
80-
present = permute(present, [1 2 3 5 4]);
8176

8277
A = transformer.layer.multiheadAttention(Q,K,V);
8378

test/transformer/layer/tattention.m

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,26 +86,18 @@ function checkPastPresentCaching(test,NumQueries,NumObs)
8686
% Verify the expected value of past - it is the key and values
8787
% concatenated on the 4th dimension.
8888
[~,K,V] = iSplitQKV(x,hyperParams.NumHeads,latentDim);
89-
K = reshape(K, [size(K, 1:3) 1 size(K, 4)]);
90-
V = reshape(V, [size(V, 1:3) 1 size(V, 4)]);
91-
test.verifyEqual(past,cat(4,K,V));
89+
test.verifyEqual(past,cat(5,K,V));
9290
% Now verify second call to attention is possible with the first
9391
% past as input - and verify the value of the attention output.
9492
[yAct,present] = test.attention(x,past,weights,hyperParams);
9593
[Q,K,V] = iSplitQKV(x,hyperParams.NumHeads,latentDim);
96-
Q = reshape(Q, [size(Q, 1:3) 1 size(Q, 4)]);
97-
K = reshape(K, [size(K, 1:3) 1 size(K, 4)]);
98-
V = reshape(V, [size(V, 1:3) 1 size(V, 4)]);
9994
% Verify the correct value for present.
100-
pastK = past(:,:,:,1,:);
101-
pastV = past(:,:,:,2,:);
102-
test.verifyEqual(extractdata(present),extractdata(cat(4,cat(2,pastK,K),cat(2,pastV,V))),'AbsTol',1e-5);
95+
pastK = past(:,:,:,:,1);
96+
pastV = past(:,:,:,:,2);
97+
test.verifyEqual(extractdata(present),extractdata(cat(5,cat(2,pastK,K),cat(2,pastV,V))),'AbsTol',1e-5);
10398
% To compute the expected value, concatenate the pasts
10499
K = cat(2,K,pastK);
105100
V = cat(2,V,pastV);
106-
Q = permute(Q, [1 2 3 5 4]);
107-
K = permute(K, [1 2 3 5 4]);
108-
V = permute(V, [1 2 3 5 4]);
109101
yExp = test.multiheadAttention(Q,K,V);
110102
yExp = iMergeHeads(yExp);
111103
test.verifyEqual(extractdata(yAct),extractdata(yExp),'AbsTol',1e-5);

0 commit comments

Comments
 (0)