Skip to content

Commit e022eaf

Browse files
MaximilianSchreffphaniarnab
authored andcommitted
[SYSTEMDS-3829] BERT layer forward pass
This patch introduces the forward pass of the BERT layer from the BERT transformer architecture as a built-in. Closes #2184
1 parent 344ca0b commit e022eaf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+647
-1
lines changed

scripts/nn/layers/bert_layer.dml

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
source("nn/layers/affine.dml") as affine
23+
source("nn/layers/multi_attention.dml") as attention
24+
source("nn/layers/dropout.dml") as dropout
25+
source("nn/layers/batch_norm1d.dml") as batch_norm
26+
source("nn/layers/tanh.dml") as tanh
27+
source("nn/layers/gelu.dml") as gelu
28+
29+
linear_tensor_forward = function(matrix[double] X, matrix[double] W, matrix[double] b, int B, int C)
30+
return (matrix[double] out) {
31+
/*
32+
* Helper function for computing linear layer with tensor input, of shape (A, B*C)
33+
*/
34+
A = nrow(X)
35+
C_new = ncol(W)
36+
out = affine::forward(matrix(X, rows=A*B, cols=C), W, b)
37+
out = matrix(out, rows=A, cols=B*C_new)
38+
}
39+
40+
layer_norm_forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta, double epsilon, int B, int C)
41+
return (matrix[double] out, matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm) {
42+
/*
43+
* Helper function for computing layer norm via 1D batch norm with tensor input, of shpae (A, B*C)
44+
*/
45+
A = nrow(X)
46+
batch_norm_input = t(matrix(X, rows=A*B, cols=C))
47+
# EMA matrices are unused and thus empty matrices will be provided
48+
emas_mat = matrix(0, rows=1, cols=A*B)
49+
[batch_norm_out, unused1, unused2, cache_mean, cache_var, cache_norm] = batch_norm::forward(
50+
batch_norm_input, t(gamma), t(beta), "train", emas_mat, emas_mat, 0.0, epsilon)
51+
out = matrix(t(batch_norm_out), rows=A, cols=B*C)
52+
}
53+
54+
forward = function(matrix[double] states,
55+
int H, int T, int d, int I,
56+
matrix[double] W_Q, matrix[double] b_Q,
57+
matrix[double] W_K, matrix[double] b_K,
58+
matrix[double] W_V, matrix[double] b_V,
59+
matrix[double] W_context, matrix[double] b_context,
60+
matrix[double] W_intermediate, matrix[double] b_intermediate,
61+
matrix[double] W_out, matrix[double] b_out,
62+
double dropout_p_attention,
63+
double dropout_p_output,
64+
double epsilon_ln,
65+
matrix[double] gamma_ln1, matrix[double] beta_ln1,
66+
matrix[double] gamma_ln2, matrix[double] beta_ln2,
67+
string activation)
68+
return (matrix[double] out_states, matrix[double] attention,
69+
list[unknown] outputs,
70+
matrix[double] dropout_mask_attention,
71+
matrix[double] dropout_mask_output_1,
72+
matrix[double] dropout_mask_output_2,
73+
matrix[double] cache_mean_ln1, matrix[double] cache_var_ln1, matrix[double] cache_norm_ln1,
74+
matrix[double] cache_mean_ln2, matrix[double] cache_var_ln2, matrix[double] cache_norm_ln2) {
75+
/*
76+
* Computes the forward pass for a layer of the BERT transformer architecture.
77+
*
78+
* Inputs (B: Batch size, T: Sequence length, D: Embedding length, H: Heads):
79+
* - states: Hidden states, of shape (B, T*D).
80+
* - H: Head count.
81+
* - T: Sequence length.
82+
* - d: Embedding length of single token per head with d*H = D.
83+
* - I: Intemediate embedding length.
84+
* - W_Q: Weights for linear query layer, of shape (D, D).
85+
* - b_Q: Biases for linear query layer, of shape (1, D).
86+
* - W_K: Weights for linear key layer, of shape (D, D).
87+
* - b_K: Biases for linear key layer, of shape (1, D).
88+
* - W_V: Weights for linear value layer, of shape (D, D).
89+
* - b_V: Biases for linear value layer, of shape (1, D).
90+
* - W_context: Weights for linear output layer on context, of shape (D, D).
91+
* - b_context: Biases for linear output layer on context, of shape (1, D).
92+
* - W_intermediate: Weights for intermediate linear layer, of shape (D, I).
93+
* - b_intermediate: Biases for intermediate linear layer, of shape (1, I).
94+
* - W_out: Weights for last linear output layer, of shape (D, D).
95+
* - b_out: Biases for last linear output layer, of shape (1, D).
96+
* - dropout_p_attention: Probability for dropout on attention.
97+
* - dropout_p_output: Probability for dropout on output.
98+
* - epsilon_ln: Epsilon value for layer norm.
99+
* - gamma_ln1: Gamma params for layer norm 1, of shape (1, D).
100+
* - beta_ln1: Beta params for layer norm 1, of shape (1, D).
101+
* - gamma_ln2: Gamma params for layer norm 2, of shape (1, D).
102+
* - beta_ln2: Beta params for layer norm 2, of shape (1, D).
103+
* - activation: String specifying type of activation to use.
104+
* Can be tanh or gelu.
105+
*
106+
* Outputs:
107+
* - out_states: Token output states, of shape (B, T*D)
108+
* - attention: Attention values for keys & querys, of shape (B, H*T*T)
109+
* - outputs: List of relevant outputs for backward pass with following
110+
* order/content:
111+
* -> 1: Output of linear query layer, of shape (B, T*D).
112+
* -> 2: Output of linear key layer, of shape (B, T*D).
113+
* -> 3: Output of linear value layer, of shape (B, T*D).
114+
* -> 4: Output context of attention layer, of shape (B, T*D).
115+
* -> 5: Output attention of attention layer, of shape (B, T*D).
116+
* -> 6: Output of residual pass 1, of shape (B, T*D).
117+
* -> 7: Output of layer norm 1, of shape (B, T*D).
118+
* -> 8: Output of intermediate linear layer, of shape (B, T*I).
119+
* -> 9: Output of activation layer, of shape (B, T*I).
120+
* -> 10: Output of residual pass 2, of shape (B, T*D).
121+
* - dropout_mask_attention: Dropout mask used on attention, of shape (B, H*T*T)
122+
* - dropout_mask_output_1: Dropout mask used on attention output, of shape (B, T*D)
123+
* - dropout_mask_output_2: Dropout mask used on attention output, of shape (B, T*D)
124+
* - cache_mean_ln1: Cached mean from layer norm 1, of shape (1, B*T)
125+
* - cache_var_ln1: Cached mean from layer norm 1, of shape (1, B*T)
126+
* - cache_norm_ln1: Cached mean from layer norm 1, of shape (1, B*T)
127+
* - cache_mean_ln2: Cached mean from layer norm 2, of shape (1, B*T)
128+
* - cache_var_ln2: Cached mean from layer norm 2, of shape (1, B*T)
129+
* - cache_norm_ln2: Cached mean from layer norm 2, of shape (1, B*T)
130+
*/
131+
# Embedding dim
132+
D = d * H
133+
134+
# Linear layers for Q, K, V
135+
Q = linear_tensor_forward(states, W_Q, b_Q, T, D) # Shape (B, T*D)
136+
K = linear_tensor_forward(states, W_K, b_K, T, D) # Shape (B, T*D)
137+
V = linear_tensor_forward(states, W_V, b_V, T, D) # Shape (B, T*D)
138+
139+
# Multi-head self attention
140+
[context, attention, dropout_mask_attention] = attention::forward(Q, K, V, H, T, d, dropout_p_attention)
141+
# Shapes (B, T*D), (B, H*T*T), (B, H*T*T)
142+
outputs = list(Q, K, V, context, attention)
143+
144+
# Linear layer on attention output (output layer)
145+
out_states = linear_tensor_forward(context, W_context, b_context, T, D) # Shape (B, T*D)
146+
# Dropout on output 1
147+
dropout_mask_output_1 = matrix(0, 1, 1)
148+
if (dropout_p_output > 0.0) {
149+
[out_states, dropout_mask_output_1] = dropout::forward(out_states, dropout_p_output, -1)
150+
}
151+
152+
# Residual pass 1
153+
out_states = out_states + states # Shapes (B, T*D).
154+
outputs = append(outputs, out_states)
155+
# Layer norm 1 for each token
156+
[out_states, cache_mean_ln1, cache_var_ln1, cache_norm_ln1] = layer_norm_forward(
157+
out_states, gamma_ln1, beta_ln1, epsilon_ln, T, D)
158+
outputs = append(outputs, out_states)
159+
160+
# Save out_states for residual pass
161+
out_states_identity = out_states
162+
# Linear layer of intermediate part
163+
out_states = linear_tensor_forward(out_states, W_intermediate, b_intermediate, T, D) # Shape (B, T*I)
164+
outputs = append(outputs, out_states)
165+
# Activation
166+
if (activation == "gelu") {
167+
out_states = gelu::forward(out_states)
168+
} else if (activation == "tanh") {
169+
out_states = tanh::forward(out_states)
170+
}
171+
outputs = append(outputs, out_states)
172+
173+
# Final linear output layer
174+
out_states = linear_tensor_forward(out_states, W_out, b_out, T, I) # Shape (B, T*D)
175+
# Dropout on output 2
176+
dropout_mask_output_2 = matrix(0, 1, 1)
177+
if (dropout_p_output > 0.0) {
178+
[out_states, dropout_mask_output_2] = dropout::forward(out_states, dropout_p_output, -1)
179+
}
180+
# Residual pass 2
181+
out_states = out_states + out_states_identity
182+
outputs = append(outputs, out_states)
183+
# Layer norm 2 for each token
184+
[out_states, cache_mean_ln2, cache_var_ln2, cache_norm_ln2] = layer_norm_forward(
185+
out_states, gamma_ln2, beta_ln2, epsilon_ln, T, D)
186+
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.sysds.test.applications.nn.transformers;
20+
21+
import org.apache.sysds.common.Types;
22+
import org.apache.sysds.test.AutomatedTestBase;
23+
import org.apache.sysds.test.TestConfiguration;
24+
import org.apache.sysds.test.TestUtils;
25+
import org.junit.Test;
26+
27+
public class BertLayerTest extends AutomatedTestBase{
28+
private static final String TEST_NAME_FORWARD = "bert_layer_forward";
29+
private static final String TEST_DIR = "applications/nn/component/";
30+
private static final String RESOURCE_DIR = "src/test/resources/component/transformers/bert_layer/";
31+
32+
@Override
33+
public void setUp() {
34+
TestUtils.clearAssertionInformation();
35+
addTestConfiguration(TEST_NAME_FORWARD, new TestConfiguration(TEST_DIR, TEST_NAME_FORWARD));
36+
}
37+
38+
@Test
39+
public void testBertLayerForwardNormalTanh() {
40+
runBertLayerTest("test1", 5, 4, 6, 2, 3, 7, "tanh", 0, TEST_NAME_FORWARD,
41+
1e-5, true);
42+
}
43+
44+
@Test
45+
public void testBertLayerForwardNormalGelu() {
46+
runBertLayerTest("test2", 4, 4, 8, 2, 4, 7, "gelu", 0, TEST_NAME_FORWARD,
47+
1e-5, true);
48+
}
49+
50+
private void runBertLayerTest(String testSuffix, int batchSize, int seqLength, int embeddingDim, int numHeads,
51+
int perHeadEmbeddingDim, int intermediateEmbeddingDim, String activation, int debug, String testname, double precision,
52+
boolean isForward) {
53+
// Set execution platform
54+
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
55+
56+
try {
57+
// Load test configuration
58+
getAndLoadTestConfiguration(testname);
59+
fullDMLScriptName = getScript();
60+
61+
// Program arguments
62+
if (isForward) {
63+
programArgs = new String[] {
64+
"-stats", "-args",
65+
String.valueOf(debug), String.valueOf(batchSize),
66+
String.valueOf(seqLength), String.valueOf(embeddingDim),
67+
String.valueOf(numHeads), String.valueOf(perHeadEmbeddingDim),
68+
String.valueOf(intermediateEmbeddingDim), activation,
69+
RESOURCE_DIR + "input_states_" + testSuffix + ".csv",
70+
RESOURCE_DIR + "input_W_Q_" + testSuffix + ".csv",
71+
RESOURCE_DIR + "input_b_Q_" + testSuffix + ".csv",
72+
RESOURCE_DIR + "input_W_K_" + testSuffix + ".csv",
73+
RESOURCE_DIR + "input_b_K_" + testSuffix + ".csv",
74+
RESOURCE_DIR + "input_W_V_" + testSuffix + ".csv",
75+
RESOURCE_DIR + "input_b_V_" + testSuffix + ".csv",
76+
RESOURCE_DIR + "input_W_context_" + testSuffix + ".csv",
77+
RESOURCE_DIR + "input_b_context_" + testSuffix + ".csv",
78+
RESOURCE_DIR + "input_W_intermediate_" + testSuffix + ".csv",
79+
RESOURCE_DIR + "input_b_intermediate_" + testSuffix + ".csv",
80+
RESOURCE_DIR + "input_W_out_" + testSuffix + ".csv",
81+
RESOURCE_DIR + "input_b_out_" + testSuffix + ".csv",
82+
RESOURCE_DIR + "input_gamma_ln1_" + testSuffix + ".csv",
83+
RESOURCE_DIR + "input_beta_ln1_" + testSuffix + ".csv",
84+
RESOURCE_DIR + "input_gamma_ln2_" + testSuffix + ".csv",
85+
RESOURCE_DIR + "input_beta_ln2_" + testSuffix + ".csv",
86+
RESOURCE_DIR + "output_states_" + testSuffix + ".csv",
87+
RESOURCE_DIR + "output_attention_" + testSuffix + ".csv",
88+
output("states_error"),
89+
output("attention_error"),
90+
};
91+
}
92+
93+
// Run the test
94+
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
95+
96+
// Compare results
97+
if (isForward) {
98+
double statesMaxError = (Double) readDMLScalarFromOutputDir("states_error").values().toArray()[0];
99+
assert statesMaxError < precision;
100+
double attentionMaxError = (Double) readDMLScalarFromOutputDir("attention_error").values().toArray()[0];
101+
assert attentionMaxError < precision;
102+
} else {
103+
double dqueryMaxError = (Double) readDMLScalarFromOutputDir("dquery_error").values().toArray()[0];
104+
assert dqueryMaxError < precision;
105+
double dkeyMaxError = (Double) readDMLScalarFromOutputDir("dkey_error").values().toArray()[0];
106+
assert dkeyMaxError < precision;
107+
double dvalueMaxError = (Double) readDMLScalarFromOutputDir("dvalue_error").values().toArray()[0];
108+
assert dvalueMaxError < precision;
109+
}
110+
} catch (Throwable ex) {
111+
ex.printStackTrace(System.out); // Log or debug all exceptions or errors
112+
throw new RuntimeException(ex);
113+
} finally {
114+
resetExecMode(platformOld);
115+
}
116+
}
117+
}

src/test/java/org/apache/sysds/test/applications/nn/transformers/MultiAttentionLayerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ private void runMultiAttentionTest(String testSuffix, int batchSize, int seqLeng
119119
if (isForward) {
120120
double contextMaxError = (Double) readDMLScalarFromOutputDir("context_error").values().toArray()[0];
121121
assert contextMaxError < precision;
122-
double attentionMaxError = (Double) readDMLScalarFromOutputDir("context_error").values().toArray()[0];
122+
double attentionMaxError = (Double) readDMLScalarFromOutputDir("attention_error").values().toArray()[0];
123123
assert attentionMaxError < precision;
124124
} else {
125125
double dqueryMaxError = (Double) readDMLScalarFromOutputDir("dquery_error").values().toArray()[0];
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
-0.366342,-0.367447,-0.368428,0.043807,-0.098173,-0.287969
2+
-0.352497,-0.027552,0.258298,-0.085474,-0.085857,0.018208
3+
-0.063882,0.359021,-0.047110,0.291535,-0.336430,-0.287788
4+
0.005280,-0.166521,-0.182245,0.113960,0.221207,-0.224734
5+
-0.185457,0.368649,0.326457,0.196166,0.324140,-0.237889
6+
0.153787,0.147849,-0.329904,0.144177,0.279334,0.139517
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
0.014873,0.127848,-0.276551,-0.306393,0.044137,0.116741,0.004873,-0.350424
2+
0.227909,0.044769,-0.185308,0.175143,0.316675,0.265246,-0.060110,0.159592
3+
-0.267258,-0.002632,0.285492,-0.251829,0.216273,-0.113814,-0.186207,-0.169799
4+
-0.242719,-0.069891,-0.286925,-0.100361,-0.223521,0.000566,0.046730,-0.235940
5+
-0.205295,0.044359,-0.025387,-0.118623,0.158570,0.182018,0.292360,-0.203683
6+
0.247464,-0.080732,0.349749,-0.052357,-0.249925,-0.341919,-0.103351,0.203278
7+
-0.127090,-0.002484,0.127717,0.003867,-0.149845,0.255612,-0.209903,0.187233
8+
0.298218,0.045111,0.010010,0.291613,0.103988,-0.292361,-0.130758,0.271360
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
0.160208,-0.102197,-0.265960,0.082622,-0.366738,-0.382060
2+
0.024454,-0.198863,-0.020951,0.259580,-0.193541,-0.344565
3+
-0.199196,-0.142819,0.292245,0.386712,0.277978,-0.082808
4+
0.193179,-0.334609,-0.041968,0.259260,-0.002646,0.223886
5+
-0.391612,-0.086841,0.011346,0.387596,-0.202918,0.220716
6+
-0.241971,0.087266,-0.035219,-0.029525,-0.312845,-0.393728
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-0.123685,0.009826,-0.317605,-0.071714,-0.068100,-0.318219,-0.040798,0.169884
2+
-0.289780,-0.030501,-0.167612,0.193891,-0.069418,-0.023860,-0.157829,0.124861
3+
-0.075206,0.071553,0.240736,0.191146,-0.317261,0.310922,0.282720,-0.085021
4+
0.075574,0.224803,-0.002292,-0.340978,-0.305271,-0.144212,-0.285705,-0.074354
5+
-0.230328,0.334902,-0.175732,0.220540,-0.055324,0.319260,0.037938,-0.291357
6+
-0.018144,0.224526,-0.270932,-0.276659,0.004572,0.128041,-0.074023,0.191571
7+
0.253091,0.335668,-0.330874,-0.074745,-0.160610,-0.319068,0.252477,0.280714
8+
-0.036345,-0.025570,-0.298402,-0.143356,0.133183,0.223692,0.098693,0.241910
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
-0.237054,-0.003039,-0.319334,0.147474,-0.136974,0.249731
2+
0.285747,-0.080703,-0.213976,0.011559,-0.060456,-0.258100
3+
-0.146751,0.051221,0.329658,-0.353792,0.004466,0.183101
4+
0.344353,-0.093221,-0.331313,0.202237,0.336726,-0.288589
5+
0.147626,-0.002869,-0.029315,-0.290787,0.050965,-0.173026
6+
0.051695,0.052090,0.403855,-0.115887,0.365665,0.120075
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-0.334274,0.011264,-0.080849,0.244498,0.069431,0.122792,0.029533,0.165114
2+
0.076385,-0.237900,-0.015723,-0.263194,-0.262510,-0.129004,0.044147,-0.171997
3+
-0.198408,-0.285785,-0.215330,0.144839,0.058866,0.134202,-0.277945,-0.292986
4+
-0.315220,0.281811,0.119572,-0.118884,0.150589,0.235453,0.027785,-0.304028
5+
0.310023,0.057572,0.111782,-0.170578,0.139947,-0.184608,0.244825,0.352708
6+
-0.229602,0.293317,-0.007293,0.063514,-0.044505,0.003487,0.318592,0.224432
7+
-0.040221,-0.118525,-0.079515,-0.183656,-0.289839,0.146194,0.207801,-0.244388
8+
0.101291,0.104141,-0.217941,0.081460,-0.054502,0.027711,0.047377,0.138325
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
0.295156,0.337588,-0.196067,0.148081,-0.193161,0.357983
2+
-0.337590,-0.119339,-0.272441,-0.136338,-0.191908,-0.265121
3+
0.005627,-0.242375,-0.235192,-0.114084,-0.385986,-0.046443
4+
-0.069409,-0.150986,0.234725,0.120609,0.088201,0.116961
5+
-0.215013,-0.404635,0.216198,0.335594,-0.229102,0.013006
6+
0.053959,0.184281,0.313339,0.111000,-0.363984,-0.274703

0 commit comments

Comments
 (0)