Skip to content

Commit bad0c27

Browse files
committed
add test_lookup_sparse_table_op
1 parent 8d205c8 commit bad0c27

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,33 @@ def check_with_place(self, place):
8080
assert (result_array2[3] == w_array[6]).all()
8181
assert (result_array2[4] == w_array[7]).all()
8282

83+
# create and run lookup_table operator
84+
test_lookup_table = Operator(
85+
"lookup_sparse_table",
86+
W='W',
87+
Ids='Ids',
88+
Out='Out',
89+
min=-5.0,
90+
max=10.0,
91+
seed=10,
92+
is_test=True)
93+
94+
ids = scope.var("Ids").get_tensor()
95+
unknown_id = [44, 22, 33]
96+
ids_array2 = np.array([4, 2, 3, 7, 100000] + unknown_id).astype("int64")
97+
ids.set(ids_array2, place)
98+
test_lookup_table.run(scope, place)
99+
100+
result_array2 = np.array(out_tensor)
101+
assert (result_array2[0] == w_array[5]).all()
102+
assert (result_array2[1] == w_array[1]).all()
103+
assert (result_array2[2] == w_array[2]).all()
104+
assert (result_array2[3] == w_array[6]).all()
105+
assert (result_array2[4] == w_array[7]).all()
106+
107+
for i in [5, 6, 7]:
108+
assert np.all(result_array2[i] == 0)
109+
83110
def test_w_is_selected_rows(self):
84111
places = [core.CPUPlace()]
85112
# currently only support CPU

0 commit comments

Comments
 (0)