Skip to content

Commit a43eee4

Browse files
committed
follow comments
1 parent 92e2207 commit a43eee4

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

python/paddle/fluid/tests/unittests/test_lookup_table_op.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,39 +53,38 @@ class TestLookupTableIdsIsSelectedRows(OpTest):
5353
def check_with_place(self, place):
5454
scope = core.Scope()
5555

56-
# create and initialize Grad Variable
56+
# create and initialize Variable
5757
height = 10
5858
rows = [0, 4, 4, 7]
5959
row_numel = 12
6060

61-
ids_selected_rows = scope.var('Ids').get_selected_rows()
62-
ids_selected_rows.set_height(height)
63-
ids_selected_rows.set_rows(rows)
64-
np_array = np.ones((len(rows), row_numel)).astype("float32")
65-
ids_tensor = ids_selected_rows.get_tensor()
66-
ids_tensor.set(np_array, place)
67-
6861
# create and initialize W Variable
6962
W = scope.var('W').get_tensor()
7063
W_array = np.full((height, row_numel), 1.0).astype("float32")
7164
for i in range(height):
7265
W_array[i] *= i
7366
W.set(W_array, place)
7467

68+
# create and initialize Ids Variable
69+
ids_selected_rows = scope.var('Ids').get_selected_rows()
70+
ids_selected_rows.set_height(len(rows))
71+
ids_selected_rows.set_rows(rows)
72+
np_array = np.ones((len(rows), row_numel)).astype("float32")
73+
ids_tensor = ids_selected_rows.get_tensor()
74+
ids_tensor.set(np_array, place)
75+
76+
# create Out Variable
7577
Out = scope.var('Out').get_selected_rows()
76-
Out_array = np.full((len(rows), row_numel), -1.0).astype("float32")
77-
Out.set_height(height)
78-
Out.set_rows(rows)
79-
Out_tensor = Out.get_tensor()
80-
Out_tensor.set(Out_array, place)
8178

82-
# create and run concat_rows_op operator
79+
# create and run lookup_table operator
8380
concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
8481
concat_rows_op.run(scope, place)
8582

86-
# get and compare result
83+
# get result from Out
84+
Out_tensor = Out.get_tensor()
8785
result_array = np.array(Out_tensor)
8886

87+
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
8988
for idx, row in enumerate(rows):
9089
assert (row == result_array[idx]).all()
9190

0 commit comments

Comments
 (0)