Skip to content

Commit c8d9969

Browse files
Forward pass
1 parent 5b541c9 commit c8d9969

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

scripts/nn/layers/embedding.dml

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,49 @@
1919
#
2020
#-------------------------------------------------------------
2121

22-
forward = function()
23-
return () {}
22+
forward = function(matrix[double] indices, matrix[double] embedding_dict)
23+
return (matrix[double] embeddings) {
24+
/*
25+
* Forward pass of an embedding layer. An embedding matrix is constructed
26+
* from indices and corresponding embedding vectors from the embedding
27+
* dictionary.
28+
*
29+
* Inputs:
30+
* - indices: Indices referring to embedding vectors of embedding dictionary
31+
* of shape n x 1 with each value in {1, ..., v}.
32+
* - embedding_dict: Dictionary of embedding vectors of shape v x d.
33+
*
34+
* Outputs:
35+
* - embeddings: Embedding matrix where row i is equal to
36+
* embedding_dict[indices[i]].
37+
*/
38+
n = nrow(indices)
39+
d = ncol(embedding_dict)
40+
41+
# Construct permutation-like matrix (one '1' per row, rest '0')
42+
permutation = matrix(0, rows=n, cols=d)
43+
for (i in 1:n) {
44+
permutation[i, as.integer(as.scalar(indices[i]))] = 1
45+
}
46+
47+
embeddings = permutation %*% embedding_dict
48+
}
2449

2550
backward = function()
2651
return () {}
2752

2853
init = function(int v, int d, int seed = -1)
2954
return (matrix[double] embedding_dict) {
3055
/*
31-
* Initializes embedding dictionary matrix via N(0, 1)
56+
* Initializes embedding dictionary matrix via N(0, 1).
3257
*
3358
* Inputs:
3459
* - v: Embedding dictionary size.
3560
* - d: Embedding vector dimension.
3661
* - seed: Optional random generation seed.
3762
*
3863
* Output:
39-
* - embedding_dict: Embedding dictionary matrix of shape v x d
64+
* - embedding_dict: Embedding dictionary matrix of shape v x d.
4065
*/
4166
embedding_dict = rand(rows=v, cols=d, pdf="normal", seed=seed)
4267
}

0 commit comments

Comments
 (0)