Skip to content

Commit 1c7d4ce

Browse files
Fix problem with adding more than one tf.newaxis at the same time (#2007)
Signed-off-by: southfreebird <[email protected]> Co-authored-by: iolkhovsky <[email protected]>
1 parent 404e2b7 commit 1c7d4ce

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

tests/test_backend.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5893,5 +5893,23 @@ def func(x):
58935893
x_val = make_xval([3, 4])
58945894
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
58955895

5896+
@check_opset_min_version(10, "Slice")
5897+
def test_addition_two_newaxis_simultaneously(self):
5898+
def func(x):
5899+
op = x[..., tf.newaxis, tf.newaxis]
5900+
return tf.identity(op, name=_TFOUTPUT)
5901+
5902+
x_val = make_xval([2, 3])
5903+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
5904+
5905+
@check_opset_min_version(10, "Slice")
5906+
def test_addition_three_newaxis_simultaneously(self):
5907+
def func(x):
5908+
op = x[..., tf.newaxis, tf.newaxis, tf.newaxis]
5909+
return tf.identity(op, name=_TFOUTPUT)
5910+
5911+
x_val = make_xval([2, 3])
5912+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
5913+
58965914
if __name__ == '__main__':
58975915
unittest_main()

tf2onnx/onnx_opset/tensor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,29 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
974974
begin_mask |= 1 << bit
975975
end_mask |= 1 << bit
976976

977+
if ellipsis_mask:
978+
unqueeze_at = []
979+
ellipsis_gap = 0
980+
num_new = 0
981+
end_mask = node.get_attr("end_mask")
982+
end_mask = end_mask.i if end_mask is not None else 0
983+
begin_mask = node.get_attr("begin_mask")
984+
begin_mask = begin_mask.i if begin_mask is not None else 0
985+
986+
for bit in range(32):
987+
new_axis_flag = (new_axis_mask >> bit) & 1
988+
ellipsis_flag = (ellipsis_mask >> bit) & 1
989+
num_new += not ellipsis_flag and new_axis_flag
990+
991+
for bit in range(32):
992+
if (ellipsis_mask >> bit) & 1:
993+
ellipsis_gap = len(ctx.get_shape(input_x)) - param_rank + num_new + 1
994+
elif (new_axis_mask >> bit) & 1:
995+
effective_bit = bit if not ellipsis_gap else bit + ellipsis_gap - 1
996+
unqueeze_at.append(effective_bit)
997+
begin_mask |= 1 << bit
998+
end_mask |= 1 << bit
999+
9771000
input_x = GraphBuilder(ctx).make_unsqueeze(
9781001
{'data': input_x, 'axes': unqueeze_at})
9791002

0 commit comments

Comments
 (0)