Skip to content

Commit 7dfda8a

Browse files
committed
clean up some code
1 parent 6e85952 commit 7dfda8a

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -847,33 +847,32 @@ def version_10(cls, ctx, node, **kwargs):
847847
if not np.all(new_begin_mask == 1):
848848
if begin.is_const() and strides.is_const():
849849
new_begin_vals = np.copy(begin.get_tensor_value(as_list=False))
850-
for i, v in enumerate(strides.get_tensor_value(as_list=False)):
851-
if v < 0 and new_begin_mask[i] == 0:
852-
new_begin_vals[i] = max_size
850+
strides_vals = strides.get_tensor_value(as_list=False)
851+
idx1 = np.where(new_begin_mask == 0)
852+
idx2 = np.where(strides_vals < 0)
853+
idx3 = np.intersect1d(idx1, idx2)
854+
new_begin_vals[idx3] = max_size
853855
begin = ctx.make_const(utils.make_name("begin_masked"), new_begin_vals)
854856
else:
855857
begin_mask_const = ctx.make_const(utils.make_name("begin_mask"), np.equal(new_begin_mask, 0))
856858
zero_const = ctx.make_const(utils.make_name("zero_const"), np.zeros(1, dtype=np_dtype))
857859
max_const = ctx.make_const(utils.make_name("max_const"), np.array(max_size, dtype=np_dtype))
858-
is_reverse_steps = ctx.make_node("Less", [strides.output[0], zero_const.output[0]],
859-
op_name_scope=node.name)
860-
is_reverse_and_full_range = ctx.make_node("And", [is_reverse_steps.output[0],
861-
begin_mask_const.output[0]], op_name_scope=node.name)
862-
begin = ctx.make_node("Where", [is_reverse_and_full_range.output[0], max_const.output[0],
863-
begin.output[0]], op_name_scope=node.name)
860+
op1 = ctx.make_node("Less", [strides.output[0], zero_const.output[0]], op_name_scope=node.name)
861+
op2 = ctx.make_node("And", [op1.output[0], begin_mask_const.output[0]], op_name_scope=node.name)
862+
begin = ctx.make_node("Where", [op2.output[0], max_const.output[0], begin.output[0]],
863+
op_name_scope=node.name)
864864

865865
# mask end
866866
new_end_mask = np.array(new_end_mask, dtype=np_dtype)
867867
end_output = end.output[0]
868868
if not np.all(new_end_mask == min_size):
869-
if end.is_const() and strides.is_const() and False:
869+
if end.is_const() and strides.is_const():
870870
new_end_mask = np.maximum(end.get_tensor_value(as_list=False), new_end_mask)
871-
for i, v in enumerate(strides.get_tensor_value(as_list=False)):
872-
if new_end_mask[i] == max_size:
873-
new_end_mask[i] *= np.sign(v)
871+
idx = np.where(new_end_mask == max_size)
872+
sign = np.sign(strides.get_tensor_value(as_list=False))[idx]
873+
new_end_mask[idx] = new_end_mask[idx] * sign
874874
end = ctx.make_const(utils.make_name("end_masked"), new_end_mask)
875875
end_output = end.output[0]
876-
877876
else:
878877
# Overlay new_end_mask with specified end values.
879878
# Adjust max_size to min_size if steps are < 0
@@ -884,13 +883,10 @@ def version_10(cls, ctx, node, **kwargs):
884883
outputname = utils.make_name("{}__newendmask".format(node.name))
885884
new_end_mask = math.make_min_or_max_op(ctx, "Max", [end.output[0], end_mask_const.output[0]],
886885
[outputname])
887-
is_reverse_steps = ctx.make_node("Less", [strides.output[0], zero_const.output[0]],
888-
op_name_scope=node.name)
889-
is_full_range = ctx.make_node("Equal", [new_end_mask.output[0], max_const.output[0]],
890-
op_name_scope=node.name)
891-
is_reverse_and_full_range = ctx.make_node("And", [is_full_range.output[0], is_reverse_steps.output[0]],
892-
op_name_scope=node.name)
893-
final_end = ctx.make_node("Where", [is_reverse_and_full_range.output[0], min_const.output[0],
886+
op1 = ctx.make_node("Less", [strides.output[0], zero_const.output[0]], op_name_scope=node.name)
887+
op2 = ctx.make_node("Equal", [new_end_mask.output[0], max_const.output[0]], op_name_scope=node.name)
888+
op3 = ctx.make_node("And", [op2.output[0], op1.output[0]], op_name_scope=node.name)
889+
final_end = ctx.make_node("Where", [op3.output[0], min_const.output[0],
894890
new_end_mask.output[0]], op_name_scope=node.name)
895891
end_output = final_end.output[0]
896892

0 commit comments

Comments
 (0)