Skip to content

Commit b1af7e5

Browse files
committed
Add unittests for lookup_table_op
1 parent 7efdf05 commit b1af7e5

File tree

3 files changed

+76
-14
lines changed

3 files changed

+76
-14
lines changed

paddle/fluid/operators/lookup_table_op.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
139139

140140
auto *d_table_data = d_table_value->data<T>();
141141
auto *d_output_data = d_output->data<T>();
142-
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
142+
auto d_output_dims = d_output->dims();
143+
PADDLE_ENFORCE_EQ(
144+
d_table_value->dims(),
145+
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
143146
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
144147
d_output->numel() * sizeof(T), stream);
145148

paddle/fluid/operators/lookup_table_op.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,10 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
127127
auto *d_output_data = d_output->data<T>();
128128
auto *d_table_data = d_table_value->data<T>();
129129

130-
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
130+
auto d_output_dims = d_output->dims();
131+
PADDLE_ENFORCE_EQ(
132+
d_table_value->dims(),
133+
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
131134
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
132135
} else {
133136
auto *ids = context.Input<LoDTensor>("Ids");
@@ -137,7 +140,7 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
137140
auto *ids_data = ids->data<int64_t>();
138141

139142
int N = table_dim[0];
140-
int D = d_output->dims()[1];
143+
int D = table_dim[1];
141144

142145
auto *d_output_data = d_output->data<T>();
143146
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());

python/paddle/fluid/tests/unittests/test_lookup_table_op.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@ def test_check_grad(self):
3535
self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
3636

3737

38+
class TestLookupTableOpWithTensorIds(OpTest):
39+
def setUp(self):
40+
self.op_type = "lookup_table"
41+
table = np.random.random((17, 31)).astype("float32")
42+
ids = np.random.randint(
43+
low=0, high=17, size=(2, 4, 5, 1)).astype("int64")
44+
self.inputs = {'W': table, 'Ids': ids}
45+
self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))}
46+
47+
def test_check_output(self):
48+
self.check_output()
49+
50+
def test_check_grad(self):
51+
self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
52+
53+
3854
class TestLookupTableOpWithPadding(TestLookupTableOp):
3955
def test_check_output(self):
4056
ids = np.squeeze(self.inputs['Ids'])
@@ -44,21 +60,34 @@ def test_check_output(self):
4460
self.check_output()
4561

4662
def test_check_grad(self):
47-
# Since paddings are not trainable and fixed in forward, the gradient of
63+
# Since paddings are not trainable and fixed in forward, the gradient of
4864
# paddings makes no sense and we don't test the gradient here.
4965
pass
5066

5167

52-
class TestLookupTableWIsSelectedRows(OpTest):
53-
def check_with_place(self, place):
54-
scope = core.Scope()
68+
class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
69+
def test_check_output(self):
70+
ids = self.inputs['Ids']
71+
flatten_idx = ids.flatten()
72+
padding_idx = np.random.choice(flatten_idx, 1)[0]
73+
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
74+
self.attrs = {'padding_idx': long(padding_idx)}
75+
self.check_output()
76+
77+
def test_check_grad(self):
78+
# Since paddings are not trainable and fixed in forward, the gradient of
79+
# paddings makes no sense and we don't test the gradient here.
80+
pass
5581

56-
# create and initialize Id Variable
82+
83+
class TestLookupTableWIsSelectedRows(OpTest):
84+
def prepare_ids(self, scope, place):
5785
ids_tensor = scope.var('Ids').get_tensor()
5886
ids_array = np.array([[0], [4], [3], [5]]).astype("int64")
5987
ids_tensor.set(ids_array, place)
88+
return ids_array
6089

61-
# create and initialize W Variable
90+
def prepare_w(self, scope, place):
6291
rows = [0, 1, 2, 3, 4, 5, 6]
6392
row_numel = 12
6493

@@ -71,18 +100,31 @@ def check_with_place(self, place):
71100
w_tensor = w_selected_rows.get_tensor()
72101
w_tensor.set(w_array, place)
73102

74-
# create Out Variable
75-
out_tensor = scope.var('Out').get_tensor()
103+
def create_out_tensor(self, scope, place):
104+
return scope.var('Out').get_tensor()
105+
106+
def check_result(self, ids_array, result_array):
107+
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
108+
for idx, row in enumerate(ids_array):
109+
assert (row[0] == result_array[idx]).all()
110+
111+
def check_with_place(self, place):
112+
scope = core.Scope()
113+
114+
ids_array = self.prepare_ids(scope, place)
115+
116+
self.prepare_w(scope, place)
117+
118+
out_tensor = self.create_out_tensor(scope, place)
76119

77120
# create and run lookup_table operator
78121
lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
79122
lookup_table.run(scope, place)
80123

81124
# get result from Out
82125
result_array = np.array(out_tensor)
83-
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
84-
for idx, row in enumerate(ids_array):
85-
assert (row[0] == result_array[idx]).all()
126+
127+
self.check_result(ids_array, result_array)
86128

87129
def test_w_is_selected_rows(self):
88130
places = [core.CPUPlace()]
@@ -91,5 +133,19 @@ def test_w_is_selected_rows(self):
91133
self.check_with_place(place)
92134

93135

136+
class TestLookupTableWithTensorIdsWIsSelectedRows(
137+
TestLookupTableWIsSelectedRows):
138+
def prepare_ids(self, scope, place):
139+
ids_tensor = scope.var('Ids').get_tensor()
140+
ids_array = np.random.randint(
141+
low=0, high=6, size=(2, 4, 3, 1)).astype("int64")
142+
ids_tensor.set(ids_array, place)
143+
return ids_array
144+
145+
def check_result(self, ids_array, result_array):
146+
for idx, row in np.ndenumerate(ids_array):
147+
assert (row == result_array[idx]).all()
148+
149+
94150
if __name__ == "__main__":
95151
unittest.main()

0 commit comments

Comments
 (0)