Skip to content

Commit 3b4f375

Browse files
authored
Merge pull request #797 from jignparm/jignparm/spacetobatchnd_fix
Fix BatchToSpaceND and SpaceToBatchND operators
2 parents 18ac0e1 + 7c03cc4 commit 3b4f375

File tree

2 files changed

+225
-167
lines changed

2 files changed

+225
-167
lines changed

tests/test_backend.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2485,26 +2485,46 @@ def test_batch_to_spacend(self):
24852485
self._run_test_case([_OUTPUT], {_INPUT: input_val})
24862486

24872487
@check_opset_min_version(11, "BatchToSpaceND")
2488-
def test_batch_to_spacend_non_const(self):
2489-
input_x_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32) # NHWC
2490-
block_shape_val = np.array([2, 2]).astype(np.int64)
2491-
crops_val = np.array([[1, 0], [2, 1]]).astype(np.int64)
2492-
input_x = tf.placeholder(dtype=tf.float32, shape=input_x_val.shape, name=_TFINPUT)
2493-
block_shape = tf.placeholder(dtype=tf.int64, shape=block_shape_val.shape, name=_TFINPUT1)
2494-
crops = tf.placeholder(dtype=tf.int64, shape=crops_val.shape, name=_TFINPUT2)
2495-
_ = tf.batch_to_space_nd(input_x, block_shape, crops, name=_TFOUTPUT)
2496-
self._run_test_case([_OUTPUT], {_INPUT: input_x_val, _INPUT1: block_shape_val, _INPUT2: crops_val})
2488+
def test_batch_to_spacend_non_const_7d(self):
2489+
x_type, y_type, z_type = np.int64, np.int64, np.int64
2490+
# test 3D upto 7D input tensors
2491+
for x_shape in [[12, 4, 4], [12, 4, 8, 3], [12, 4, 8, 3, 2], [12, 4, 8, 3, 2, 3], [12, 4, 8, 3, 2, 1, 3]]:
2492+
# test 1D upto 2D block shapes
2493+
for block_shape in [[2, 3], [2]]:
2494+
tf.reset_default_graph()
2495+
# crop 1 layer at end of each dim
2496+
crops = [[0, 1] for dim in block_shape]
2497+
y_val = np.array(block_shape).astype(y_type)
2498+
x_val = np.array([x + 1 for x in range(0, np.prod(x_shape))], dtype=x_type).reshape(x_shape)
2499+
z_val = np.array(crops).astype(z_type)
2500+
# x and z can be dynamic.
2501+
# y = block_shape cannot be dynamic without change to Transpose op spec
2502+
x = tf.placeholder(dtype=x_type, shape=x_val.shape, name=_TFINPUT)
2503+
y = tf.constant(dtype=y_type, value=y_val, shape=y_val.shape, name=_TFINPUT1)
2504+
z = tf.placeholder(dtype=z_type, shape=z_val.shape, name=_TFINPUT2)
2505+
_ = tf.batch_to_space_nd(x, y, z, name=_TFOUTPUT)
2506+
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT2: z_val})
24972507

24982508
@check_opset_min_version(11, "SpaceToBatchND")
2499-
def test_space_to_batchnd_non_const(self):
2500-
input_x_val = np.random.random_sample([40, 5, 7, 66]).astype(np.float32) # NHWC
2501-
block_size_val = np.array([2, 2]).astype(np.int64)
2502-
pad_val = np.array([[0, 1], [2, 1]]).astype(np.int64)
2503-
input_x = tf.placeholder(dtype=tf.float32, shape=input_x_val.shape, name=_TFINPUT)
2504-
block_size = tf.placeholder(dtype=tf.int64, shape=block_size_val.shape, name=_TFINPUT1)
2505-
pad = tf.placeholder(dtype=tf.int64, shape=pad_val.shape, name=_TFINPUT2)
2506-
_ = tf.space_to_batch_nd(input_x, block_size, pad, name=_TFOUTPUT)
2507-
self._run_test_case([_OUTPUT], {_INPUT: input_x_val, _INPUT1: block_size_val, _INPUT2: pad_val})
2509+
def test_space_to_batchnd_non_const_7d(self):
2510+
x_type, y_type, z_type = np.int64, np.int64, np.int64
2511+
# test 3D upto 7D input tensors
2512+
for x_shape in [[2, 4, 4], [1, 4, 8, 3], [1, 4, 8, 3, 2], [1, 4, 8, 3, 2, 3], [1, 4, 8, 3, 2, 1, 3]]:
2513+
# test 1D upto 2D block shapes
2514+
for block_shape in [[2], [2, 2]]:
2515+
tf.reset_default_graph()
2516+
# pad 1 layer at begin and end of each dim
2517+
pads = [[1, 1] for dim in block_shape]
2518+
y_val = np.array(block_shape).astype(y_type)
2519+
x_val = np.array([x + 1 for x in range(0, np.prod(x_shape))], dtype=x_type).reshape(x_shape)
2520+
z_val = np.array(pads).astype(z_type)
2521+
# x and z can be dynamic.
2522+
# y = block_shape cannot be dynamic without change to Transpose op spec
2523+
x = tf.placeholder(dtype=x_type, shape=x_val.shape, name=_TFINPUT)
2524+
y = tf.constant(dtype=y_type, value=y_val, shape=y_val.shape, name=_TFINPUT1)
2525+
z = tf.placeholder(dtype=z_type, shape=z_val.shape, name=_TFINPUT2)
2526+
_ = tf.space_to_batch_nd(x, y, z, name=_TFOUTPUT)
2527+
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT2: z_val})
25082528

25092529
@check_opset_min_version(11, "CropAndResize")
25102530
def test_crop_and_resize_linear(self):

0 commit comments

Comments
 (0)