Skip to content

Commit 1db2c68

Browse files
committed
add test cases for tf.map_fn
1 parent 1e5e377 commit 1db2c68

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

tests/test_backend.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,33 @@ def test_atrig_ops(self):
201201
_ = tf.identity(op_, name=_TFOUTPUT)
202202
self._run_test_case([_OUTPUT], {_INPUT: x_val})
203203

204+
def test_map_fn(self):
205+
def fn0(elem):
206+
res = elem + elem * elem
207+
return res
208+
209+
def fn1(elem):
210+
res = elem[0] * elem[1] + elem[0]
211+
return res
212+
213+
x_val = 100 * np.random.random_sample([2, 10]).astype(np.float32)
214+
y_val = 100 * np.random.random_sample([2, 10]).astype(np.float32)
215+
216+
# test fn0
217+
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
218+
res_ = tf.map_fn(fn0, x, dtype=tf.float32)
219+
_ = tf.identity(res_, name=_TFOUTPUT1)
220+
self._run_test_case([_OUTPUT1], {_INPUT: x_val}, rtol=0)
221+
tf.reset_default_graph()
222+
223+
# test fn1
224+
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
225+
y = tf.placeholder(tf.float32, shape=y_val.shape, name=_TFINPUT1)
226+
res_ = tf.map_fn(fn1, (x, y), dtype=tf.float32)
227+
_ = tf.identity(res_, name=_TFOUTPUT1)
228+
self._run_test_case([_OUTPUT1], {_INPUT: x_val, _INPUT1: y_val}, rtol=0)
229+
tf.reset_default_graph()
230+
204231
@unittest.skipIf(BACKEND in ["caffe2"], "not supported correctly in caffe2")
205232
@unittest.skipIf(*support_op_conversion_since(7, "multinomial"))
206233
def test_multinomial(self):
@@ -213,7 +240,6 @@ def test_multinomial(self):
213240
self._run_test_case([_OUTPUT], {_INPUT: x_val}, check_value=False,
214241
check_shape=True, check_dtype=True)
215242

216-
217243
@unittest.skipIf(BACKEND in ["caffe2"], "not supported correctly in caffe2")
218244
@unittest.skipIf(*support_op_conversion_since(7, "multinomial"))
219245
def test_multinomial1(self):

0 commit comments

Comments
 (0)