@@ -405,6 +405,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
405
405
if target is None :
406
406
target = []
407
407
self ._nodes = []
408
+ self ._consts = {}
408
409
self ._nodes_by_name = {}
409
410
self ._output_to_node_name = {}
410
411
self .shapes = {}
@@ -484,6 +485,14 @@ def inputs(self):
484
485
all_inputs .append (n )
485
486
return all_inputs
486
487
488
+ def make_consts (self , values , np_type = np .int64 , skip_conversion = False , raw = True ):
489
+ """create list of consts of same type"""
490
+ consts = []
491
+ for value in values :
492
+ np_val = np .array (value ).astype (np_type )
493
+ consts .append (self .make_const (utils .make_name ("const" ), np_val , skip_conversion , raw ).output [0 ])
494
+ return consts
495
+
487
496
def make_const (self , name , np_val , skip_conversion = False , raw = True ):
488
497
"""Make a new constant in the graph.
489
498
Args:
@@ -492,6 +501,11 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
492
501
skip_conversion: bool, indicate whether this created node would be mapped during conversion.
493
502
raw: whether to store data at field of raw_data or the specific field according to its dtype
494
503
"""
504
+
505
+ key = str (np_val ) + "_" + str (np_val .dtype )
506
+ if key in self ._consts :
507
+ return self ._consts [key ]
508
+
495
509
if raw :
496
510
onnx_tensor = numpy_helper .from_array (np_val , name )
497
511
else :
@@ -500,6 +514,8 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
500
514
dtype = onnx_tensor .data_type
501
515
node = self .make_node ("Const" , [], outputs = [name ], name = name , attr = {"value" : onnx_tensor },
502
516
skip_conversion = skip_conversion , dtypes = [dtype ], infer_shape_dtype = False )
517
+
518
+ self ._consts [key ] = node
503
519
self .set_shape (name , np_val .shape )
504
520
self .set_dtype (name , utils .map_numpy_to_onnx_dtype (np_val .dtype ))
505
521
return node
0 commit comments