@@ -36,19 +36,50 @@ forward = function(matrix[double] indices, matrix[double] embedding_dict)
3636 * embedding_dict[indices[i]].
3737 */
3838 n = nrow(indices)
39- d = ncol (embedding_dict)
39+ v = nrow (embedding_dict)
4040
4141 # Construct permutation-like matrix (one '1' per row, rest '0')
42- permutation = matrix(0, rows=n, cols=d )
42+ permutation = matrix(0, rows=n, cols=v )
4343 for (i in 1:n) {
4444 permutation[i, as.integer(as.scalar(indices[i]))] = 1
4545 }
4646
4747 embeddings = permutation %*% embedding_dict
4848}
4949
50- backward = function()
51- return () {}
50+ backward = function(matrix[double] dout, matrix[double] indices, int v,
51+ int padding_idx = -1)
52+ return (matrix[double] dembedding_dict) {
53+ /*
54+ * Backward pass of embedding layer computes the gradients of the embedding
55+ * dictionary.
56+ *
57+ * Inputs:
58+ * - dout: Gradient of the output.
59+ * - indices: Indices referring to embedding vectors of embedding dictionary
60+ * of shape n x 1 with each value in {1, ..., v}.
61+ * - v: Embedding dictionary size.
62+ * - padding_idx: Index of embedding vector of embedding dictionary which
63+ * should not be updated (i.e. gradients are 0). Use -1 if
64+ * there is no padding vector.
65+ *
66+ * Outputs:
67+ * - dembedding_dict: Gradients of the dictionary of embedding vectors of
68+ * shape v x d.
69+ */
70+ n = nrow(indices)
71+
72+ # Construct permutation-like matrix (one '1' per row, rest '0')
73+ permutation = matrix(0, rows=n, cols=v)
74+ for (i in 1:n) {
75+ permutation[i, as.integer(as.scalar(indices[i]))] = 1
76+ }
77+
78+ dembedding_dict = t(permutation) %*% dout
79+ if (padding_idx != -1) {
80+ dembedding_dict[padding_idx] = matrix(0, rows=1, cols=ncol(dout))
81+ }
82+ }
5283
5384init = function(int v, int d, int seed = -1)
5485 return (matrix[double] embedding_dict) {
@@ -58,7 +89,7 @@ init = function(int v, int d, int seed = -1)
5889 * Inputs:
5990 * - v: Embedding dictionary size.
6091 * - d: Embedding vector dimension.
61- * - seed: Optional random generation seed.
92+ * - seed: Random generation seed.
6293 *
6394 * Output:
6495 * - embedding_dict: Embedding dictionary matrix of shape v x d.
0 commit comments