Skip to content

Commit 1ffb148

Browse files
committed
Fix strided slice, reverse stride with unbounded begin and end
1 parent c5e560d commit 1ffb148

File tree

3 files changed

+73
-26
lines changed

3 files changed

+73
-26
lines changed

tests/test_backend.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,6 +1787,32 @@ def func(x, y):
17871787
return tf.identity(x_, name=_TFOUTPUT)
17881788
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
17891789

1790+
@check_opset_min_version(10, "Slice")
1791+
def test_strided_slice_reverse_1(self):
1792+
tf.reset_default_graph()
1793+
x_val = np.arange(16 * 32).astype(np.float32).reshape((1, 16, 32, 1))
1794+
def func(x):
1795+
return tf.concat([x[:, :, :10], x[:, :, :21:-1]], axis=0, name=_TFOUTPUT)
1796+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1797+
1798+
@check_opset_min_version(10, "Slice")
1799+
def test_strided_slice_reverse_2(self):
1800+
tf.reset_default_graph()
1801+
x_val = np.arange(16 * 32).astype(np.float32).reshape((1, 16, 32, 1))
1802+
def func(x):
1803+
return tf.concat([x[:, :, :10], x[:, :, 9::-1]], axis=0, name=_TFOUTPUT)
1804+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1805+
1806+
@check_opset_min_version(10, "Slice")
1807+
def test_strided_slice_reverse_3(self):
1808+
tf.reset_default_graph()
1809+
x_val = np.zeros((1, 16, 32, 1)).astype(np.float32)
1810+
y_val = np.array(9).astype(np.int32)
1811+
z_val = np.array(-1).astype(np.int32)
1812+
def func(x, y, z):
1813+
return tf.concat([x[:, :, :10], x[:, :, y::z]], axis=0, name=_TFOUTPUT)
1814+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
1815+
17901816
@check_opset_min_version(10, "Slice")
17911817
def test_new_axis_mask(self):
17921818
def func(x, y):

