Skip to content

Commit f8d50de

Browse files
committed
Add embedding layer.
1 parent 2e869dc commit f8d50de

File tree

3 files changed

+45
-1
lines changed

3 files changed

+45
-1
lines changed

DeepFried2/layers/Embedding.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import DeepFried2 as df
2+
3+
4+
class Embedding(df.Module):
5+
def __init__(self, ntok, ndim, init=df.init.ortho_svd()):
6+
"""A layer that learns `ntok` embedding vectors of dimension `ndim`.
7+
8+
Note that this doesn't take care of any `unk`/`oov` token.
9+
If you need that, increase `ntok` by one and handle it from the outside.
10+
11+
The input to this layer is an array of indices of any arbitrary shape.
12+
The output will be an array of the same shape plus an added dimension
13+
for the embeddings at the end (as last dimension).
14+
"""
15+
df.Module.__init__(self)
16+
17+
self.ndim = ndim
18+
self.W = self._addparam((ntok, ndim), init, name='Wemb_{}x{}'.format(ntok, ndim))
19+
20+
def symb_forward(self, symb_input):
21+
return self.W.param[symb_input.flatten()].reshape(tuple(symb_input.shape) + (self.ndim,))

DeepFried2/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from .AddConstant import AddConstant
22
from .BatchNormalization import BatchNormalization
3+
from .Bias import Bias
34
from .Dropout import Dropout
5+
from .Embedding import Embedding
46
from .Identity import Identity
57
from .Linear import Linear
6-
from .Bias import Bias
78
from .Log import Log
89
from .ReLU import ReLU
910
from .ELU import ELU
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/usr/bin/env python3
2+
3+
import DeepFried2 as df
4+
5+
import unittest
6+
import numpy as np
7+
8+
class TestEmbedding(unittest.TestCase):
9+
10+
def testForward(self):
11+
X = np.array([
12+
[0, 1, 2],
13+
[2, 1, 0],
14+
])
15+
16+
Z = np.array([
17+
[[1,0,0,0], [0,1,0,0], [0,0,1,0]],
18+
[[0,0,1,0], [0,1,0,0], [1,0,0,0]],
19+
])
20+
21+
Y = df.Embedding(ntok=3, ndim=4, init=df.init.eye()).forward(X)
22+
np.testing.assert_array_equal(Y, Z)

0 commit comments

Comments
 (0)