Skip to content

Commit 7a63960

Browse files
authored
Release/2.0 beta fix emb (#27150)
* fix weight * fix weight and fix doc * fix embeeding padding idx * add UT * fix interval
1 parent 2986184 commit 7a63960

File tree

4 files changed

+104
-44
lines changed

4 files changed

+104
-44
lines changed

python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,49 @@
1616

1717
import unittest
1818

19+
import paddle
20+
import paddle.nn as nn
21+
import numpy as np
22+
23+
paddle.disable_static()
24+
1925

2026
class EmbeddingDygraph(unittest.TestCase):
2127
def test_1(self):
22-
import paddle
23-
import paddle.nn as nn
24-
import numpy as np
25-
paddle.disable_static()
28+
x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
29+
y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32)
30+
paddle.disable_static(paddle.CPUPlace())
31+
x = paddle.to_tensor(x_data, stop_gradient=False)
32+
y = paddle.to_tensor(y_data, stop_gradient=False)
33+
34+
embedding = paddle.nn.Embedding(10, 3, sparse=True)
35+
36+
w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32)
37+
embedding.weight.set_value(w0)
38+
39+
adam = paddle.optimizer.Adam(
40+
parameters=[embedding.weight], learning_rate=0.01)
41+
adam.clear_grad()
42+
43+
out = embedding(x)
44+
out.backward()
45+
adam.step()
46+
47+
def test_2(self):
48+
x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
49+
y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32)
50+
paddle.disable_static(paddle.CPUPlace())
51+
x = paddle.to_tensor(x_data, stop_gradient=False)
52+
y = paddle.to_tensor(y_data, stop_gradient=False)
53+
54+
with self.assertRaises(ValueError):
55+
embedding = paddle.nn.Embedding(10, 3, padding_idx=11, sparse=True)
2656

27-
# example 1
28-
inp_word = np.array([[2, 3, 5], [4, 2, 1]]).astype('int64')
29-
inp_word.shape # [2, 3]
30-
dict_size = 20
57+
with self.assertRaises(ValueError):
58+
embedding = paddle.nn.Embedding(-1, 3, sparse=True)
3159

32-
emb = nn.Embedding(dict_size, 32, weight_attr='emb.w', sparse=False)
60+
with self.assertRaises(ValueError):
61+
embedding = paddle.nn.Embedding(10, -3, sparse=True)
3362

3463

3564
if __name__ == '__main__':

python/paddle/fluid/tests/unittests/test_nn_functional_embedding_static.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,13 @@ def test_bad_x():
7373
dtype="int32")
7474

7575
emb = functional.embedding(
76-
x=label, weight=weight, sparse=True, name="embedding")
76+
x=label,
77+
weight=weight,
78+
padding_idx=129,
79+
sparse=True,
80+
name="embedding")
7781

82+
with self.assertRaises(ValueError):
7883
test_bad_x()
7984

8085

python/paddle/nn/functional/input.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,18 @@ def one_hot(x, num_classes, name=None):
113113

