Skip to content

Commit 1d7b500

Browse files
Backward pass
1 parent c8d9969 commit 1d7b500

File tree

1 file changed

+36
-5
lines changed

1 file changed

+36
-5
lines changed

scripts/nn/layers/embedding.dml

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,50 @@ forward = function(matrix[double] indices, matrix[double] embedding_dict)
3636
* embedding_dict[indices[i]].
3737
*/
3838
n = nrow(indices)
39-
d = ncol(embedding_dict)
39+
v = nrow(embedding_dict)
4040

4141
# Construct permutation-like matrix (one '1' per row, rest '0')
42-
permutation = matrix(0, rows=n, cols=d)
42+
permutation = matrix(0, rows=n, cols=v)
4343
for (i in 1:n) {
4444
permutation[i, as.integer(as.scalar(indices[i]))] = 1
4545
}
4646

4747
embeddings = permutation %*% embedding_dict
4848
}
4949

50-
backward = function()
51-
return () {}
50+
backward = function(matrix[double] dout, matrix[double] indices, int v,
51+
int padding_idx = -1)
52+
return (matrix[double] dembedding_dict) {
53+
/*
54+
* Backward pass of embedding layer computes the gradients of the embedding
55+
* dictionary.
56+
*
57+
* Inputs:
58+
* - dout: Gradient of the output.
59+
* - indices: Indices referring to embedding vectors of embedding dictionary
60+
* of shape n x 1 with each value in {1, ..., v}.
61+
* - v: Embedding dictionary size.
62+
* - padding_idx: Index of embedding vector of embedding dictionary which
63+
* should not be updated (i.e. gradients are 0). Use -1 if
64+
* there is no padding vector.
65+
*
66+
* Outputs:
67+
* - dembedding_dict: Gradients of the dictionary of embedding vectors of
68+
* shape v x d.
69+
*/
70+
n = nrow(indices)
71+
72+
# Construct permutation-like matrix (one '1' per row, rest '0')
73+
permutation = matrix(0, rows=n, cols=v)
74+
for (i in 1:n) {
75+
permutation[i, as.integer(as.scalar(indices[i]))] = 1
76+
}
77+
78+
dembedding_dict = t(permutation) %*% dout
79+
if (padding_idx != -1) {
80+
dembedding_dict[padding_idx] = matrix(0, rows=1, cols=ncol(dout))
81+
}
82+
}
5283

5384
init = function(int v, int d, int seed = -1)
5485
return (matrix[double] embedding_dict) {
@@ -58,7 +89,7 @@ init = function(int v, int d, int seed = -1)
5889
* Inputs:
5990
* - v: Embedding dictionary size.
6091
* - d: Embedding vector dimension.
61-
* - seed: Optional random generation seed.
92+
* - seed: Random generation seed.
6293
*
6394
* Output:
6495
* - embedding_dict: Embedding dictionary matrix of shape v x d.

0 commit comments

Comments
 (0)