Skip to content

Commit a114a14

Browse files
committed
Ensure scalar values only in matrixdiagpart->Range() function
1 parent ac395f3 commit a114a14

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1772,9 +1772,12 @@ class MatrixDiagPart:
17721772
def version_11(cls, ctx, node, **kwargs):
17731773
# MatrixDiagPart by slice and gather
17741774
const_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero', np.array([0]).astype(np.int64))
1775+
const_zero_ = ctx.make_const(utils.make_name(node.name) + 'const_zero_', np.array(0).astype(np.int64))
1776+
17751777
const_zero_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero_zero',
17761778
np.array([0, 0]).astype(np.int64))
17771779
const_one = ctx.make_const(utils.make_name(node.name) + 'const_one', np.array([1]).astype(np.int64))
1780+
const_one_ = ctx.make_const(utils.make_name(node.name) + 'const_one_', np.array(1).astype(np.int64))
17781781
const_two = ctx.make_const(utils.make_name(node.name) + 'const_two', np.array([2]).astype(np.int64))
17791782
const_negative_one = ctx.make_const(utils.make_name(node.name) + 'const_negative_one',
17801783
np.array([-1]).astype(np.int64))
@@ -1802,7 +1805,9 @@ def version_11(cls, ctx, node, **kwargs):
18021805
const_negative_one.output[0]])
18031806
sliced_input_shape_new = ctx.make_node('Concat', [sliced_input_shape_half.output[0], const_one.output[0]],
18041807
attr={'axis': -1})
1805-
matrice_range = ctx.make_node('Range', [const_zero.output[0], min_matrice_dim.output[0], const_one.output[0]])
1808+
min_matrice_dim_ = ctx.make_node('Squeeze', [min_matrice_dim.output[0]], {'axes': [0]})
1809+
matrice_range = ctx.make_node('Range', [const_zero_.output[0], min_matrice_dim_.output[0],
1810+
const_one_.output[0]])
18061811
unsqueezed_matrice_range = ctx.make_node('Unsqueeze', [matrice_range.output[0]], attr={"axes": [-1]})
18071812
expanded_range = ctx.make_node('Expand', [unsqueezed_matrice_range.output[0], sliced_input_shape_new.output[0]])
18081813
gathered_result = ctx.make_node('GatherElements', [sliced_input.output[0], expanded_range.output[0]],
@@ -1893,6 +1898,8 @@ def version_11(cls, ctx, node, **kwargs):
18931898
new_width = body_graph.make_node('Slice', [processed_shape.output[0], const_neg_one.output[0],
18941899
shape_processed_shape.output[0]])
18951900
abs_k = body_graph.make_node('Abs', [current_k.output[0]])
1901+
1902+
18961903
range_k = body_graph.make_node('Range', [abs_k.output[0], new_width.output[0], const_one.output[0]],
18971904
domain="com.microsoft")
18981905
sliced_range = body_graph.make_node('Slice', [range_k.output[0], const_zero.output[0], new_depth.output[0]])

0 commit comments

Comments
 (0)