Skip to content

Commit bb716e5

Browse files
committed
move mkconst to graph
1 parent 9c8e9e1 commit bb716e5

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

tf2onnx/graph.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
405405
if target is None:
406406
target = []
407407
self._nodes = []
408+
self._consts = {}
408409
self._nodes_by_name = {}
409410
self._output_to_node_name = {}
410411
self.shapes = {}
@@ -484,6 +485,14 @@ def inputs(self):
484485
all_inputs.append(n)
485486
return all_inputs
486487

488+
def make_consts(self, values, np_type=np.int64, skip_conversion=False, raw=True):
489+
"""create list of consts of same type"""
490+
consts = []
491+
for value in values:
492+
np_val = np.array(value).astype(np_type)
493+
consts.append(self.make_const(utils.make_name("const"), np_val, skip_conversion, raw).output[0])
494+
return consts
495+
487496
def make_const(self, name, np_val, skip_conversion=False, raw=True):
488497
"""Make a new constant in the graph.
489498
Args:
@@ -492,6 +501,11 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
492501
skip_conversion: bool, indicate whether this created node would be mapped during conversion.
493502
raw: whether to store data at field of raw_data or the specific field according to its dtype
494503
"""
504+
505+
key = str(np_val) + "_" + str(np_val.dtype)
506+
if key in self._consts:
507+
return self._consts[key]
508+
495509
if raw:
496510
onnx_tensor = numpy_helper.from_array(np_val, name)
497511
else:
@@ -500,6 +514,8 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
500514
dtype = onnx_tensor.data_type
501515
node = self.make_node("Const", [], outputs=[name], name=name, attr={"value": onnx_tensor},
502516
skip_conversion=skip_conversion, dtypes=[dtype], infer_shape_dtype=False)
517+
518+
self._consts[key] = node
503519
self.set_shape(name, np_val.shape)
504520
self.set_dtype(name, utils.map_numpy_to_onnx_dtype(np_val.dtype))
505521
return node

tf2onnx/onnx_opset/tensor.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,11 +2223,7 @@ def version_12(cls, ctx, node, **kwargs):
22232223
# Assemble MatrixDiagV3 by ReverseSequence
22242224
argc = len(node.input)
22252225

2226-
def mkconsts(values):
2227-
return [ctx.make_const(utils.make_name('const'), \
2228-
np.array(value).astype(np.int64)).output[0] for value in values]
2229-
2230-
minus_two, minus_one, zeo, one, two = mkconsts([[-2], [-1], [0], [1], [2]])
2226+
minus_two, minus_one, zeo, one, two = ctx.make_consts([[-2], [-1], [0], [1], [2]])
22312227

22322228
def mknode(op, args, **kwargs):
22332229
return ctx.make_node(op, args, **kwargs).output[0]
@@ -2554,11 +2550,8 @@ class MatrixSetDiagV3:
25542550
@classmethod
25552551
def version_12(cls, ctx, node, **kwargs):
25562552
# Assemble MatrixSetDiagV3 by MatrixDiagPartV3 and MatrixDiagV3
2557-
def mkconsts(values):
2558-
return [ctx.make_const(utils.make_name('const'), \
2559-
np.array(value).astype(np.int64)).output[0] for value in values]
25602553

2561-
minus_two, minus_one, zeo, one = mkconsts([[-2], [-1], [0], [1]])
2554+
minus_two, minus_one, zeo, one = ctx.make_consts([[-2], [-1], [0], [1]])
25622555

25632556
def mknode(op, args, **kwargs):
25642557
return ctx.make_node(op, args, **kwargs).output[0]

0 commit comments

Comments
 (0)