Skip to content

Commit d3be284

Browse files
committed
return node
1 parent 17f7d72 commit d3be284

File tree

2 files changed

+22
-24
lines changed

2 files changed

+22
-24
lines changed

tf2onnx/graph.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,6 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
405405
if target is None:
406406
target = []
407407
self._nodes = []
408-
self._consts = {}
409408
self._nodes_by_name = {}
410409
self._output_to_node_name = {}
411410
self.shapes = {}
@@ -490,13 +489,7 @@ def make_consts(self, values, np_type=np.int64, skip_conversion=False, raw=True)
490489
consts = []
491490
for value in values:
492491
np_val = np.array(value).astype(np_type)
493-
key = str(np_val) + "_" + str(np_val.dtype)
494-
if key in self._consts:
495-
consts.append(self._consts[key])
496-
else:
497-
const_node = self.make_const(utils.make_name("const"), np_val, skip_conversion, raw)
498-
self._consts[key] = const_node.output[0]
499-
consts.append(const_node.output[0])
492+
consts.append(self.make_const(utils.make_name("const"), np_val, skip_conversion, raw))
500493
return consts
501494

502495
def make_const(self, name, np_val, skip_conversion=False, raw=True):

tf2onnx/onnx_opset/tensor.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,8 +1267,8 @@ def mknode(optype, inputs, attrs=None):
12671267
# const vals
12681268
int_max_const, one_const, minus1_const, blocklen_resize_const, \
12691269
blocklenplus1_const, block_shape_const = \
1270-
ctx.make_consts([[utils.get_max_value(np.int64)], [1], [-1],\
1271-
[-1, blocklen], [blocklen + 1], block_shape])
1270+
[n.output[0] for n in ctx.make_consts([[utils.get_max_value(np.int64)], [1], [-1],\
1271+
[-1, blocklen], [blocklen + 1], block_shape])]
12721272

12731273
x_shape = ctx.insert_new_node_on_input(node, 'Shape', node.input[0])
12741274

@@ -1299,18 +1299,18 @@ def mknode(optype, inputs, attrs=None):
12991299
p[i] = p[i - 2] + 1
13001300

13011301
# reshape to create moving blocks, shuffle, and reshape to target_spatial
1302-
indices = ctx.make_consts([list(g)])[0]
1302+
indices = ctx.make_consts([list(g)])[0].output[0]
13031303
gather = mknode('Gather', [shape1.output[0], indices])
13041304
x2 = mknode('Reshape', [input0, gather.output[0]])
13051305
tr2 = mknode('Transpose', [x2.output[0]], {'perm': np.array(p)})
13061306
shape2 = mknode('Concat', [minus1_const, target_spatial.output[0], depth.output[0]], {'axis': 0})
13071307
x3 = mknode('Reshape', [tr2.output[0], shape2.output[0]])
13081308

13091309
# crop axes
1310-
slice_starts_const1, slice_starts_const2, slice_ends_const1,\
1310+
slice_starts_const1, slice_starts_const2, slice_ends_const1, \
13111311
slice_ends_const2, axes_const = \
1312-
ctx.make_consts([[0, 0], [1, utils.get_max_value(np.int64)], [1, 0],\
1313-
[2, utils.get_max_value(np.int64)], range(1, blocklen + 1)])
1312+
[n.output[0] for n in ctx.make_consts([[0, 0], [1, utils.get_max_value(np.int64)], [1, 0],\
1313+
[2, utils.get_max_value(np.int64)], range(1, blocklen + 1)])]
13141314

13151315
crop = mknode('Cast', [input2], {'to': TensorProto.INT64})
13161316
crop_transposed = mknode('Transpose', [crop.output[0]])
@@ -1388,8 +1388,9 @@ def mknode(optype, inputs, attrs=None):
13881388
# const vals
13891389
int_max_const, zero_const, one_const, minus1_const, blocklen_resize_const, \
13901390
blocklenplus1_const, filltop_const, fillbottom_const, block_shape_const = \
1391-
ctx.make_consts([[utils.get_max_value(np.int64)], [0], [1], [-1], [-1, blocklen], \
1392-
[blocklen + 1], [1, 0, 0, 0], [0, 0, 1, 0], block_shape])
1391+
[n.output[0] for n in ctx.make_consts([[utils.get_max_value(np.int64)], [0], [1],\
1392+
[-1], [-1, blocklen], [blocklen + 1],\
1393+
[1, 0, 0, 0], [0, 0, 1, 0], block_shape])]
13931394

