@@ -425,6 +425,8 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
425
425
self ._nodes = []
426
426
self ._nodes_by_name = {}
427
427
self ._output_to_node_name = {}
428
+ self ._input_to_node_name = {}
429
+ self ._input_to_graph = {}
428
430
self .shapes = {}
429
431
self .graph_name = graph_name or "tf2onnx"
430
432
self ._is_subgraph = is_subgraph
@@ -576,6 +578,9 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
576
578
577
579
onnx_node = helper .make_node (op_type , inputs , outputs , name = name , domain = domain , ** raw_attr )
578
580
581
+ for name2 in onnx_node .input :
582
+ self ._register_input_name (name2 , onnx_node )
583
+
579
584
if op_type in ["If" , "Loop" , "Scan" ]:
580
585
# we force the op containing inner graphs not skipped during conversion.
581
586
skip_conversion = False
@@ -613,6 +618,8 @@ def append_node(self, node):
613
618
self ._output_to_node_name [name ] = node .name
614
619
self .set_dtype (name , output_dtypes [i ])
615
620
self .set_shape (name , output_shapes [i ])
621
+ for name in node .input :
622
+ self ._register_input_name (name , node )
616
623
617
624
def remove_node (self , node_name ):
618
625
"""Remove node in current graph."""
@@ -633,6 +640,12 @@ def remove_node(self, node_name):
633
640
if op_output in self ._dtypes :
634
641
del self ._dtypes [op_output ]
635
642
643
+ for op_input in node .input :
644
+ if op_input not in self ._input_to_node_name :
645
+ raise RuntimeError (
646
+ "Input %r of node %r not found." % (op_input , node_name ))
647
+ self ._unregister_input_name (op_input , node )
648
+
636
649
self ._nodes .remove (node )
637
650
node .graph = None
638
651
@@ -656,16 +669,32 @@ def reset_nodes(self, ops):
656
669
self .contained_graphs = remained_sub_graphs
657
670
self ._nodes_by_name = {op .name : op for op in ops }
658
671
self ._output_to_node_name = {}
672
+ self ._input_to_node_name = {}
659
673
for op in ops :
660
674
for op_output in op .output :
661
675
self ._output_to_node_name [op_output ] = op .name
676
+ if op .type == 'Placeholder' :
677
+ inps = [op .name ]
678
+ elif op .type == 'Const' :
679
+ inps = [op .name ]
680
+ else :
681
+ inps = op .input
682
+ for op_input in inps :
683
+ self ._register_input_name (op_input , op )
662
684
663
685
for n in self ._order_sensitive_inputs :
664
686
if n not in ops :
665
687
self ._order_sensitive_inputs .remove (n )
666
688
for o in self .outputs :
667
689
if o not in self ._output_to_node_name :
668
690
raise ValueError ("graph output " + o + " not exist" )
691
+ for i in self .inputs :
692
+ if i .name .startswith ('Placeholder' ):
693
+ continue
694
+ if i .name .startswith ('keras_learning_phase' ):
695
+ continue
696
+ if i .name not in self ._input_to_node_name :
697
+ raise ValueError ("graph input %r not exist in graph." % i .name )
669
698
670
699
self ._dtypes = remained_dtypes
671
700
self ._output_shapes = remained_shapes
@@ -782,6 +811,14 @@ def get_node_by_output_in_current_graph(self, output):
782
811
ret = self ._nodes_by_name .get (name )
783
812
return ret
784
813
814
+ def get_node_by_input_in_current_graph (self , input_name ):
815
+ """Get nodes by node input id."""
816
+ names = self ._output_to_node_name .get (input_name )
817
+ ret = None
818
+ if name :
819
+ ret = [self ._nodes_by_name .get (name ) for name in names ]
820
+ return ret
821
+
785
822
def get_node_by_name (self , name ):
786
823
"""Get node by name."""
787
824
ret = self ._nodes_by_name .get (name )
@@ -792,6 +829,8 @@ def set_node_by_name(self, node):
792
829
self ._nodes_by_name [node .name ] = node
793
830
for op_output in node .output :
794
831
self ._output_to_node_name [op_output ] = node .name
832
+ for name in node .input :
833
+ self ._register_input_name (name , node )
795
834
796
835
def change_node_name (self , node , new_name ):
797
836
"""Remove node in current graph."""
@@ -1138,8 +1177,7 @@ def dump_node_statistics(self):
1138
1177
1139
1178
return op_cnt
1140
1179
1141
- @staticmethod
1142
- def remove_input (node , to_be_removed , i = None ):
1180
+ def remove_input (self , node , to_be_removed , i = None ):
1143
1181
"""Remove input from Node.
1144
1182
Args:
1145
1183
node: the node we expect the input on
@@ -1149,11 +1187,16 @@ def remove_input(node, to_be_removed, i=None):
1149
1187
assert isinstance (node , Node ) and isinstance (to_be_removed , six .text_type )
1150
1188
if i is not None :
1151
1189
assert node .input [i ] == to_be_removed
1190
+ if node .input [i ] in self ._input_to_node_name :
1191
+ to_ops = self ._input_to_node_name [node .input [i ]]
1192
+ if node .name in to_ops :
1193
+ to_ops .remove (node .name )
1152
1194
del node .input [i ]
1153
1195
return True
1154
1196
1155
1197
for i2 , name in enumerate (node .input ):
1156
1198
if name == to_be_removed :
1199
+ self ._unregister_input_name (node .input [i2 ], node )
1157
1200
del node .input [i2 ]
1158
1201
break
1159
1202
# don't remove output from parent since others might depend on it
@@ -1205,43 +1248,90 @@ def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **k
1205
1248
new_output = port_name (name )
1206
1249
new_node = self .make_node (op_type , [output_name ], attr = kwargs , outputs = [new_output ], name = name , domain = domain )
1207
1250
1208
- to_replace = [n for n in self .get_nodes () if n != new_node ]
1251
+ to_replace = [self .get_node_by_name (n ) for n in self ._input_to_node_name [output_name ]]
1252
+ to_replace = [n for n in to_replace if n != new_node ]
1209
1253
self .replace_all_inputs (to_replace , output_name , new_output )
1210
1254
return new_node
1211
1255
1212
1256
def find_output_consumers (self , output_name ):
1213
1257
"""Find all nodes consuming a given output."""
1258
+ if output_name in self ._input_to_node_name :
1259
+ ops = self ._input_to_node_name [output_name ]
1260
+ ops = [self .get_node_by_name (n ) for n in ops ]
1261
+ else :
1262
+ ops = self .get_nodes ()
1214
1263
nodes = []
1215
- for node in self .get_nodes ():
1264
+ for node in ops :
1265
+ if node is None :
1266
+ continue
1216
1267
if output_name in node .input :
1217
1268
nodes .append (node )
1218
1269
1219
- # find consumers in sub graphs
1220
- body_graphs = node .get_body_graphs ()
1221
- if body_graphs :
1222
- for g in body_graphs .values ():
1223
- nodes .extend (g .find_output_consumers (output_name ))
1270
+ # find consumers in sub graphs
1271
+ if output_name in self ._input_to_graph :
1272
+ for _ , g in self ._input_to_graph [output_name ].items ():
1273
+ nodes .extend (g .find_output_consumers (output_name ))
1224
1274
return nodes
1225
1275
1226
- @staticmethod
1227
- def replace_all_inputs (ops , old_input , new_input ):
1228
- """Replace all inputs pointing to old_input with new_input."""
1276
+ def _register_input_name (self , input_name , node , only_graph = False ):
1277
+ "Register node taking a specific input."
1278
+ if not only_graph :
1279
+ if input_name not in self ._input_to_node_name :
1280
+ self ._input_to_node_name [input_name ] = set ()
1281
+ self ._input_to_node_name [input_name ].add (node .name )
1282
+ if self .parent_graph is not None :
1283
+ if input_name not in self .parent_graph ._input_to_graph :
1284
+ self .parent_graph ._input_to_graph [input_name ] = {}
1285
+ self .parent_graph ._input_to_graph [input_name ][id (self )] = self
1286
+ self .parent_graph ._register_input_name (input_name , node , only_graph = True )
1287
+
1288
+ def _unregister_input_name (self , input_name , node , only_graph = False ):
1289
+ "Unregister node taking a specific input."
1290
+ node_name = node .name
1291
+ if not only_graph :
1292
+ if node_name in self ._input_to_node_name [input_name ]:
1293
+ if node_name in self ._input_to_node_name [input_name ]:
1294
+ self ._input_to_node_name [input_name ].remove (node_name )
1295
+ if (self .parent_graph is not None and
1296
+ input_name in self .parent_graph ._input_to_graph and
1297
+ id (self ) in self .parent_graph ._input_to_graph [input_name ]):
1298
+ del self .parent_graph ._input_to_graph [input_name ][id (self )]
1299
+ self .parent_graph ._unregister_input_name (input_name , node , only_graph = True )
1300
+
1301
+ def replace_all_inputs (self , ops , old_input , new_input ):
1302
+ """
1303
+ Replace all inputs pointing to old_input with new_input.
1304
+ *ops* is used if defined, otherwise _input_to_node_name
1305
+ is used to determine the impacted nodes.
1306
+ """
1229
1307
if old_input == new_input :
1230
1308
return
1309
+ if new_input not in self ._input_to_node_name :
1310
+ self ._input_to_node_name [new_input ] = set ()
1311
+
1312
+ if ops is not None :
1313
+ keep_ops = True
1314
+ elif old_input in self ._input_to_node_name :
1315
+ ops = [self .get_node_by_name (n ) for n in self ._input_to_node_name [old_input ]]
1316
+ keep_ops = False
1317
+ else :
1318
+ ops = []
1319
+ keep_ops = False
1231
1320
1232
1321
for node in ops :
1322
+ assert node is not None
1233
1323
if old_input in node .input and new_input in node .output :
1234
1324
raise RuntimeError ("creating a circle in the graph is not allowed: " + node .name )
1325
+ self ._register_input_name (new_input , node )
1235
1326
1236
1327
for i , input_name in enumerate (node .input ):
1237
1328
if input_name == old_input :
1238
- node .input [i ] = new_input
1329
+ self . replace_input ( node , node .input [i ], new_input , i )
1239
1330
1240
- # modify references in sub graphs
1241
- body_graphs = node .get_body_graphs ()
1242
- if body_graphs :
1243
- for g in body_graphs .values ():
1244
- g .replace_all_inputs (g .get_nodes (), old_input , new_input )
1331
+ # modify references in sub graphs
1332
+ if old_input in self ._input_to_graph :
1333
+ for _ , g in self ._input_to_graph [old_input ].items ():
1334
+ g .replace_all_inputs (g .get_nodes () if keep_ops else None , old_input , new_input )
1245
1335
1246
1336
def replace_input (self , node , old_input , new_input , i = None ):
1247
1337
"""Replace one input in a node."""
@@ -1257,11 +1347,31 @@ def replace_input(self, node, old_input, new_input, i=None):
1257
1347
is_replaced = True
1258
1348
else :
1259
1349
raise RuntimeError ("Unable to replace input %r into %r for node %r." % (old_input , new_input , node .name ))
1350
+
1351
+ to_ops = self ._input_to_node_name .get (old_input , None )
1352
+ if to_ops is not None :
1353
+ if node .name in to_ops :
1354
+ # A node may take twice the same entry.
1355
+ to_ops .remove (node .name )
1356
+
1357
+ self ._register_input_name (new_input , node )
1260
1358
return is_replaced
1261
1359
1262
1360
def replace_inputs (self , node , new_inputs ):
1263
1361
"""Replace node inputs."""
1264
1362
assert isinstance (node , Node ) and isinstance (new_inputs , list )
1363
+
1364
+ for old_input in node .input :
1365
+ to_ops = self ._input_to_node_name .get (old_input , None )
1366
+ if to_ops is not None and old_input in to_ops :
1367
+ # To avoid issues when a node
1368
+ # takes twice the same entry.
1369
+ to_ops .remove (old_input )
1370
+
1371
+ for input_name in new_inputs :
1372
+ assert isinstance (input_name , six .text_type )
1373
+ self ._register_input_name (input_name , node )
1374
+
1265
1375
node .input = new_inputs
1266
1376
return True
1267
1377
0 commit comments