@@ -53,39 +53,38 @@ class TestLookupTableIdsIsSelectedRows(OpTest):
53
53
def check_with_place (self , place ):
54
54
scope = core .Scope ()
55
55
56
- # create and initialize Grad Variable
56
+ # create and initialize Variable
57
57
height = 10
58
58
rows = [0 , 4 , 4 , 7 ]
59
59
row_numel = 12
60
60
61
- ids_selected_rows = scope .var ('Ids' ).get_selected_rows ()
62
- ids_selected_rows .set_height (height )
63
- ids_selected_rows .set_rows (rows )
64
- np_array = np .ones ((len (rows ), row_numel )).astype ("float32" )
65
- ids_tensor = ids_selected_rows .get_tensor ()
66
- ids_tensor .set (np_array , place )
67
-
68
61
# create and initialize W Variable
69
62
W = scope .var ('W' ).get_tensor ()
70
63
W_array = np .full ((height , row_numel ), 1.0 ).astype ("float32" )
71
64
for i in range (height ):
72
65
W_array [i ] *= i
73
66
W .set (W_array , place )
74
67
68
+ # create and initialize Ids Variable
69
+ ids_selected_rows = scope .var ('Ids' ).get_selected_rows ()
70
+ ids_selected_rows .set_height (len (rows ))
71
+ ids_selected_rows .set_rows (rows )
72
+ np_array = np .ones ((len (rows ), row_numel )).astype ("float32" )
73
+ ids_tensor = ids_selected_rows .get_tensor ()
74
+ ids_tensor .set (np_array , place )
75
+
76
+ # create Out Variable
75
77
Out = scope .var ('Out' ).get_selected_rows ()
76
- Out_array = np .full ((len (rows ), row_numel ), - 1.0 ).astype ("float32" )
77
- Out .set_height (height )
78
- Out .set_rows (rows )
79
- Out_tensor = Out .get_tensor ()
80
- Out_tensor .set (Out_array , place )
81
78
82
- # create and run concat_rows_op operator
79
+ # create and run lookup_table operator
83
80
concat_rows_op = Operator ("lookup_table" , W = 'W' , Ids = 'Ids' , Out = 'Out' )
84
81
concat_rows_op .run (scope , place )
85
82
86
- # get and compare result
83
+ # get result from Out
84
+ Out_tensor = Out .get_tensor ()
87
85
result_array = np .array (Out_tensor )
88
86
87
+ # all(): return True if all elements of the iterable are true (or if the iterable is empty)
89
88
for idx , row in enumerate (rows ):
90
89
assert (row == result_array [idx ]).all ()
91
90
0 commit comments