@@ -35,6 +35,22 @@ def test_check_grad(self):
35
35
self .check_grad (['W' ], 'Out' , no_grad_set = set ('Ids' ))
36
36
37
37
38
+ class TestLookupTableOpWithTensorIds (OpTest ):
39
+ def setUp (self ):
40
+ self .op_type = "lookup_table"
41
+ table = np .random .random ((17 , 31 )).astype ("float32" )
42
+ ids = np .random .randint (
43
+ low = 0 , high = 17 , size = (2 , 4 , 5 , 1 )).astype ("int64" )
44
+ self .inputs = {'W' : table , 'Ids' : ids }
45
+ self .outputs = {'Out' : table [ids .flatten ()].reshape ((2 , 4 , 5 , 31 ))}
46
+
47
+ def test_check_output (self ):
48
+ self .check_output ()
49
+
50
+ def test_check_grad (self ):
51
+ self .check_grad (['W' ], 'Out' , no_grad_set = set ('Ids' ))
52
+
53
+
38
54
class TestLookupTableOpWithPadding (TestLookupTableOp ):
39
55
def test_check_output (self ):
40
56
ids = np .squeeze (self .inputs ['Ids' ])
@@ -44,21 +60,34 @@ def test_check_output(self):
44
60
self .check_output ()
45
61
46
62
def test_check_grad (self ):
47
- # Since paddings are not trainable and fixed in forward, the gradient of
63
+ # Since paddings are not trainable and fixed in forward, the gradient of
48
64
# paddings makes no sense and we don't test the gradient here.
49
65
pass
50
66
51
67
52
- class TestLookupTableWIsSelectedRows (OpTest ):
53
- def check_with_place (self , place ):
54
- scope = core .Scope ()
68
+ class TestLookupTableOpWithTensorIdsAndPadding (TestLookupTableOpWithTensorIds ):
69
+ def test_check_output (self ):
70
+ ids = self .inputs ['Ids' ]
71
+ flatten_idx = ids .flatten ()
72
+ padding_idx = np .random .choice (flatten_idx , 1 )[0 ]
73
+ self .outputs ['Out' ][np .squeeze (ids == padding_idx )] = np .zeros (31 )
74
+ self .attrs = {'padding_idx' : long (padding_idx )}
75
+ self .check_output ()
76
+
77
+ def test_check_grad (self ):
78
+ # Since paddings are not trainable and fixed in forward, the gradient of
79
+ # paddings makes no sense and we don't test the gradient here.
80
+ pass
55
81
56
- # create and initialize Id Variable
82
+
83
+ class TestLookupTableWIsSelectedRows (OpTest ):
84
+ def prepare_ids (self , scope , place ):
57
85
ids_tensor = scope .var ('Ids' ).get_tensor ()
58
86
ids_array = np .array ([[0 ], [4 ], [3 ], [5 ]]).astype ("int64" )
59
87
ids_tensor .set (ids_array , place )
88
+ return ids_array
60
89
61
- # create and initialize W Variable
90
+ def prepare_w ( self , scope , place ):
62
91
rows = [0 , 1 , 2 , 3 , 4 , 5 , 6 ]
63
92
row_numel = 12
64
93
@@ -71,18 +100,31 @@ def check_with_place(self, place):
71
100
w_tensor = w_selected_rows .get_tensor ()
72
101
w_tensor .set (w_array , place )
73
102
74
- # create Out Variable
75
- out_tensor = scope .var ('Out' ).get_tensor ()
103
+ def create_out_tensor (self , scope , place ):
104
+ return scope .var ('Out' ).get_tensor ()
105
+
106
+ def check_result (self , ids_array , result_array ):
107
+ # all(): return True if all elements of the iterable are true (or if the iterable is empty)
108
+ for idx , row in enumerate (ids_array ):
109
+ assert (row [0 ] == result_array [idx ]).all ()
110
+
111
+ def check_with_place (self , place ):
112
+ scope = core .Scope ()
113
+
114
+ ids_array = self .prepare_ids (scope , place )
115
+
116
+ self .prepare_w (scope , place )
117
+
118
+ out_tensor = self .create_out_tensor (scope , place )
76
119
77
120
# create and run lookup_table operator
78
121
lookup_table = Operator ("lookup_table" , W = 'W' , Ids = 'Ids' , Out = 'Out' )
79
122
lookup_table .run (scope , place )
80
123
81
124
# get result from Out
82
125
result_array = np .array (out_tensor )
83
- # all(): return True if all elements of the iterable are true (or if the iterable is empty)
84
- for idx , row in enumerate (ids_array ):
85
- assert (row [0 ] == result_array [idx ]).all ()
126
+
127
+ self .check_result (ids_array , result_array )
86
128
87
129
def test_w_is_selected_rows (self ):
88
130
places = [core .CPUPlace ()]
@@ -91,5 +133,19 @@ def test_w_is_selected_rows(self):
91
133
self .check_with_place (place )
92
134
93
135
136
+ class TestLookupTableWithTensorIdsWIsSelectedRows (
137
+ TestLookupTableWIsSelectedRows ):
138
+ def prepare_ids (self , scope , place ):
139
+ ids_tensor = scope .var ('Ids' ).get_tensor ()
140
+ ids_array = np .random .randint (
141
+ low = 0 , high = 6 , size = (2 , 4 , 3 , 1 )).astype ("int64" )
142
+ ids_tensor .set (ids_array , place )
143
+ return ids_array
144
+
145
+ def check_result (self , ids_array , result_array ):
146
+ for idx , row in np .ndenumerate (ids_array ):
147
+ assert (row == result_array [idx ]).all ()
148
+
149
+
94
150
if __name__ == "__main__" :
95
151
unittest .main ()
0 commit comments