16
16
17
17
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
18
18
19
+
19
20
class LoopTests (Tf2OnnxBackendTestBase ):
20
21
21
22
def test_simple_while_loop (self ):
@@ -31,7 +32,6 @@ def test_simple_while_loop(self):
31
32
output_names_with_port = ["output:0" ]
32
33
self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
33
34
34
-
35
35
def test_simple_while_loop_2 (self ):
36
36
i = tf .placeholder (tf .int32 , (), name = "input_1" )
37
37
c = lambda i : tf .logical_and (tf .less (i , 10 ), tf .greater_equal (i , 3 ))
@@ -45,7 +45,6 @@ def test_simple_while_loop_2(self):
45
45
output_names_with_port = ["output:0" ]
46
46
self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
47
47
48
-
49
48
def test_while_loop_with_ta_write (self ):
50
49
i = tf .placeholder (tf .int32 , (), name = "input_1" )
51
50
output_ta = tf .TensorArray (dtype = tf .int32 , size = 0 , dynamic_size = True )
@@ -68,7 +67,6 @@ def b(i, out_ta):
68
67
output_names_with_port = ["output:0" , "i:0" ]
69
68
self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
70
69
71
-
72
70
def test_while_loop_with_ta_read (self ):
73
71
i = tf .placeholder (tf .int32 , (), name = "input_1" )
74
72
inputs = tf .placeholder (tf .float32 , (10 ,), name = "input_2" )
@@ -82,6 +80,7 @@ def test_while_loop_with_ta_read(self):
82
80
c = lambda i , * _ : tf .logical_and (tf .less (i , 10 ), i >= 0 )
83
81
res = tf .constant (0. )
84
82
res2 = tf .constant (1. )
83
+
85
84
def b (i , res , res2 ):
86
85
new_i = tf .add (i , 1 )
87
86
x = input_ta .read (i )
@@ -113,6 +112,7 @@ def test_while_loop_with_ta_read_reference_outer_input_directly(self):
113
112
c = lambda i , * _ : tf .logical_and (tf .less (i , 10 ), i >= 0 )
114
113
res = tf .constant (0. )
115
114
res2 = tf .constant (1. )
115
+
116
116
def b (i , res , res2 ):
117
117
new_i = tf .add (i , 1 )
118
118
x = input_ta .read (i )
@@ -132,7 +132,6 @@ def b(i, res, res2):
132
132
output_names_with_port = ["i:0" , "x:0" , "y:0" ]
133
133
self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
134
134
135
-
136
135
def test_while_loop_with_ta_read_and_write (self ):
137
136
i = tf .placeholder (tf .int32 , (), name = "input_1" )
138
137
inputs = tf .placeholder (tf .float32 , (10 ,), name = "input_2" )
@@ -160,5 +159,42 @@ def b(i, out_ta):
160
159
output_names_with_port = ["i:0" , "output_ta:0" ]
161
160
self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
162
161
162
+ def test_map_fn (self ):
163
+ def fn0 (elem ):
164
+ res = elem + elem * elem
165
+ return res
166
+
167
+ def fn1 (elem ):
168
+ res = elem [0 ] * elem [1 ] + elem [0 ]
169
+ return res
170
+
171
+ x_val = 100 * np .random .random_sample ([2 , 10 ]).astype (np .float32 )
172
+ y_val = 100 * np .random .random_sample ([2 , 10 ]).astype (np .float32 )
173
+
174
+ # test fn0
175
+ x = tf .placeholder (tf .float32 , shape = x_val .shape , name = "input_0" )
176
+ x_ = tf .identity (x )
177
+ res_ = tf .map_fn (fn0 , x_ , dtype = tf .float32 )
178
+ _ = tf .identity (res_ , name = "output_0" )
179
+ feed_dict = {"input_0:0" : x_val }
180
+ input_names_with_port = ["input_0:0" ]
181
+ output_names_with_port = ["output_0:0" ]
182
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-5 )
183
+ tf .reset_default_graph ()
184
+
185
+ # test fn1
186
+ x = tf .placeholder (tf .float32 , shape = x_val .shape , name = "input_0" )
187
+ y = tf .placeholder (tf .float32 , shape = y_val .shape , name = "input_1" )
188
+ x_ = tf .identity (x )
189
+ y_ = tf .identity (y )
190
+ res_ = tf .map_fn (fn1 , (x_ , y_ ), dtype = tf .float32 )
191
+ _ = tf .identity (res_ , name = "output_0" )
192
+ feed_dict = {"input_0:0" : x_val , "input_1:0" : y_val }
193
+ input_names_with_port = ["input_0:0" , "input_1:0" ]
194
+ output_names_with_port = ["output_0:0" ]
195
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-5 )
196
+ tf .reset_default_graph ()
197
+
198
+
163
199
if __name__ == '__main__' :
164
200
Tf2OnnxBackendTestBase .trigger (LoopTests )
0 commit comments