Skip to content

Commit 8f38df5

Browse files
committed
refactor
1 parent 92c24fb commit 8f38df5

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

tf2onnx/graph.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,13 @@ def make_consts(self, values, np_type=np.int64, skip_conversion=False, raw=True)
490490
consts = []
491491
for value in values:
492492
np_val = np.array(value).astype(np_type)
493-
consts.append(self.make_const(utils.make_name("const"), np_val, skip_conversion, raw).output[0])
493+
key = str(np_val) + "_" + str(np_val.dtype)
494+
if key in self._consts:
495+
consts.append(self._consts[key])
496+
else:
497+
const_node = self.make_const(utils.make_name("const"), np_val, skip_conversion, raw)
498+
self._consts[key] = const_node.output[0]
499+
consts.append(const_node.output[0])
494500
return consts
495501

496502
def make_const(self, name, np_val, skip_conversion=False, raw=True):
@@ -501,11 +507,6 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
501507
skip_conversion: bool, indicate whether this created node would be mapped during conversion.
502508
raw: whether to store data at field of raw_data or the specific field according to its dtype
503509
"""
504-
505-
key = str(np_val) + "_" + str(np_val.dtype)
506-
if key in self._consts:
507-
return self._consts[key]
508-
509510
if raw:
510511
onnx_tensor = numpy_helper.from_array(np_val, name)
511512
else:
@@ -514,8 +515,6 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
514515
dtype = onnx_tensor.data_type
515516
node = self.make_node("Const", [], outputs=[name], name=name, attr={"value": onnx_tensor},
516517
skip_conversion=skip_conversion, dtypes=[dtype], infer_shape_dtype=False)
517-
518-
self._consts[key] = node
519518
self.set_shape(name, np_val.shape)
520519
self.set_dtype(name, utils.map_numpy_to_onnx_dtype(np_val.dtype))
521520
return node

0 commit comments

Comments
 (0)