tf2onnx/onnx_opset/math.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def make_min_or_max_op(ctx, op_type, inputs, outputs,
9797
ctx.set_dtype(cast_node.output[0], origin_dtype)
9898
ctx.copy_shape(node.output[0], cast_node.output[0])
9999
actual_outputs = cast_node.output
100-
ctx.make_node("Identity", actual_outputs, outputs=outputs,
101-
shapes=output_shapes, dtypes=output_dtypes)
100+
final_node = ctx.make_node("Identity", actual_outputs, outputs=outputs,
101+
shapes=output_shapes, dtypes=output_dtypes)
102102

103103
# tensorflow minimum/maximum does support broadcast, onnx < opset 8 does not.
104104
# handle this by doing something like:
@@ -124,6 +124,7 @@ def make_min_or_max_op(ctx, op_type, inputs, outputs,
124124
add_node = ctx.make_node("Add", [input_node.output[0], sub_node.output[0]],
125125
op_name_scope=input_node.name)
126126
node.input[i] = add_node.output[0]
127+
return final_node
127128

128129

129130
@tf_op("Minimum", onnx_op="Min")

tf2onnx/onnx_opset/tensor.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -845,37 +845,57 @@ def version_10(cls, ctx, node, **kwargs):
845845
# mask begin
846846
new_begin_mask = np.array(new_begin_mask, dtype=np_dtype)
847847
if not np.all(new_begin_mask == 1):
848-
if begin.is_const():
849-
begin = ctx.make_const(
850-
utils.make_name("begin_masked"),
851-
begin.get_tensor_value(as_list=False) * new_begin_mask
852-
)
848+
if begin.is_const() and strides.is_const():
849+
begin_vals = begin.get_tensor_value(as_list=False)
850+
strides_vals = strides.get_tensor_value(as_list=False)
851+
new_begin_vals = np.copy(begin_vals)
852+
for i, v in enumerate(strides_vals):
853+
if v < 0 and new_begin_mask[i] == 0:
854+
new_begin_vals[i] = max_size
855+
begin = ctx.make_const(utils.make_name("begin_masked"), new_begin_vals)
853856
else:
854-
begin_mask_const = ctx.make_const(
855-
utils.make_name("begin_mask"),
856-
new_begin_mask
857-
)
858-
begin = ctx.make_node(
859-
"Mul", [begin.output[0], begin_mask_const.output[0]],
860-
op_name_scope=node.name
861-
)
857+
begin_mask_const = ctx.make_const(utils.make_name("begin_mask"), np.equal(new_begin_mask, 0))
858+
zero_const = ctx.make_const(utils.make_name("zero_const"), np.zeros(1, dtype=np_dtype))
859+
max_const = ctx.make_const(utils.make_name("max_const"), np.array(max_size, dtype=np_dtype))
860+
is_reverse_steps = ctx.make_node("Less", [strides.output[0], zero_const.output[0]],
861+
op_name_scope=node.name)
862+
is_reverse_and_full_range = ctx.make_node("And", [is_reverse_steps.output[0],
863+
begin_mask_const.output[0]], op_name_scope=node.name)
864+
begin = ctx.make_node("Where", [is_reverse_and_full_range.output[0], max_const.output[0],
865+
begin.output[0]], op_name_scope=node.name)
866+
862867
# mask end
863868
new_end_mask = np.array(new_end_mask, dtype=np_dtype)
864869
end_output = end.output[0]
865870
if not np.all(new_end_mask == min_size):
866-
if end.is_const():
867-
end = ctx.make_const(
868-
utils.make_name("end_masked"),
869-
np.maximum(end.get_tensor_value(as_list=False), new_end_mask)
870-
)
871+
if end.is_const() and strides.is_const() and False:
872+
new_end_mask = np.maximum(end.get_tensor_value(as_list=False), new_end_mask)
873+
for i, v in enumerate(strides.get_tensor_value(as_list=False)):
874+
if new_end_mask[i] == max_size:
875+
new_end_mask[i] *= np.sign(v)
876+
end = ctx.make_const(utils.make_name("end_masked"), new_end_mask)
871877
end_output = end.output[0]
878+
872879
else:
873-
end_mask_const = ctx.make_const(
874-
utils.make_name("end_mask"),
875-
np.array(new_end_mask, dtype=np_dtype)
876-
)
877-
end_output = utils.make_name("{}__end".format(node.name))
878-
math.make_min_or_max_op(ctx, "Max", [end.output[0], end_mask_const.output[0]], [end_output])
880+
# Overlay new_end_mask with specified end values.
881+
# Adjust max_size to min_size if steps are < 0
882+
max_const = ctx.make_const(utils.make_name("max_const"), np.array(max_size, dtype=np_dtype))
883+
min_const = ctx.make_const(utils.make_name("min_const"), np.array(min_size, dtype=np_dtype))
884+
zero_const = ctx.make_const(utils.make_name("zero_const"), np.zeros(1, dtype=np_dtype))
885+
end_mask_const = ctx.make_const(utils.make_name("end_mask"), np.array(new_end_mask, dtype=np_dtype))
886+
outputname = utils.make_name("{}__newendmask".format(node.name))
887+
new_end_mask = math.make_min_or_max_op(ctx, "Max", [end.output[0], end_mask_const.output[0]],
888+
[outputname])
889+
is_reverse_steps = ctx.make_node("Less", [strides.output[0], zero_const.output[0]],
890+
op_name_scope=node.name)
891+
is_full_range = ctx.make_node("Equal", [new_end_mask.output[0], max_const.output[0]],
892+
op_name_scope=node.name)
893+
is_reverse_and_full_range = ctx.make_node("And", [is_full_range.output[0], is_reverse_steps.output[0]],
894+
op_name_scope=node.name)
895+
final_end = ctx.make_node("Where", [is_reverse_and_full_range.output[0], min_const.output[0],
896+
new_end_mask.output[0]], op_name_scope=node.name)
897+
end_output = final_end.output[0]
898+
879899
# mask strides for shrink
880900
shrink_strided_mask = np.array(shrink_strided_mask, dtype=np_dtype)
881901
strides_output = strides.output[0]

0 commit comments

Comments
 (0)