Skip to content

Commit cbb3538

Browse files
satyajithjwayuanho
authored andcommitted
Op ReverseV2 (#594)
* RV2 -> RS for axes=[0, 1] and const shapes * Working ReverseV2 for axis=0 * ReverseV2 op for constant axis * ReverseV2 op for vector axis * Support for empty axis for ReverseV2 * ReverseV2 op test with random tensor * Minor fixes for ReverseV2 * ReverseV2 support for negative axis * Computing the permutation for shape reset * Working ReverseV2 with minimal Transpose nodes * ReverseV2 test for 1D tensor with empty axis * Removed unused variables * changes after PR review v1 * minor changes fo ReverseV2 op support * pylint fixes * sort the axis vector
1 parent 93648c4 commit cbb3538

File tree

2 files changed

+213
-0
lines changed

2 files changed

+213
-0
lines changed

tests/test_backend.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2015,6 +2015,65 @@ def test_reverse_sequence_time_major(self):
20152015
_ = tf.identity(x_, name=_TFOUTPUT)
20162016
self._run_test_case([_OUTPUT], {_INPUT: x_val})
20172017

2018+
2019+
@check_opset_min_version(10, "ReverseSequence")
2020+
def test_reversev2_constant_axis(self):
2021+
# Tests for constant axis.
2022+
x_val_shape = [1, 2, 3, 4]
2023+
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
2024+
x = tf.placeholder(tf.float32, x_val_shape, name=_TFINPUT)
2025+
x_ = tf.reverse_v2(x, axis=[3])
2026+
_ = tf.identity(x_, name=_TFOUTPUT)
2027+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2028+
tf.reset_default_graph()
2029+
2030+
# Empty axis vector.
2031+
x_val_shape = [2, 3, 4]
2032+
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
2033+
x = tf.placeholder(tf.float32, x_val_shape, name=_TFINPUT)
2034+
x_ = tf.reverse_v2(x, axis=[])
2035+
_ = tf.identity(x_, name=_TFOUTPUT)
2036+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2037+
2038+
2039+
@check_opset_min_version(10, "ReverseSequence")
2040+
def test_reversev2_vector_axis(self):
2041+
x_val_shape = [1, 2, 3, 4]
2042+
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
2043+
x = tf.placeholder(tf.float32, x_val_shape, name=_TFINPUT)
2044+
x_ = tf.reverse_v2(x, axis=[0, -3, 2, 3])
2045+
_ = tf.identity(x_, name=_TFOUTPUT)
2046+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2047+
tf.reset_default_graph()
2048+
2049+
x_val_shape = [2, 3, 4]
2050+
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
2051+
x = tf.placeholder(tf.float32, x_val_shape, name=_TFINPUT)
2052+
x_ = tf.reverse_v2(x, axis=[-3, 1, 2])
2053+
_ = tf.identity(x_, name=_TFOUTPUT)
2054+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2055+
tf.reset_default_graph()
2056+
2057+
x_val_shape = [5, 5, 9, 7, 8, 9]
2058+
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
2059+
x = tf.placeholder(tf.float32, [5, 5, 9, 7, 8, 9], name=_TFINPUT)
2060+
x_ = tf.reverse_v2(x, axis=[0, 1, -2, 3, 5])
2061+
_ = tf.identity(x_, name=_TFOUTPUT)
2062+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2063+
2064+
2065+
@check_opset_min_version(10, "ReverseSequence")
2066+
def test_reversev2_1D_tensor(self):
2067+
# For tensors with 1 dimension and no axis to reverse.
2068+
# Adds an identity block.
2069+
x_val_shape = [4]
2070+
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
2071+
x = tf.placeholder(tf.float32, x_val_shape, name=_TFINPUT)
2072+
x_ = tf.reverse_v2(x, axis=[])
2073+
_ = tf.identity(x_, name=_TFOUTPUT)
2074+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2075+
2076+
20182077
@check_opset_min_version(8, "where")
20192078
def test_where(self):
20202079
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.float32)

