@@ -65,30 +65,21 @@ def check_with_place(self, place):
65
65
]
66
66
67
67
# expected output selected rows
68
- expected_out0_rows = [0 , 9 ]
69
- expected_out1_rows = [7 , 4 ]
70
- expected_out2_rows = [5 ]
68
+ expected_out_rows = [[0 , 9 ], [7 , 4 ], [5 ]]
71
69
72
70
op = Operator ("split_ids" , Ids = "X" , Out = outs_name )
73
71
74
72
op .run (scope , place )
75
73
76
- self .assertEqual (outs [0 ].rows (), expected_out0_rows )
77
- self .assertEqual (outs [1 ].rows (), expected_out1_rows )
78
- self .assertEqual (outs [2 ].rows (), expected_out2_rows )
79
-
80
- self .assertAlmostEqual (0.0 , np .array (outs [0 ].get_tensor ())[0 , 0 ])
81
- self .assertAlmostEqual (1.0 , np .array (outs [0 ].get_tensor ())[0 , 1 ])
82
- self .assertAlmostEqual (9.0 , np .array (outs [0 ].get_tensor ())[1 , 0 ])
83
- self .assertAlmostEqual (10.0 , np .array (outs [0 ].get_tensor ())[1 , 1 ])
84
-
85
- self .assertAlmostEqual (7.0 , np .array (outs [1 ].get_tensor ())[0 , 0 ])
86
- self .assertAlmostEqual (8.0 , np .array (outs [1 ].get_tensor ())[0 , 1 ])
87
- self .assertAlmostEqual (4.0 , np .array (outs [1 ].get_tensor ())[1 , 0 ])
88
- self .assertAlmostEqual (5.0 , np .array (outs [1 ].get_tensor ())[1 , 1 ])
89
-
90
- self .assertAlmostEqual (5.0 , np .array (outs [2 ].get_tensor ())[0 , 0 ])
91
- self .assertAlmostEqual (6.0 , np .array (outs [2 ].get_tensor ())[0 , 1 ])
74
+ for i in range (len (outs )):
75
+ expected_rows = expected_out_rows [i ]
76
+ self .assertEqual (outs [i ].rows (), expected_rows )
77
+ for j in range (len (expected_rows )):
78
+ row = expected_rows [j ]
79
+ self .assertAlmostEqual (
80
+ float (row ), np .array (outs [i ].get_tensor ())[j , 0 ])
81
+ self .assertAlmostEqual (
82
+ float (row + 1 ), np .array (outs [i ].get_tensor ())[j , 1 ])
92
83
93
84
94
85
if __name__ == '__main__' :
0 commit comments