@@ -201,6 +201,33 @@ def test_atrig_ops(self):
201
201
_ = tf .identity (op_ , name = _TFOUTPUT )
202
202
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
203
203
204
+ def test_map_fn (self ):
205
+ def fn0 (elem ):
206
+ res = elem + elem * elem
207
+ return res
208
+
209
+ def fn1 (elem ):
210
+ res = elem [0 ] * elem [1 ] + elem [0 ]
211
+ return res
212
+
213
+ x_val = 100 * np .random .random_sample ([2 , 10 ]).astype (np .float32 )
214
+ y_val = 100 * np .random .random_sample ([2 , 10 ]).astype (np .float32 )
215
+
216
+ # test fn0
217
+ x = tf .placeholder (tf .float32 , shape = x_val .shape , name = _TFINPUT )
218
+ res_ = tf .map_fn (fn0 , x , dtype = tf .float32 )
219
+ _ = tf .identity (res_ , name = _TFOUTPUT1 )
220
+ self ._run_test_case ([_OUTPUT1 ], {_INPUT : x_val }, rtol = 0 )
221
+ tf .reset_default_graph ()
222
+
223
+ # test fn1
224
+ x = tf .placeholder (tf .float32 , shape = x_val .shape , name = _TFINPUT )
225
+ y = tf .placeholder (tf .float32 , shape = y_val .shape , name = _TFINPUT1 )
226
+ res_ = tf .map_fn (fn1 , (x , y ), dtype = tf .float32 )
227
+ _ = tf .identity (res_ , name = _TFOUTPUT1 )
228
+ self ._run_test_case ([_OUTPUT1 ], {_INPUT : x_val , _INPUT1 : y_val }, rtol = 0 )
229
+ tf .reset_default_graph ()
230
+
204
231
@unittest .skipIf (BACKEND in ["caffe2" ], "not supported correctly in caffe2" )
205
232
@unittest .skipIf (* support_op_conversion_since (7 , "multinomial" ))
206
233
def test_multinomial (self ):
@@ -213,7 +240,6 @@ def test_multinomial(self):
213
240
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, check_value = False ,
214
241
check_shape = True , check_dtype = True )
215
242
216
-
217
243
@unittest .skipIf (BACKEND in ["caffe2" ], "not supported correctly in caffe2" )
218
244
@unittest .skipIf (* support_op_conversion_since (7 , "multinomial" ))
219
245
def test_multinomial1 (self ):
0 commit comments