Skip to content

Commit 58958f2

Browse files
committed
Enforce scalar inputs for Range function
1 parent aa412b6 commit 58958f2

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,6 +1859,7 @@ def mkconsts(values, dtype=np.int64):
18591859
const_zero_float, const_neg_one_float = mkconsts([[0], [-1]], np.float32)
18601860
const_zero, const_one, const_neg_one, const_neg_two, const_pad_vals, const_t = \
18611861
mkconsts([[0], [1], [-1], [-2], pads, [-1, 1]])
1862+
const_zero_scalar, const_one_scalar, const_neg_one_scalar = mkconsts([0, 1, -1])
18621863

18631864
m_shape = ctx.make_node('Shape', [node.input[0]]).output[0]
18641865
xlen = ctx.make_node('Gather', [m_shape, const_neg_one]).output[0]
@@ -1882,24 +1883,24 @@ def mkconsts(values, dtype=np.int64):
18821883
input1 = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64})
18831884
k0 = ctx.make_node('ReduceMin', [input1.output[0]]).output[0]
18841885
k1 = ctx.make_node('ReduceMax', [input1.output[0]]).output[0]
1886+
k0_scalar = ctx.make_node('Squeeze', [k0]).output[0]
18851887
k1_scalar = ctx.make_node('Squeeze', [k1]).output[0]
18861888
m_padded = ctx.make_node('Pad', [m, const_pad_vals, node.input[2]])
18871889

18881890
# starting indexes for super diagonals
1889-
xstart_0 = ctx.make_node('Cast', [k0], attr={'to': TensorProto.FLOAT})
1891+
xstart_0 = ctx.make_node('Cast', [k0_scalar], attr={'to': TensorProto.FLOAT})
18901892
xstart_1 = ctx.make_node('Max', [const_zero_float, xstart_0.output[0]])
18911893
xstart_2 = ctx.make_node('Cast', [xstart_1.output[0]], attr={'to': TensorProto.INT64})
1892-
xstart_3 = ctx.make_node('Add', [xstart_2.output[0], const_neg_one])
1893-
xstart_4 = ctx.make_node('Range', [k1_scalar, xstart_3.output[0], const_neg_one])
1894+
xstart_3 = ctx.make_node('Add', [xstart_2.output[0], const_neg_one_scalar])
1895+
xstart_4 = ctx.make_node('Range', [k1_scalar, xstart_3.output[0], const_neg_one_scalar])
18941896
xstart = ctx.make_node('Reshape', [xstart_4.output[0], const_t])
18951897

18961898
# starting indexes for sub diagonals
1897-
ystart_0 = ctx.make_node('Cast', [k1], attr={'to': TensorProto.FLOAT})
1899+
ystart_0 = ctx.make_node('Cast', [k1_scalar], attr={'to': TensorProto.FLOAT})
18981900
ystart_1 = ctx.make_node('Min', [const_neg_one_float, ystart_0.output[0]])
18991901
ystart_2 = ctx.make_node('Cast', [ystart_1.output[0]], attr={'to': TensorProto.INT64})
1900-
ystart_2_scalar = ctx.make_node('Squeeze', [ystart_2.output[0]])
1901-
ystart_3 = ctx.make_node('Add', [k0, const_neg_one])
1902-
ystart_4 = ctx.make_node('Range', [ystart_2_scalar.output[0], ystart_3.output[0], const_neg_one])
1902+
ystart_3 = ctx.make_node('Add', [k0_scalar, const_neg_one_scalar])
1903+
ystart_4 = ctx.make_node('Range', [ystart_2.output[0], ystart_3.output[0], const_neg_one_scalar])
19031904
ystart = ctx.make_node('Reshape', [ystart_4.output[0], const_t])
19041905

19051906
xmax_0 = ctx.make_node('Mul', [xstart.output[0], xlenp])
@@ -1920,7 +1921,7 @@ def mkconsts(values, dtype=np.int64):
19201921
maxsize_0 = ctx.make_node('Reshape', [maxsize.output[0], const_neg_one])
19211922
maxsize_scalar = ctx.make_node('Squeeze', [maxsize.output[0]])
19221923

1923-
diagdistances_0 = ctx.make_node('Range', [const_zero, maxsize_scalar.output[0], const_one])
1924+
diagdistances_0 = ctx.make_node('Range', [const_zero_scalar, maxsize_scalar.output[0], const_one_scalar])
19241925
diagdistances = ctx.make_node('Mul', [diagdistances_0.output[0], stride])
19251926

19261927
def right_align(sizes, indices, starts, maxval):
@@ -1976,7 +1977,7 @@ def compute_out_shape(k0_k1_same=False):
19761977
if_node.set_body_graph_as_attr('then_branch', compute_out_shape(True))
19771978
if_node.set_body_graph_as_attr('else_branch', compute_out_shape(False))
19781979

1979-
shapes = [-1] * m_rank
1980+
shapes = ctx.get_shape(node.output[0])
19801981
dtypes = node.output_dtypes
19811982
ctx.remove_node(node.name)
19821983
ctx.make_node('Reshape', [diags.output[0], if_node.output[0]], name=node.name, outputs=node.output,

0 commit comments

Comments
 (0)