Skip to content

Commit bf33b6a

Browse files
committed
support domain when make node
1 parent 4354a03 commit bf33b6a

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

tf2onnx/graph.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,8 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
378378

379379
new_outputs = [o if o != output else new_output_name for output in n.output]
380380
new_node = self.make_node(n.type, n.input, outputs=new_outputs, attr=n.attr, name=n.name,
381-
skip_conversion=n._skip_conversion, dtypes=n_dtypes, shapes=n_shapes)
381+
skip_conversion=n._skip_conversion, dtypes=n_dtypes, shapes=n_shapes,
382+
domain=n.domain)
382383

383384
if body_graphs:
384385
for attr_name, body_graph in body_graphs.items():
@@ -423,7 +424,7 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
423424
return node
424425

425426
def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, skip_conversion=True,
426-
op_name_scope=None, name=None, shapes=None, dtypes=None):
427+
op_name_scope=None, name=None, shapes=None, dtypes=None, domain=None):
427428
"""Make a new onnx node in the graph"""
428429
if attr is None:
429430
attr = {}
@@ -456,7 +457,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
456457
n = self.get_node_by_output_in_current_graph(o)
457458
utils.make_sure(n is None, "output tensor named %s already exists in node: \n%s", o, n)
458459

459-
onnx_node = helper.make_node(op_type, inputs, outputs, name=name, **raw_attr)
460+
onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr)
460461

461462
if op_type in ["If", "Loop", "Scan"]:
462463
# we force the op containing inner graphs not skipped during conversion.
@@ -883,7 +884,7 @@ def remove_input(node, to_be_removed):
883884
# don't remove output from parent since others might depend on it
884885
return True
885886

886-
def insert_new_node_on_input(self, node, op_type, input_name, name=None, **kwargs):
887+
def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=None, **kwargs):
887888
"""Create and insert a new node into the graph.
888889
Args:
889890
node: we want to replace the input for this node
@@ -898,14 +899,14 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, **kwarg
898899
if name is None:
899900
name = utils.make_name(node.name)
900901
new_output = port_name(name)
901-
new_node = self.make_node(op_type, [input_name], attr=kwargs, outputs=[new_output], name=name)
902+
new_node = self.make_node(op_type, [input_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)
902903
for i, n in enumerate(node.input):
903904
if n == input_name:
904905
node.input[i] = new_output
905906
break
906907
return new_node
907908

908-
def insert_new_node_on_output(self, op_type, output_name, name=None, **kwargs):
909+
def insert_new_node_on_output(self, op_type, output_name, name=None, domain=None, **kwargs):
909910
"""Create and insert a new node into the graph.
910911
Args:
911912
op_type: type for new operation
@@ -922,7 +923,7 @@ def insert_new_node_on_output(self, op_type, output_name, name=None, **kwargs):
922923
type(op_type))
923924

924925
new_output = port_name(name)
925-
new_node = self.make_node(op_type, [output_name], attr=kwargs, outputs=[new_output], name=name)
926+
new_node = self.make_node(op_type, [output_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)
926927

927928
to_replace = [n for n in self.get_nodes() if n != new_node]
928929
self.replace_all_inputs(to_replace, output_name, new_output)

0 commit comments

Comments
 (0)