13941395
x_shape = ctx.insert_new_node_on_input(node, 'Shape', node.input[0])
13951396
x_rank = mknode('Size', [x_shape.output[0]])
@@ -1768,8 +1769,8 @@ class MatrixDiagPart:
17681769
def version_11(cls, ctx, node, **kwargs):
17691770
# MatrixDiagPart by slice and gather
17701771
minus_two_one, minus_two, minus_one, zeo, zeo_zeo, one, two, two_one = \
1771-
ctx.make_consts([[-2, -1], [-2], [-1], [0], [0, 0], [1], [2], [2, 1]])
1772-
zeo_, one_ = ctx.make_consts([0, 1])
1772+
[n.output[0] for n in ctx.make_consts([[-2, -1], [-2], [-1], [0], [0, 0], [1], [2], [2, 1]])]
1773+
zeo_, one_ = [n.output[0] for n in ctx.make_consts([0, 1])]
17731774

17741775
input_shape = ctx.make_node('Shape', [node.input[0]])
17751776
input_shape_size = ctx.make_node('Shape', [input_shape.output[0]])
@@ -1807,7 +1808,8 @@ class MatrixDiagPartV2V3:
18071808
@classmethod
18081809
def version_11(cls, ctx, node, **kwargs):
18091810
# assemble MatrixDiagPart V2&V3 by looping k diagonals with proper pads
1810-
minus_two, minus_one, zeo, one, two = ctx.make_consts([[-2], [-1], [0], [1], [2]])
1811+
minus_two, minus_one, zeo, one, two = \
1812+
[n.output[0] for n in ctx.make_consts([[-2], [-1], [0], [1], [2]])]
18111813

18121814
def normalize():
18131815
raw_k = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64}).output[0]
@@ -2041,10 +2043,11 @@ def version_12(cls, ctx, node, **kwargs):
20412043
xalign, yalign = align.split('_')
20422044

20432045
# consts
2044-
const_zero_float, const_neg_one_float = ctx.make_consts([0, -1], np.float32)
2046+
const_zero_float, const_neg_one_float = [n.output[0] for n in ctx.make_consts([0, -1], np.float32)]
20452047
const_zero, const_one, const_neg_one, const_neg_two, const_pad_vals, const_t = \
2046-
ctx.make_consts([[0], [1], [-1], [-2], pads, [-1, 1]])
2047-
const_zero_scalar, const_one_scalar, const_neg_one_scalar = ctx.make_consts([0, 1, -1])
2048+
[n.output[0] for n in ctx.make_consts([[0], [1], [-1], [-2], pads, [-1, 1]])]
2049+
const_zero_scalar, const_one_scalar, const_neg_one_scalar = \
2050+
[n.output[0] for n in ctx.make_consts([0, 1, -1])]
20482051

20492052
m_shape = ctx.make_node('Shape', [node.input[0]]).output[0]
20502053
xlen = ctx.make_node('Gather', [m_shape, const_neg_one]).output[0]
@@ -2184,7 +2187,8 @@ def version_12(cls, ctx, node, **kwargs):
21842187
# Assemble MatrixDiagV3 by ReverseSequence
21852188
argc = len(node.input)
21862189

2187-
minus_two, minus_one, zeo, one, two = ctx.make_consts([[-2], [-1], [0], [1], [2]])
2190+
minus_two, minus_one, zeo, one, two = \
2191+
[n.output[0] for n in ctx.make_consts([[-2], [-1], [0], [1], [2]])]
21882192

21892193
def mknode(op, args, **kwargs):
21902194
return ctx.make_node(op, args, **kwargs).output[0]
@@ -2512,7 +2516,8 @@ class MatrixSetDiagV3:
25122516
def version_12(cls, ctx, node, **kwargs):
25132517
# Assemble MatrixSetDiagV3 by MatrixDiagPartV3 and MatrixDiagV3
25142518

2515-
minus_two, minus_one, zeo, one = ctx.make_consts([[-2], [-1], [0], [1]])
2519+
minus_two, minus_one, zeo, one = \
2520+
[n.output[0] for n in ctx.make_consts([[-2], [-1], [0], [1]])]
25162521

25172522
def mknode(op, args, **kwargs):
25182523
return ctx.make_node(op, args, **kwargs).output[0]

0 commit comments

Comments
 (0)