Skip to content

Commit 85331dc

Browse files
Maximilian.Sphaniarnab
authored andcommitted
[SYSTEMDS-3811] Multi-head Attention Layer
This patch introduces multi-head attention layer with forward and backward passes as a built-in layer. The multi-head attention layer is the base layer of all most Transformer models, with many variations for different models. This implementation is in-line with the basic BERT attention layer. The functionality is currently kept to a minimum without features like attention masking, head masking, cross-attention, etc. This patch is the first in a number of commits in an effort to support the BERT model in SystemDS and other transformer models in the future. Closes #2172
1 parent 5f360ef commit 85331dc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+606
-0
lines changed
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
2+
#-------------------------------------------------------------
3+
#
4+
# Licensed to the Apache Software Foundation (ASF) under one
5+
# or more contributor license agreements. See the NOTICE file
6+
# distributed with this work for additional information
7+
# regarding copyright ownership. The ASF licenses this file
8+
# to you under the Apache License, Version 2.0 (the
9+
# "License"); you may not use this file except in compliance
10+
# with the License. You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing,
15+
# software distributed under the License is distributed on an
16+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17+
# KIND, either express or implied. See the License for the
18+
# specific language governing permissions and limitations
19+
# under the License.
20+
#
21+
#-------------------------------------------------------------
22+
23+
source("nn/layers/softmax.dml") as softmax
24+
source("nn/layers/dropout.dml") as dropout
25+
source("scripts/nn/util.dml") as util
26+
27+
28+
forward = function(matrix[double] Q, matrix[double] K,
29+
matrix[double] V, int H, int T, int D, double dropout_p)
30+
return (matrix[double] context, matrix[double] attention, matrix[double] dropout_mask) {
31+
/*
32+
* Computes the forward pass for a multi-head attention layer.
33+
*
34+
* Inputs (B: Batch size, T: Sequence length, D: Embedding length, H: Heads):
35+
* - Q: Input querys, of shape (B,T*H*D).
36+
* - K: Input keys, of shape (B,T*H*D).
37+
* - V: Input values, of shape (B,T*H*D).
38+
* - H: Head count.
39+
* - T: Sequence length.
40+
* - D: Embedding length of single query, value, key,
41+
* - dropout_p: Dropout probability.
42+
*
43+
* Outputs:
44+
* - context: Token context embeddings, of shape (B, T*H*D)
45+
* - attention: Attention on value(s) for given query(s), of shape (B, H*T*T)
46+
* - dropout_mask: Dropout mask used on attention, of shape (B, H*T*T)
47+
*/
48+
B = nrow(Q)
49+
50+
# Transpose head and token dimension for per-head computation
51+
Q = util::transpose_ABCD_to_ACBD(Q, T, H) # Shape (B, H*T*D)
52+
K = util::transpose_ABCD_to_ACBD(K, T, H) # Shape (B, H*T*D)
53+
V = util::transpose_ABCD_to_ACBD(V, T, H) # Shape (B, H*T*D)
54+
55+
attention = matrix(0, rows=B, cols=H*T*T)
56+
dropout_mask = matrix(0, rows=B, cols=H*T*T)
57+
context = matrix(0, rows=B, cols=H*T*D)
58+
K_norm = K / sqrt(D)
59+
60+
# For loops for tensor operations
61+
for (batch in 1:B) {
62+
attention_probs_b = matrix(0, rows=H, cols=T*T)
63+
if (dropout_p > 0.0) {
64+
dropout_mask_b = matrix(0, rows=H, cols=T*T)
65+
}
66+
context_b = matrix(0, rows=H, cols=T*D)
67+
Q_b = matrix(Q[batch], rows=H, cols=T*D)
68+
K_norm_b = matrix(K_norm[batch], rows=H, cols=T*D)
69+
V_b = matrix(V[batch], rows=H, cols=T*D)
70+
71+
for (head in 1:H) {
72+
Q_h = matrix(Q_b[head], rows=T, cols=D)
73+
K_norm_h = matrix(K_norm_b[head], rows=T, cols=D)
74+
V_h = matrix(V_b[head], rows=T, cols=D)
75+
76+
attention_scores = Q_h %*% t(K_norm_h) # Shape (T, T)
77+
78+
# TODO: Add support for attention mask here
79+
80+
# Column-wise softmax
81+
attention_probs_h = softmax::forward(attention_scores)
82+
83+
if (dropout_p > 0.0) {
84+
[attention_probs_h, dropout_mask_h] = dropout::forward(attention_probs_h, dropout_p, -1)
85+
}
86+
87+
context_h = attention_probs_h %*% V_h # Shape (T, D)
88+
89+
attention_probs_b[head] = matrix(attention_probs_h, rows=1, cols=T*T)
90+
if (dropout_p > 0.0) {
91+
dropout_mask_b[head] = matrix(dropout_mask_h, rows=1, cols=T*T)
92+
}
93+
context_b[head] = matrix(context_h, rows=1, cols=T*D)
94+
}
95+
96+
attention[batch] = matrix(attention_probs_b, rows=1, cols=H*T*T)
97+
if (dropout_p > 0.0) {
98+
dropout_mask[batch] = matrix(dropout_mask_b, rows=1, cols=H*T*T)
99+
}
100+
context[batch] = matrix(context_b, rows=1, cols=H*T*D)
101+
}
102+
103+
# Swap head and token dimension for original shape
104+
context = util::transpose_ABCD_to_ACBD(context, H, T)
105+
}
106+
107+
108+
backward = function(matrix[double] dcontext,
109+
matrix[double] dropout_mask, matrix[double] attention, matrix[double] Q,
110+
matrix[double] K, matrix[double] V, int H, int T,
111+
int D, double dropout_p)
112+
return (matrix[double] dQ, matrix[double] dK, matrix[double] dV) {
113+
/*
114+
* Computes the backward pass for a multi-head attention layer.
115+
*
116+
* Inputs (B: Batch size, T: Sequence length, D: Embedding length, H: Heads):
117+
* - dcontext: Gradient w.r.t. the context matrix of shape (B, T*H*D)
118+
* - dropout_mask: Dropout mask from forward pass of shape (B, H*T*T)
119+
* - attention: Attention output from forward pass of shape (B, H*T*T)
120+
* - Q: Input querys, of shape (B,T*H*D).
121+
* - K: Input keys, of shape (B,T*H*D).
122+
* - V: Input values, of shape (B,T*H*D).
123+
* - H: Head count.
124+
* - T: Sequence length.
125+
* - D: Embedding length of single query, value, key,
126+
* - dropout_p: Dropout probability.
127+
*
128+
* Outputs:
129+
* - dQ: Gradient w.r.t. input querys, of shape (B,T*H*D).
130+
* - dK: Gradient w.r.t. input keys, of shape (B,T*H*D).
131+
* - dV: Gradient w.r.t. input values, of shape (B,T*H*D).
132+
*/
133+
B = nrow(Q)
134+
135+
# Transpose head and token dimension for per-head computation
136+
dcontext = util::transpose_ABCD_to_ACBD(dcontext, T, H) # Shape (B, H*T*D)
137+
Q = util::transpose_ABCD_to_ACBD(Q, T, H) # Shape (B, H*T*D)
138+
K = util::transpose_ABCD_to_ACBD(K, T, H) # Shape (B, H*T*D)
139+
V = util::transpose_ABCD_to_ACBD(V, T, H) # Shape (B, H*T*D)
140+
141+
dQ = matrix(0, rows=B, cols=H*T*D) # Shape (B, H*T*D)
142+
dK = matrix(0, rows=B, cols=H*T*D) # Shape (B, H*T*D)
143+
dV = matrix(0, rows=B, cols=H*T*D) # Shape (B, H*T*D)
144+
145+
K_norm = K / sqrt(D)
146+
147+
# For loops for tensor operations
148+
for (batch in 1:B) {
149+
dcontext_b = matrix(dcontext[batch], rows=H, cols=T*D)
150+
if (dropout_p > 0.0) {
151+
dropout_mask_b = matrix(dropout_mask[batch], rows=H, cols=T*T)
152+
}
153+
attention_b = matrix(attention[batch], rows=H, cols=T*T)
154+
155+
Q_b = matrix(Q[batch], rows=H, cols=T*D)
156+
K_norm_b = matrix(K_norm[batch], rows=H, cols=T*D)
157+
V_b = matrix(V[batch], rows=H, cols=T*D)
158+
159+
dQ_b = matrix(0, rows=H, cols=T*D)
160+
dK_b = matrix(0, rows=H, cols=T*D)
161+
dV_b = matrix(0, rows=H, cols=T*D)
162+
163+
for (head in 1:H) {
164+
dcontext_h = matrix(dcontext_b[head], rows=T, cols=D)
165+
if (dropout_p > 0.0) {
166+
dropout_mask_h = matrix(dropout_mask_b[head], rows=T, cols=T)
167+
}
168+
attention_h = matrix(attention_b[head], rows=T, cols=T)
169+
170+
# Compute dV early to release attention_h
171+
dV_h = t(attention_h) %*% dcontext_h
172+
173+
Q_h = matrix(Q_b[head], rows=T, cols=D)
174+
K_norm_h = matrix(K_norm_b[head], rows=T, cols=D)
175+
V_h = matrix(V_b[head], rows=T, cols=D)
176+
177+
dattention_probs = dcontext_h %*% t(V_h)
178+
179+
if (dropout_p > 0.0) {
180+
# Provide unnecessary required X input matrix via empty matrix
181+
dattention_probs = dropout::backward(dattention_probs, matrix(0, rows=1, cols=1), dropout_p, dropout_mask_h)
182+
}
183+
attention_scores = Q_h %*% t(K_norm_h) # Shape (T, T)
184+
dattention_scores = softmax::backward(dattention_probs, attention_scores)
185+
186+
dQ_h = dattention_scores %*% K_norm_h
187+
dK_h = t(dattention_scores) %*% (Q_h / sqrt(D))
188+
189+
# Append to batch matrices
190+
dK_b[head] = matrix(dK_h, rows=1, cols=T*D)
191+
dQ_b[head] = matrix(dQ_h, rows=1, cols=T*D)
192+
dV_b[head] = matrix(dV_h, rows=1, cols=T*D)
193+
}
194+
195+
# Append to output matrices
196+
dK[batch] = matrix(dK_b, rows=1, cols=H*T*D)
197+
dQ[batch] = matrix(dQ_b, rows=1, cols=H*T*D)
198+
dV[batch] = matrix(dV_b, rows=1, cols=H*T*D)
199+
}
200+
201+
# Swap head and token dimensions
202+
dK = util::transpose_ABCD_to_ACBD(dK, H, T)
203+
dQ = util::transpose_ABCD_to_ACBD(dQ, H, T)
204+
dV = util::transpose_ABCD_to_ACBD(dV, H, T)
205+
}
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.sysds.test.applications.nn.transformers;
20+
21+
import org.apache.sysds.common.Types;
22+
import org.apache.sysds.test.AutomatedTestBase;
23+
import org.apache.sysds.test.TestConfiguration;
24+
import org.apache.sysds.test.TestUtils;
25+
import org.junit.Test;
26+
27+
public class MultiAttentionLayerTest extends AutomatedTestBase {
28+
private static final String TEST_NAME_FORWARD = "multi_attention_forward";
29+
private static final String TEST_NAME_BACKWARD = "multi_attention_backward";
30+
private static final String TEST_DIR = "applications/nn/component/";
31+
private static final String RESOURCE_DIR = "src/test/resources/component/transformers/multi_attention_layer/";
32+
33+
@Override
34+
public void setUp() {
35+
TestUtils.clearAssertionInformation();
36+
addTestConfiguration(TEST_NAME_FORWARD, new TestConfiguration(TEST_DIR, TEST_NAME_FORWARD));
37+
addTestConfiguration(TEST_NAME_BACKWARD, new TestConfiguration(TEST_DIR, TEST_NAME_BACKWARD));
38+
}
39+
40+
@Test
41+
public void testMultiAttentionForwardSimple() {
42+
runMultiAttentionTest("test1", 2, 3, 4, 5, 0, TEST_NAME_FORWARD, 1e-5, true);
43+
}
44+
45+
@Test
46+
public void testMultiAttentionForwardLarge() {
47+
runMultiAttentionTest("test2", 8, 12, 10, 4, 0, TEST_NAME_FORWARD, 1e-5, true);
48+
}
49+
50+
@Test
51+
public void testMultiAttentionForwardSmall() {
52+
runMultiAttentionTest("test3", 1, 1, 1, 1, 0, TEST_NAME_FORWARD, 1e-5, true);
53+
}
54+
55+
@Test
56+
public void testMultiAttentionBackwardSimple() {
57+
runMultiAttentionTest("test4", 2, 3, 4, 5, 0, TEST_NAME_BACKWARD, 1e-5, false);
58+
}
59+
60+
@Test
61+
public void testMultiAttentionBackwardLarge() {
62+
runMultiAttentionTest("test5", 8, 12, 10, 5, 0, TEST_NAME_BACKWARD, 1e-5, false);
63+
}
64+
65+
@Test
66+
public void testMultiAttentionBackwardSmall() {
67+
runMultiAttentionTest("test6", 1, 1, 1, 1, 0, TEST_NAME_BACKWARD, 1e-5, false);
68+
}
69+
70+
private void runMultiAttentionTest(String testSuffix, int batchSize, int seqLength, int numHeads, int embeddingDim,
71+
int debug, String testname, double precision, boolean isForward) {
72+
// Set execution platform
73+
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
74+
75+
try {
76+
// Load test configuration
77+
getAndLoadTestConfiguration(testname);
78+
fullDMLScriptName = getScript();
79+
80+
// Program arguments
81+
if (isForward) {
82+
programArgs = new String[] {
83+
"-stats", "-args",
84+
String.valueOf(batchSize), String.valueOf(seqLength),
85+
String.valueOf(numHeads), String.valueOf(embeddingDim),
86+
String.valueOf(debug),
87+
RESOURCE_DIR + "input_query_" + testSuffix + ".csv",
88+
RESOURCE_DIR + "input_key_" + testSuffix + ".csv",
89+
RESOURCE_DIR + "input_value_" + testSuffix + ".csv",
90+
RESOURCE_DIR + "output_context_" + testSuffix + ".csv",
91+
RESOURCE_DIR + "output_attention_" + testSuffix + ".csv",
92+
output("context_error"),
93+
output("attention_error"),
94+
};
95+
} else {
96+
programArgs = new String[] {
97+
"-stats", "-args",
98+
String.valueOf(batchSize), String.valueOf(seqLength),
99+
String.valueOf(numHeads), String.valueOf(embeddingDim),
100+
String.valueOf(debug),
101+
RESOURCE_DIR + "input_query_" + testSuffix + ".csv",
102+
RESOURCE_DIR + "input_key_" + testSuffix + ".csv",
103+
RESOURCE_DIR + "input_value_" + testSuffix + ".csv",
104+
RESOURCE_DIR + "input_dcontext_" + testSuffix + ".csv",
105+
RESOURCE_DIR + "input_attention_" + testSuffix + ".csv",
106+
RESOURCE_DIR + "output_dquery_" + testSuffix + ".csv",
107+
RESOURCE_DIR + "output_dkey_" + testSuffix + ".csv",
108+
RESOURCE_DIR + "output_dvalue_" + testSuffix + ".csv",
109+
output("dquery_error"),
110+
output("dkey_error"),
111+
output("dvalue_error"),
112+
};
113+
}
114+
115+
// Run the test
116+
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
117+
118+
// Compare results
119+
if (isForward) {
120+
double contextMaxError = (Double) readDMLScalarFromOutputDir("context_error").values().toArray()[0];
121+
assert contextMaxError < precision;
122+
double attentionMaxError = (Double) readDMLScalarFromOutputDir("context_error").values().toArray()[0];
123+
assert attentionMaxError < precision;
124+
} else {
125+
double dqueryMaxError = (Double) readDMLScalarFromOutputDir("dquery_error").values().toArray()[0];
126+
assert dqueryMaxError < precision;
127+
double dkeyMaxError = (Double) readDMLScalarFromOutputDir("dkey_error").values().toArray()[0];
128+
assert dkeyMaxError < precision;
129+
double dvalueMaxError = (Double) readDMLScalarFromOutputDir("dvalue_error").values().toArray()[0];
130+
assert dvalueMaxError < precision;
131+
}
132+
} catch (Throwable ex) {
133+
ex.printStackTrace(System.out); // Log or debug all exceptions or errors
134+
throw new RuntimeException(ex);
135+
} finally {
136+
resetExecMode(platformOld);
137+
}
138+
}
139+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
0.328471,0.331907,0.339622,0.336846,0.335182,0.327972,0.335271,0.335710,0.329019,0.329675,0.336258,0.334066,0.342494,0.325649,0.331857,0.342283,0.335882,0.321835,0.331431,0.336706,0.331864,0.313857,0.340129,0.346014,0.346816,0.319861,0.333323,0.353853,0.299062,0.347086,0.330171,0.313161,0.356668,0.339396,0.324497,0.336107
2+
0.300318,0.358489,0.341193,0.319460,0.342199,0.338342,0.321149,0.352661,0.326190,0.331912,0.333238,0.334850,0.324741,0.354198,0.321062,0.307494,0.367816,0.324689,0.357409,0.339341,0.303250,0.357259,0.333069,0.309672,0.359344,0.327826,0.312830,0.333608,0.334254,0.332138,0.343666,0.329218,0.327117,0.328783,0.338962,0.332255

0 commit comments

Comments
 (0)