15
15
import unittest
16
16
import numpy as np
17
17
from op_test import OpTest
18
+ import paddle .fluid .core as core
19
+ from paddle .fluid .op import Operator
18
20
19
21
20
22
class TestSplitIdsOp (OpTest ):
@@ -31,5 +33,63 @@ def test_check_output(self):
31
33
self .check_output ()
32
34
33
35
36
+ class TestSpliteIds (unittest .TestCase ):
37
+ def get_places (self ):
38
+ places = [core .CPUPlace ()]
39
+ return places
40
+
41
+ def test_check_output (self ):
42
+ for place in self .get_places ():
43
+ self .check_with_place (place )
44
+
45
+ def check_with_place (self , place ):
46
+ scope = core .Scope ()
47
+ rows = [0 , 5 , 7 , 4 , 9 ]
48
+ height = 20
49
+ row_numel = 2
50
+
51
+ # initialize input variable X
52
+ x = scope .var ('X' ).get_selected_rows ()
53
+ x .set_rows (rows )
54
+ x .set_height (height )
55
+ np_array = np .ones ((len (rows ), row_numel )).astype ("float32" )
56
+ for i in range (len (rows )):
57
+ for j in range (row_numel ):
58
+ np_array [i , j ] = rows [i ] + j
59
+ x_tensor = x .get_tensor ()
60
+ x_tensor .set (np_array , place )
61
+
62
+ outs_name = ["out%d" % i for i in xrange (3 )]
63
+ outs = [
64
+ scope .var (var_name ).get_selected_rows () for var_name in outs_name
65
+ ]
66
+
67
+ # expected output selected rows
68
+ expected_out0_rows = [0 , 9 ]
69
+ expected_out1_rows = [7 , 4 ]
70
+ expected_out2_rows = [5 ]
71
+
72
+ op = Operator ("split_ids" , Ids = "X" , Out = outs_name )
73
+
74
+ op .run (scope , place )
75
+
76
+ self .assertEqual (outs [0 ].rows (), expected_out0_rows )
77
+ self .assertEqual (outs [1 ].rows (), expected_out1_rows )
78
+ self .assertEqual (outs [2 ].rows (), expected_out2_rows )
79
+
80
+ self .assertAlmostEqual (0.0 , np .array (outs [0 ].get_tensor ())[0 , 0 ])
81
+ self .assertAlmostEqual (1.0 , np .array (outs [0 ].get_tensor ())[0 , 1 ])
82
+ self .assertAlmostEqual (9.0 , np .array (outs [0 ].get_tensor ())[1 , 0 ])
83
+ self .assertAlmostEqual (10.0 , np .array (outs [0 ].get_tensor ())[1 , 1 ])
84
+
85
+ self .assertAlmostEqual (7.0 , np .array (outs [1 ].get_tensor ())[0 , 0 ])
86
+ self .assertAlmostEqual (8.0 , np .array (outs [1 ].get_tensor ())[0 , 1 ])
87
+ self .assertAlmostEqual (4.0 , np .array (outs [1 ].get_tensor ())[1 , 0 ])
88
+ self .assertAlmostEqual (5.0 , np .array (outs [1 ].get_tensor ())[1 , 1 ])
89
+
90
+ self .assertAlmostEqual (5.0 , np .array (outs [2 ].get_tensor ())[0 , 0 ])
91
+ self .assertAlmostEqual (6.0 , np .array (outs [2 ].get_tensor ())[0 , 1 ])
92
+
93
+
34
94
if __name__ == '__main__' :
35
95
unittest .main ()
0 commit comments