114114
def embedding(x, weight, padding_idx=None, sparse=False, name=None):
115115
"""
116-
The operator is used to lookup embeddings vector of ids provided by :attr:`input` .
116+
The operator is used to lookup embeddings vector of ids provided by :attr:`x` .
117117
118118
The shape of output Tensor is generated by appending the last dimension of the input Tensor shape
119119
with embedding size.
120-
**Note:** The id in :attr:`input` must satisfy :math:`0 =< id < weight.shape[0]` ,
120+
121+
**Note:** The id in :attr:`x` must satisfy :math:`0 =< id < weight.shape[0]` ,
121122
otherwise the program will throw an exception and exit.
122123
123124
.. code-block:: text
124125
125126
Case 1:
126-
input is a Tensor.
127+
x is a Tensor.
127128
padding_idx = -1
128129
x.data = [[1, 3], [2, 4], [4, 127]]
129130
x.shape = [3, 2]
@@ -138,7 +139,7 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
138139
[0.0, 0.0, ..., 0.0 ]]] # padding data
139140
140141
The input padding_idx is less than 0, it is automatically converted to padding_idx = -1 + 128 = 127
141-
It will pad all-zero data when ids is 127.
142+
It will pad all-zero data when id is 127.
142143
143144
Args:
144145
x(Tensor): A Tensor with type int32/int64, which contains the id information. The value of the input id should
@@ -151,18 +152,18 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
151152
such as :ref:`api_optimizer_AdadeltaOptimizer` , :ref:`api_optimizer_AdamaxOptimizer` ,
152153
:ref:`api_optimizer_DecayedAdagradOptimizer` , :ref:`api_optimizer_FtrlOptimizer` ,
153154
:ref:`api_optimizer_LambOptimizer` and :ref:`api_optimizer_LarsMomentumOptimizer` .
154-
In these cases, is_sparse must be False. Default: False.
155-
padding_idx(int|long|None): padding_idx needs to be in the interval [-vocab_size, vocab_size).
155+
In these cases, sparse must be False. Default: False.
156+
padding_idx(int|long|None): padding_idx needs to be in the interval [-weight.shape[0], weight.shape[0]).
156157
If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted
157-
to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup
158+
to :math:`weight.shape[0] + padding\_idx` . It will output all-zero padding data whenever lookup
158159
encounters :math:`padding\_idx` in id. And the padding data will not be updated while training.
159160
If set None, it makes no effect to output. Default: None.
160161
name(str|None): For detailed information, please refer
161162
to :ref:`api_guide_Name`. Usually name is no need to set and
162163
None by default.
163164
164165
Returns:
165-
Tensor: Embedding Tensor mapped by input. The data type is the same as :attr:`weight`.
166+
Tensor: Embedding Tensor mapped by x. The data type is the same as :attr:`weight`.
166167
167168
Examples:
168169
@@ -209,6 +210,10 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
209210
padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
210211
weight.shape[0] + padding_idx)
211212

213+
if padding_idx >= weight.shape[0] or padding_idx < -weight.shape[0]:
214+
raise ValueError("padding_idx must be within [-{}, {})".format(
215+
weight.shape[0], weight.shape[0]))
216+
212217
helper.append_op(
213218
type='lookup_table_v2',
214219
inputs={'Ids': x,

python/paddle/nn/layer/common.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,22 +1551,18 @@ def forward(self, x1, x2):
15511551

15521552
class Embedding(layers.Layer):
15531553
"""
1554-
:alias_main: paddle.nn.Embedding
1555-
:alias: paddle.nn.Embedding,paddle.nn.layer.Embedding,paddle.nn.layer.common.Embedding
1556-
:old_api: paddle.fluid.dygraph.Embedding
1557-
15581554
**Embedding Layer**
15591555
15601556
This interface is used to construct a callable object of the ``Embedding`` class.
15611557
For specific usage, refer to code examples. It implements the function of the Embedding Layer.
1562-
This layer is used to lookup embeddings vector of ids provided by :attr:`input` .
1558+
This layer is used to lookup embeddings vector of ids provided by :attr:`x` .
15631559
It automatically constructs a 2D embedding matrix based on the
1564-
input :attr:`size` (vocab_size, emb_size) and :attr:`dtype` .
1560+
input :attr:`num_embeddings` and attr:`embedding_dim`.
15651561
15661562
The shape of output Tensor is generated by appending an emb_size dimension to the
15671563
last dimension of the input Tensor shape.
15681564
1569-
**Note:** The id in :attr:`input` must satisfy :math:`0 =< id < size[0]` ,
1565+
**Note:** The id in :attr:`x` must satisfy :math:`0 =< id < num_embeddings` ,
15701566
otherwise the program will throw an exception and exit.
15711567
15721568
.. code-block:: text
@@ -1594,7 +1590,7 @@ class Embedding(layers.Layer):
15941590
num_embeddings (int): Just one element which indicate the size
15951591
of the dictionary of embeddings.
15961592
embedding_dim: Just one element which indicate the size of each embedding vector respectively.
1597-
padding_idx(int|long|None): padding_idx needs to be in the interval [-vocab_size, vocab_size).
1593+
padding_idx(int|long|None): padding_idx needs to be in the interval [-num_embeddings, num_embeddings).
15981594
If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted
15991595
to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup
16001596
encounters :math:`padding\_idx` in id. And the padding data will not be updated while training.
@@ -1605,13 +1601,13 @@ class Embedding(layers.Layer):
16051601
such as :ref:`api_optimizer_AdadeltaOptimizer` , :ref:`api_optimizer_AdamaxOptimizer` ,
16061602
:ref:`api_optimizer_DecayedAdagradOptimizer` , :ref:`api_optimizer_FtrlOptimizer` ,
16071603
:ref:`api_optimizer_LambOptimizer` and :ref:`api_optimizer_LarsMomentumOptimizer` .
1608-
In these case, is_sparse must be False. Default: False.
1604+
In these case, sparse must be False. Default: False.
16091605
weight_attr(ParamAttr): To specify the weight parameter property. Default: None, which means the
1610-
default weight parameter property is used. See usage for details in :ref:`api_fluid_ParamAttr` . In addition,
1606+
default weight parameter property is used. See usage for details in :ref:`api_ParamAttr` . In addition,
16111607
user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter.
16121608
The local word vector needs to be transformed into numpy format, and the shape of local word
1613-
vector should be consistent with :attr:`size` . Then :ref:`api_fluid_initializer_NumpyArrayInitializer`
1614-
is used to load custom or pre-trained word vectors. See code example 2 for details.
1609+
vector should be consistent with :attr:`num_embeddings` . Then :ref:`api_initializer_NumpyArrayInitializer`
1610+
is used to load custom or pre-trained word vectors. See code example for details.
16151611
name(str|None): For detailed information, please refer
16161612
to :ref:`api_guide_Name`. Usually name is no need to set and
16171613
None by default.
@@ -1626,20 +1622,34 @@ class Embedding(layers.Layer):
16261622
16271623
.. code-block:: python
16281624
1629-
import paddle
1630-
import paddle.nn as nn
1631-
import numpy as np
1632-
paddle.disable_static()
1625+
import paddle
1626+
import numpy as np
1627+
1628+
x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
1629+
y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32)
1630+
paddle.disable_static(paddle.CPUPlace())
1631+
x = paddle.to_tensor(x_data, stop_gradient=False)
1632+
y = paddle.to_tensor(y_data, stop_gradient=False)
1633+
1634+
embedding = paddle.nn.Embedding(10, 3, sparse=True)
1635+
1636+
w0=np.full(shape=(10, 3), fill_value=2).astype(np.float32)
1637+
embedding.weight.set_value(w0)
16331638
1634-
# example 1
1635-
inp_word = np.array([[2, 3, 5], [4, 2, 1]]).astype('int64')
1636-
inp_word.shape # [2, 3]
1637-
dict_size = 20
1639+
adam = paddle.optimizer.Adam(parameters=[embedding.weight], learning_rate=0.01)
1640+
adam.clear_grad()
1641+
1642+
# weight.shape = [10, 3]
1643+
1644+
# x.data = [[3],[4],[5]]
1645+
# x.shape = [3, 1]
1646+
1647+
# out.data = [[2,2,2], [2,2,2], [2,2,2]]
1648+
# out.shape = [3, 1, 3]
1649+
out=embedding(x)
1650+
out.backward()
1651+
adam.step()
16381652
1639-
emb = nn.Embedding(
1640-
dict_size,
1641-
32,
1642-
sparse=False)
16431653
"""
16441654

16451655
def __init__(self,
@@ -1656,13 +1666,24 @@ def __init__(self,
16561666
self._is_distributed = False
16571667
self._padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
16581668
num_embeddings + padding_idx)
1669+
1670+
if self._num_embeddings <= 0:
1671+
raise ValueError("num_embeddings must be gather than 0")
1672+
1673+
if self._embedding_dim <= 0:
1674+
raise ValueError("embedding_dim must be gather than 0")
1675+
1676+
if self._padding_idx >= num_embeddings or self._padding_idx < -num_embeddings:
1677+
raise ValueError("padding_idx must be within [-{}, {})".format(
1678+
num_embeddings, num_embeddings))
1679+
16591680
self._dtype = self._helper.get_default_dtype()
16601681
self._size = [self._num_embeddings, self._embedding_dim]
16611682

16621683
self._weight_attr = weight_attr
16631684
self._remote_prefetch = False
16641685
self._name = name
1665-
self._weight = self.create_parameter(
1686+
self.weight = self.create_parameter(
16661687
attr=self._weight_attr,
16671688
shape=self._size,
16681689
dtype=self._dtype,
@@ -1671,7 +1692,7 @@ def __init__(self,
16711692
def forward(self, x):
16721693
return F.embedding(
16731694
x,
1674-
weight=self._weight,
1695+
weight=self.weight,
16751696
padding_idx=self._padding_idx,
16761697
sparse=self._sparse,
16771698
name=self._name)

0 commit comments

Comments
 (0)