tf2onnx/onnx_opset/tensor.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,3 +1291,157 @@ def version_10(cls, ctx, node, **kwargs):
12911291
target_dtype = TensorProto.INT64
12921292
if seq_len_dtype != target_dtype:
12931293
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=target_dtype)
1294+
1295+
1296+
@tf_op("ReverseV2")
1297+
class ReverseV2:
1298+
@classmethod
1299+
def version_10(cls, ctx, node, **kwargs):
1300+
# T output = ReverseV2(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
1301+
# Implement tensorflow ReverseV2 op using multiple ReverseSequence (for each axis)
1302+
# and Transpose ops. We sort the axis vector (if non-empty) at the start. Each axis can
1303+
# be reversed only once (in tf) and so we can compute the transpose for each axis
1304+
# (other than 0), feed the tensor to a ReverseSequence node and finally transpose again
1305+
# to get back the original shape.
1306+
1307+
axes_node = node.inputs[1]
1308+
axes = axes_node.get_tensor_value(as_list=False)
1309+
# Current support is for when axis is a 1D tensor.
1310+
utils.make_sure(len(axes.shape) == 1 \
1311+
, "Currently no support for reverseV2 tensor axis")
1312+
1313+
axes = axes.tolist()
1314+
len_axes = len(axes)
1315+
1316+
# Store input and output parameters of the ReverseV2 node.
1317+
rv2_in_names = [node.input[0]]
1318+
1319+
input_shape = ctx.get_shape(node.input[0])
1320+
# Make sure input shape is not None
1321+
utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0]))
1322+
1323+
input_rank = len(input_shape)
1324+
1325+
rv2_node_name = node.name
1326+
# ReverseV2 has a single output.
1327+
rv2_output_dtypes = node.output_dtypes
1328+
rv2_output_shapes = node.output_shapes
1329+
1330+
const_name_root = rv2_node_name + '_Const'
1331+
1332+
# Remove ReverseV2 node from graph.
1333+
ctx.remove_node(rv2_node_name)
1334+
1335+
# Variable to store input names for the next node.
1336+
inputs = rv2_in_names
1337+
1338+
new_node = None
1339+
1340+
# Empty axis vector.
1341+
if len_axes == 0:
1342+
# Replace ReverseV2 with an identity block.
1343+
new_node = ctx.make_node(
1344+
"Identity",
1345+
inputs=inputs,
1346+
outputs=node.output,
1347+
shapes=rv2_output_shapes,
1348+
dtypes=rv2_output_dtypes,
1349+
op_name_scope=rv2_node_name,
1350+
)
1351+
1352+
else:
1353+
# For negative indices use the positive counterpart.
1354+
for i, ax in enumerate(axes):
1355+
if ax < 0:
1356+
axes[i] += input_rank
1357+
1358+
axes = sorted(axes)
1359+
1360+
orig_perm = list(range(input_rank))
1361+
curr_perm = []
1362+
1363+
# Add ReverseSequence nodes for each element of axis.
1364+
for i in range(len_axes):
1365+
1366+
axis = axes[i]
1367+
1368+
curr_perm = orig_perm.copy()
1369+
# Permutation indices relative to original tensor.
1370+
curr_perm[axis], curr_perm[0] = curr_perm[0], curr_perm[axis]
1371+
1372+
# Add a Transpose node if the axis != 0 (finish first due to sort).
1373+
if axis != 0:
1374+
# Permutation indices for the transpose node relative to IN tensor shape.
1375+
new_node = ctx.make_node(
1376+
"Transpose",
1377+
inputs=inputs,
1378+
op_name_scope=rv2_node_name,
1379+
dtypes=rv2_output_dtypes,
1380+
attr={"perm": curr_perm}
1381+
)
1382+
1383+
inputs = [new_node.output[0]]
1384+
1385+
# Add a Constant node (seq_len) for ReverseSequence.
1386+
1387+
# Index 1 for the shape should not return 0
1388+
# since the input must have rank >= 2.
1389+
rs_batch_size = ctx.get_shape(inputs[-1])[1]
1390+
1391+
# Make sure rs_batch_size and input_shape[axis] are not -1 each
1392+
utils.make_sure(input_shape[axis] is not -1 \
1393+
, "shape of axis {} is unknown".format(axis))
1394+
utils.make_sure(rs_batch_size is not -1 \
1395+
, "ReverseSequence batch size for axis {} is unknown".format(axis))
1396+
1397+
seq_list = [input_shape[axis]] * rs_batch_size
1398+
seq_array = np.asarray(seq_list, dtype=np.int64) # dtype should be int64
1399+
1400+
const_seq_name = utils.make_name(const_name_root)
1401+
new_node = ctx.make_const(name=const_seq_name, np_val=seq_array)
1402+
inputs.append(new_node.output[0])
1403+
1404+
# Add a ReverseSequence node.
1405+
1406+
# If processing for the final axis and the tensor shape permutation is
1407+
# original then the output is fed to the output of the ReverseV2 node.
1408+
#
1409+
# Else a new output is created which is fed to a Transpose node.
1410+
rs_out_name = node.output if \
1411+
((i == len_axes - 1) and (curr_perm == orig_perm)) \
1412+
else None
1413+
1414+
rs_out_shapes = None if rs_out_name is None else rv2_output_shapes
1415+
1416+
new_node = ctx.make_node(
1417+
"ReverseSequence",
1418+
inputs=inputs,
1419+
op_name_scope=rv2_node_name,
1420+
outputs=rs_out_name,
1421+
shapes=rs_out_shapes,
1422+
dtypes=rv2_output_dtypes,
1423+
attr={"batch_axis": 1, "time_axis": 0}
1424+
)
1425+
1426+
inputs = [new_node.output[0]]
1427+
1428+
# Additional transpose block is required if the current
1429+
# permutation list is not the original one.
1430+
if curr_perm != orig_perm:
1431+
1432+
# Compute the required permutation list.
1433+
if len_axes != 1:
1434+
for i, ax in enumerate(axes[::-1][1:]):
1435+
curr_perm[0], curr_perm[ax] = \
1436+
curr_perm[ax], curr_perm[0]
1437+
1438+
# Add a Transpose node to restore shape.
1439+
new_node = ctx.make_node(
1440+
"Transpose",
1441+
inputs=inputs,
1442+
op_name_scope=rv2_node_name,
1443+
outputs=node.output,
1444+
shapes=rv2_output_shapes,
1445+
dtypes=rv2_output_dtypes,
1446+
attr={"perm": curr_perm}
1447+
)

0 commit comments

Comments
 (0)