@@ -37,6 +37,17 @@ linear_tensor_forward = function(matrix[double] X, matrix[double] W, matrix[doub
3737 out = matrix(out, rows=A, cols=B*C_new)
3838}
3939
40+ linear_tensor_backward = function(matrix[double] dout, matrix[double] X, matrix[double] W, matrix[double] b, int B,
41+ int C_out, int C_in)
42+ return (matrix[double] dX, matrix[double] dW, matrix[double] db) {
43+ /*
44+ * Helper function for computing linear layer with tensor input, of shape (A, B*C)
45+ */
46+ A = nrow(X)
47+ [dX, dW, db] = affine::backward(matrix(dout, rows=A*B, cols=C_out), matrix(X, rows=A*B, cols=C_in), W, b)
48+ dX = matrix(dX, rows=A, cols=B*C_in)
49+ }
50+
4051layer_norm_forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta, double epsilon, int B, int C)
4152 return (matrix[double] out, matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm) {
4253 /*
@@ -51,6 +62,27 @@ layer_norm_forward = function(matrix[double] X, matrix[double] gamma, matrix[dou
5162 out = matrix(t(batch_norm_out), rows=A, cols=B*C)
5263}
5364
65+ layer_norm_backward = function(matrix[double] dout, matrix[double] cache_mean, matrix[double] cache_var,
66+ matrix[double] cache_norm, matrix[double] X, matrix[double] gamma, matrix[double] beta, double epsilon, int B, int C)
67+ return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
68+ /*
69+ * Helper function for computing layer norm via 1D batch norm with tensor input, of shpae (A, B*C)
70+ */
71+ A = nrow(X)
72+ batch_norm_input = t(matrix(X, rows=A*B, cols=C))
73+ batch_norm_doutput = t(matrix(dout, rows=A*B, cols=C))
74+ # EMA matrices, updated EMA matrices and out matrix are unused and thus empty matrices will be provided
75+ empty_mat = matrix(0, rows=1, cols=1)
76+ [batch_norm_dX, unused1, unused2] = batch_norm::backward(
77+ batch_norm_doutput,
78+ empty_mat, empty_mat, empty_mat,
79+ cache_mean, cache_var, cache_norm,
80+ batch_norm_input, t(gamma), t(beta), "train", empty_mat, empty_mat, 0.0, epsilon)
81+ dX = matrix(t(batch_norm_dX), rows=A, cols=B*C)
82+ dgamma = t(rowSums(batch_norm_doutput * cache_norm))
83+ dbeta = t(rowSums(batch_norm_doutput))
84+ }
85+
5486forward = function(matrix[double] states,
5587 int H, int T, int d, int I,
5688 matrix[double] W_Q, matrix[double] b_Q,
@@ -184,3 +216,153 @@ forward = function(matrix[double] states,
184216 [out_states, cache_mean_ln2, cache_var_ln2, cache_norm_ln2] = layer_norm_forward(
185217 out_states, gamma_ln2, beta_ln2, epsilon_ln, T, D)
186218}
219+
220+ backward = function(matrix[double] dout_states,
221+ matrix[double] dropout_mask_attention,
222+ matrix[double] dropout_mask_output_1,
223+ matrix[double] dropout_mask_output_2,
224+ matrix[double] cache_mean_ln1, matrix[double] cache_var_ln1, matrix[double] cache_norm_ln1,
225+ matrix[double] cache_mean_ln2, matrix[double] cache_var_ln2, matrix[double] cache_norm_ln2,
226+ list[unknown] outputs,
227+ matrix[double] states,
228+ int H, int T, int d, int I,
229+ matrix[double] W_Q, matrix[double] b_Q,
230+ matrix[double] W_K, matrix[double] b_K,
231+ matrix[double] W_V, matrix[double] b_V,
232+ matrix[double] W_context, matrix[double] b_context,
233+ matrix[double] W_intermediate, matrix[double] b_intermediate,
234+ matrix[double] W_out, matrix[double] b_out,
235+ double dropout_p_attention,
236+ double dropout_p_output,
237+ double epsilon_ln,
238+ matrix[double] gamma_ln1, matrix[double] beta_ln1,
239+ matrix[double] gamma_ln2, matrix[double] beta_ln2,
240+ string activation)
241+ return (matrix[double] din_states,
242+ matrix[double] dW_Q, matrix[double] db_Q,
243+ matrix[double] dW_K, matrix[double] db_K,
244+ matrix[double] dW_V, matrix[double] db_V,
245+ matrix[double] dW_context, matrix[double] db_context,
246+ matrix[double] dW_intermediate, matrix[double] db_intermediate,
247+ matrix[double] dW_out, matrix[double] db_out,
248+ matrix[double] dgamma_ln1, matrix[double] dbeta_ln1,
249+ matrix[double] dgamma_ln2, matrix[double] dbeta_ln2) {
250+ /*
251+ * Computes the backward pass for a layer of the BERT transformer architecture.
252+ *
253+ * Inputs (B: Batch size, T: Sequence length, D: Embedding length, H: Heads):
254+ * - dout_states: Gradients w.r.t. output states, of shape (B, T*D)
255+ * - dropout_mask_attention: Dropout mask used on attention, of shape (B, H*T*T)
256+ * - dropout_mask_output_1: Dropout mask used on attention output, of shape (B, T*D)
257+ * - dropout_mask_output_2: Dropout mask used on attention output, of shape (B, T*D)
258+ * - cache_mean_ln1: Cached mean from layer norm 1, of shape (1, B*T)
259+ * - cache_var_ln1: Cached mean from layer norm 1, of shape (1, B*T)
260+ * - cache_norm_ln1: Cached mean from layer norm 1, of shape (1, B*T)
261+ * - cache_mean_ln2: Cached mean from layer norm 2, of shape (1, B*T)
262+ * - cache_var_ln2: Cached mean from layer norm 2, of shape (1, B*T)
263+ * - cache_norm_ln2: Cached mean from layer norm 2, of shape (1, B*T)
264+ * - outputs: list of relevant outputs from forward pass
265+ * with the following order/content:
266+ * -> 1: Output of linear query layer, of shape (B, T*D).
267+ * -> 2: Output of linear key layer, of shape (B, T*D).
268+ * -> 3: Output of linear value layer, of shape (B, T*D).
269+ * -> 4: Output context of attention layer, of shape (B, T*D).
270+ * -> 5: Output attention of attention layer, of shape (B, T*D).
271+ * -> 6: Output of residual pass 1, of shape (B, T*D).
272+ * -> 7: Output of layer norm 1, of shape (B, T*D).
273+ * -> 8: Output of intermediate linear layer, of shape (B, T*I).
274+ * -> 9: Output of activation layer, of shape (B, T*I).
275+ * -> 10: Output of residual pass 2, of shape (B, T*D).
276+ * - states: Hidden states, of shape (B, T*D).
277+ * - H: Head count.
278+ * - T: Sequence length.
279+ * - d: Embedding length of single token per head with d*H = D.
280+ * - I: Intemediate embedding length.
281+ * - W_Q: Weights for linear query layer, of shape (D, D).
282+ * - b_Q: Biases for linear query layer, of shape (1, D).
283+ * - W_K: Weights for linear key layer, of shape (D, D).
284+ * - b_K: Biases for linear key layer, of shape (1, D).
285+ * - W_V: Weights for linear value layer, of shape (D, D).
286+ * - b_V: Biases for linear value layer, of shape (1, D).
287+ * - W_context: Weights for linear output layer on context, of shape (D, D).
288+ * - b_context: Biases for linear output layer on context, of shape (1, D).
289+ * - W_intermediate: Weights for intermediate linear layer, of shape (D, I).
290+ * - b_intermediate: Biases for intermediate linear layer, of shape (1, I).
291+ * - W_out: Weights for last linear output layer, of shape (D, D).
292+ * - b_out: Biases for last linear output layer, of shape (1, D).
293+ * - dropout_p_attention: Probability for dropout on attention.
294+ * - dropout_p_output: Probability for dropout on output.
295+ * - epsilon_ln: Epsilon value for layer norm.
296+ * - gamma_ln1: Gamma params for layer norm 1, of shape (1, D).
297+ * - beta_ln1: Beta params for layer norm 1, of shape (1, D).
298+ * - gamma_ln2: Gamma params for layer norm 2, of shape (1, D).
299+ * - beta_ln2: Beta params for layer norm 2, of shape (1, D).
300+ * - activation: String specifying type of activation to use.
301+ * Can be tanh or gelu.
302+ *
303+ * Outputs:
304+ * - din_states: Gradients w.r.t. hidden input states, of shape (B, T*D).
305+ * - W_Q: Gradients w.r.t. weights for linear query layer, of shape (D, D).
306+ * - b_Q: Gradients w.r.t. biases for linear query layer, of shape (1, D).
307+ * - W_K: Gradients w.r.t. weights for linear key layer, of shape (D, D).
308+ * - b_K: Gradients w.r.t. biases for linear key layer, of shape (1, D).
309+ * - W_V: Gradients w.r.t. weights for linear value layer, of shape (D, D).
310+ * - b_V: Gradients w.r.t. biases for linear value layer, of shape (1, D).
311+ * - W_context: Gradients w.r.t. weights for linear output layer on context, of shape (D, D).
312+ * - b_context: Gradients w.r.t. biases for linear output layer on context, of shape (1, D).
313+ * - W_intermediate: Gradients w.r.t. weights for intermediate linear layer, of shape (D, I).
314+ * - b_intermediate: Gradients w.r.t. biases for intermediate linear layer, of shape (1, I).
315+ * - W_out: Gradients w.r.t. weights for last linear output layer, of shape (D, D).
316+ * - b_out: Gradients w.r.t. biases for last linear output layer, of shape (1, D).
317+ */
318+ # Embedding dim
319+ D = d * H
320+
321+ # Layer norm 2 for each token
322+ [dout_states, dgamma_ln2, dbeta_ln2] = layer_norm_backward(
323+ dout_states, cache_mean_ln2, cache_var_ln2, cache_norm_ln2, as.matrix(outputs[10]), gamma_ln2, beta_ln2, epsilon_ln, T, D)
324+ # Save dout_states for residual pass
325+ dout_states_identity_2 = dout_states
326+ # Dropout on output 2
327+ if (dropout_p_output > 0.0) {
328+ dout_states = dropout::backward(dout_states, matrix(0, 1, 1), dropout_p_output, dropout_mask_output_2)
329+ }
330+ # Final linear output layer
331+ [dout_states, dW_out, db_out] = linear_tensor_backward(dout_states, as.matrix(outputs[9]), W_out, b_out, T, D, I)
332+
333+ # Activation
334+ if (activation == "gelu") {
335+ dout_states = gelu::backward(dout_states, as.matrix(outputs[8]))
336+ } else if (activation == "tanh") {
337+ dout_states = tanh::backward(dout_states, as.matrix(outputs[8]))
338+ }
339+ # Linear layer of intermediate part
340+ [dout_states, dW_intermediate, db_intermediate] = linear_tensor_backward(dout_states, as.matrix(outputs[7]), W_intermediate,
341+ b_intermediate, T, I, D)
342+ # Residual pass 2
343+ dout_states = dout_states + dout_states_identity_2
344+
345+ # Layer norm 1 for each token
346+ [dout_states, dgamma_ln1, dbeta_ln1] = layer_norm_backward(
347+ dout_states, cache_mean_ln1, cache_var_ln1, cache_norm_ln1, as.matrix(outputs[6]), gamma_ln1, beta_ln1, epsilon_ln, T, D)
348+ # Save dout_states for residual pass
349+ dout_states_identity_1 = dout_states
350+
351+ # Dropout on output 1
352+ if (dropout_p_output > 0.0) {
353+ dout_states = dropout::backward(dout_states, matrix(0, 1, 1), dropout_p_output, dropout_mask_output_1)
354+ }
355+ # Linear layer on attention output (output layer)
356+ [dcontext, dW_context, db_context] = linear_tensor_backward(dout_states, as.matrix(outputs[4]), W_context, b_context, T, D, D)
357+
358+ # Multi-head self attention
359+ [dQ, dK, dV] = attention::backward(dcontext, dropout_mask_attention, as.matrix(outputs[5]), as.matrix(outputs[1]),
360+ as.matrix(outputs[2]), as.matrix(outputs[3]), H, T, d, dropout_p_attention)
361+
362+ # Linear layers for Q, K, V
363+ [dstates_Q, dW_Q, db_Q] = linear_tensor_backward(dQ, states, W_Q, b_Q, T, D, D)
364+ [dstates_K, dW_K, db_K] = linear_tensor_backward(dK, states, W_K, b_K, T, D, D)
365+ [dstates_V, dW_V, db_V] = linear_tensor_backward(dV, states, W_V, b_V, T, D, D)
366+ # Add paths + residual pass 1
367+ din_states = dstates_Q + dstates_K + dstates_V + dout_states_identity_1
368+ }
0 commit comments