Skip to content

Commit 7d622fe

Browse files
committed
normalize k
1 parent 7127c43 commit 7d622fe

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,8 +1835,16 @@ class MatrixDiagPartV2V3:
18351835
@classmethod
18361836
def version_11(cls, ctx, node, **kwargs):
18371837
# assemble MatrixDiagPart V2&V3 by looping k diagonals with proper pads
1838+
const_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero', np.array([0]).astype(np.int64))
1839+
const_one = ctx.make_const(utils.make_name(node.name) + 'const_one', np.array([1]).astype(np.int64))
1840+
const_two = ctx.make_const(utils.make_name(node.name) + 'const_two', np.array([2]).astype(np.int64))
1841+
const_neg_one = ctx.make_const(utils.make_name(node.name) + 'const_neg_one', np.array([-1]).astype(np.int64))
1842+
const_neg_two = ctx.make_const(utils.make_name(node.name) + 'const_neg_two', np.array([-2]).astype(np.int64))
1843+
def normalize():
1844+
raw_k = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64}).output[0]
1845+
return ctx.make_node('Reshape', [raw_k, const_neg_one.output[0]]).output[0]
18381846
input_tensor = node.input[0]
1839-
k = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64}).output[0]
1847+
k = normalize()
18401848
padding = node.input[2]
18411849
align = 'LEFT_LEFT'
18421850
if node.op.op_type == 'MatrixDiagPartV3':
@@ -1850,12 +1858,7 @@ def version_11(cls, ctx, node, **kwargs):
18501858
for out in ctx.find_output_consumers(node.output[0]):
18511859
if out.op.op_type == 'Identity':
18521860
ctx.set_shape(out.output[0], raw_output_shape)
1853-
# define constants
1854-
const_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero', np.array([0]).astype(np.int64))
1855-
const_one = ctx.make_const(utils.make_name(node.name) + 'const_one', np.array([1]).astype(np.int64))
1856-
const_two = ctx.make_const(utils.make_name(node.name) + 'const_two', np.array([2]).astype(np.int64))
1857-
const_neg_one = ctx.make_const(utils.make_name(node.name) + 'const_neg_one', np.array([-1]).astype(np.int64))
1858-
const_neg_two = ctx.make_const(utils.make_name(node.name) + 'const_neg_two', np.array([-2]).astype(np.int64))
1861+
18591862
# prepare new_shape of input
18601863
input_shape = ctx.make_node('Shape', [input_tensor])
18611864
shape_input_shape = ctx.make_node('Shape', [input_shape.output[0]])
@@ -2075,7 +2078,7 @@ def mkconsts(values, dtype=np.int64):
20752078
xalign, yalign = align.split('_')
20762079

20772080
# consts
2078-
const_zero_float, const_neg_one_float = mkconsts([[0], [-1]], np.float32)
2081+
const_zero_float, const_neg_one_float = mkconsts([0, -1], np.float32)
20792082
const_zero, const_one, const_neg_one, const_neg_two, const_pad_vals, const_t = \
20802083
mkconsts([[0], [1], [-1], [-2], pads, [-1, 1]])
20812084
const_zero_scalar, const_one_scalar, const_neg_one_scalar = mkconsts([0, 1, -1])
@@ -2098,8 +2101,12 @@ def mkconsts(values, dtype=np.int64):
20982101
m2_shape = ctx.make_node('Concat', [partial_shape, const_neg_one], attr={'axis': 0}).output[0]
20992102
gather_shape = ctx.make_node('Concat', [partial_shape, const_one], attr={'axis': 0}).output[0]
21002103

2104+
def normalize():
2105+
raw_input1 = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64}).output[0]
2106+
return ctx.make_node('Reshape', [raw_input1, const_neg_one])
2107+
21012108
# get k0, k1 values. diags to be extracted
2102-
input1 = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64})
2109+
input1 = normalize()
21032110
k0 = ctx.make_node('ReduceMin', [input1.output[0]]).output[0]
21042111
k1 = ctx.make_node('ReduceMax', [input1.output[0]]).output[0]
21052112
k0_scalar = ctx.make_node('Squeeze', [k0]).output[0]

0 commit comments

Comments
 (0)