@@ -456,6 +456,8 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
456
456
self ._nodes = []
457
457
self ._nodes_by_name = {}
458
458
self ._output_to_node_name = {}
459
+ self ._output_to_consumers = {}
460
+ self ._input_to_graph = {}
459
461
self .shapes = {}
460
462
self .graph_name = graph_name or "tf2onnx"
461
463
self ._is_subgraph = is_subgraph
@@ -502,7 +504,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
502
504
body_graph .parent_graph = self
503
505
new_node .set_body_graph_as_attr (attr_name , body_graph )
504
506
505
- self .replace_all_inputs (self .get_nodes (), o , new_output_name )
507
+ self .replace_all_inputs (o , new_output_name , ops = self .get_nodes ())
506
508
self .make_node ("Identity" , [new_output_name ], outputs = [o ], op_name_scope = n .name + "_" + "graph_outputs" )
507
509
self .copy_shape (new_output_name , o )
508
510
self .copy_dtype (new_output_name , o )
@@ -607,6 +609,9 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
607
609
608
610
onnx_node = helper .make_node (op_type , inputs , outputs , name = name , domain = domain , ** raw_attr )
609
611
612
+ for name2 in onnx_node .input :
613
+ self ._register_input_name (name2 , onnx_node )
614
+
610
615
if op_type in ["If" , "Loop" , "Scan" ]:
611
616
# we force the op containing inner graphs not skipped during conversion.
612
617
skip_conversion = False
@@ -635,6 +640,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
635
640
return node
636
641
637
642
def append_node (self , node ):
643
+ "Add a node to the graph."
638
644
output_shapes = node .output_shapes
639
645
output_dtypes = node .output_dtypes
640
646
node .graph = self
@@ -644,6 +650,8 @@ def append_node(self, node):
644
650
self ._output_to_node_name [name ] = node .name
645
651
self .set_dtype (name , output_dtypes [i ])
646
652
self .set_shape (name , output_shapes [i ])
653
+ for name in node .input :
654
+ self ._register_input_name (name , node )
647
655
648
656
def remove_node (self , node_name ):
649
657
"""Remove node in current graph."""
@@ -664,6 +672,12 @@ def remove_node(self, node_name):
664
672
if op_output in self ._dtypes :
665
673
del self ._dtypes [op_output ]
666
674
675
+ for op_input in node .input :
676
+ utils .make_sure (
677
+ op_input in self ._output_to_consumers ,
678
+ "Input %r of node %r not found." , op_input , node_name )
679
+ self ._unregister_input_name (op_input , node )
680
+
667
681
self ._nodes .remove (node )
668
682
node .graph = None
669
683
@@ -687,9 +701,13 @@ def reset_nodes(self, ops):
687
701
self .contained_graphs = remained_sub_graphs
688
702
self ._nodes_by_name = {op .name : op for op in ops }
689
703
self ._output_to_node_name = {}
704
+ self ._output_to_consumers = {}
690
705
for op in ops :
691
706
for op_output in op .output :
692
707
self ._output_to_node_name [op_output ] = op .name
708
+ inps = op .input
709
+ for op_input in inps :
710
+ self ._register_input_name (op_input , op )
693
711
694
712
for n in self ._order_sensitive_inputs :
695
713
if n not in ops :
@@ -823,6 +841,8 @@ def set_node_by_name(self, node):
823
841
self ._nodes_by_name [node .name ] = node
824
842
for op_output in node .output :
825
843
self ._output_to_node_name [op_output ] = node .name
844
+ for name in node .input :
845
+ self ._register_input_name (name , node )
826
846
827
847
def change_node_name (self , node , new_name ):
828
848
"""Remove node in current graph."""
@@ -838,7 +858,7 @@ def change_node_name(self, node, new_name):
838
858
if k == old_output :
839
859
self .outputs [j ] = new_output
840
860
break
841
- self .replace_all_inputs (self .get_nodes (), old_output , new_output )
861
+ self .replace_all_inputs (old_output , new_output , ops = self .get_nodes ())
842
862
return new_node
843
863
844
864
def add_graph_input (self , name , dtype = None , shape = None ):
@@ -1164,13 +1184,12 @@ def dump_node_statistics(self):
1164
1184
op_cnt [n .type ] += 1
1165
1185
body_graphs = n .get_body_graphs ()
1166
1186
if body_graphs :
1167
- for _ , b_g in body_graphs .items ():
1187
+ for b_g in body_graphs .values ():
1168
1188
op_cnt += b_g .dump_node_statistics ()
1169
1189
1170
1190
return op_cnt
1171
1191
1172
- @staticmethod
1173
- def remove_input (node , to_be_removed , input_index = None ):
1192
+ def remove_input (self , node , to_be_removed , input_index = None ):
1174
1193
"""Remove input from Node.
1175
1194
Args:
1176
1195
node: the node we expect the input on
@@ -1182,15 +1201,24 @@ def remove_input(node, to_be_removed, input_index=None):
1182
1201
assert isinstance (node , Node ) and isinstance (to_be_removed , six .text_type )
1183
1202
if input_index is not None :
1184
1203
assert node .input [input_index ] == to_be_removed
1204
+ if node .input [input_index ] in self ._output_to_consumers :
1205
+ to_ops = self ._output_to_consumers [node .input [input_index ]]
1206
+ if node .name in to_ops :
1207
+ to_ops .remove (node .name )
1185
1208
del node .input [input_index ]
1186
- return True
1209
+ return
1187
1210
1188
1211
for i , name in enumerate (node .input ):
1189
1212
if name == to_be_removed :
1213
+ utils .make_sure (
1214
+ node .input .count (node .input [i ]) <= 1 ,
1215
+ "Node %r takes multiple times the same input %r. This case is not handled." ,
1216
+ node .name , node .input [i ])
1217
+ self ._unregister_input_name (node .input [i ], node )
1190
1218
del node .input [i ]
1191
1219
break
1220
+
1192
1221
# don't remove output from parent since others might depend on it
1193
- return True
1194
1222
1195
1223
def insert_new_node_on_input (self , node , op_type , input_name , name = None , domain = None , ** kwargs ):
1196
1224
"""Create and insert a new node into the graph.
@@ -1238,43 +1266,93 @@ def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **k
1238
1266
new_output = port_name (name )
1239
1267
new_node = self .make_node (op_type , [output_name ], attr = kwargs , outputs = [new_output ], name = name , domain = domain )
1240
1268
1241
- to_replace = [n for n in self .get_nodes () if n != new_node ]
1242
- self .replace_all_inputs (to_replace , output_name , new_output )
1269
+ to_replace = [self .get_node_by_name (n ) for n in self ._output_to_consumers [output_name ]]
1270
+ to_replace = [n for n in to_replace if n != new_node ]
1271
+ self .replace_all_inputs (output_name , new_output , ops = to_replace )
1243
1272
return new_node
1244
1273
1245
1274
def find_output_consumers (self , output_name ):
1246
1275
"""Find all nodes consuming a given output."""
1276
+ if output_name in self ._output_to_consumers :
1277
+ ops = self ._output_to_consumers [output_name ]
1278
+ ops = [self .get_node_by_name (n ) for n in ops ]
1279
+ else :
1280
+ ops = [] # self.get_nodes()
1247
1281
nodes = []
1248
- for node in self .get_nodes ():
1282
+ for node in ops :
1283
+ if node is None :
1284
+ continue
1249
1285
if output_name in node .input :
1250
1286
nodes .append (node )
1251
1287
1252
- # find consumers in sub graphs
1253
- body_graphs = node .get_body_graphs ()
1254
- if body_graphs :
1255
- for g in body_graphs .values ():
1256
- nodes .extend (g .find_output_consumers (output_name ))
1288
+ # find consumers in sub graphs
1289
+ if output_name in self ._input_to_graph :
1290
+ for g in self ._input_to_graph [output_name ].values ():
1291
+ nodes .extend (g .find_output_consumers (output_name ))
1257
1292
return nodes
1258
1293
1259
- @staticmethod
1260
- def replace_all_inputs (ops , old_input , new_input ):
1261
- """Replace all inputs pointing to old_input with new_input."""
1294
+ def _register_input_name (self , input_name , node , only_graph = False ):
1295
+ "Register node taking a specific input."
1296
+ if not only_graph :
1297
+ if input_name not in self ._output_to_consumers :
1298
+ self ._output_to_consumers [input_name ] = set ()
1299
+ self ._output_to_consumers [input_name ].add (node .name )
1300
+ if self .parent_graph is not None :
1301
+ if input_name not in self .parent_graph ._input_to_graph :
1302
+ self .parent_graph ._input_to_graph [input_name ] = {}
1303
+ self .parent_graph ._input_to_graph [input_name ][id (self )] = self
1304
+ self .parent_graph ._register_input_name (input_name , node , only_graph = True )
1305
+
1306
+ def _unregister_input_name (self , input_name , node , only_graph = False ):
1307
+ "Unregister node taking a specific input."
1308
+ node_name = node .name
1309
+ if not only_graph :
1310
+ if input_name in self ._output_to_consumers [input_name ]:
1311
+ if node_name in self ._output_to_consumers [input_name ]:
1312
+ self ._output_to_consumers [input_name ].remove (node_name )
1313
+ if (self .parent_graph is not None and
1314
+ input_name in self .parent_graph ._input_to_graph and
1315
+ id (self ) in self .parent_graph ._input_to_graph [input_name ]):
1316
+ del self .parent_graph ._input_to_graph [input_name ][id (self )]
1317
+ self .parent_graph ._unregister_input_name (input_name , node , only_graph = True )
1318
+
1319
+ def replace_all_inputs (self , old_input , new_input , ops = None ):
1320
+ """
1321
+ Replace all inputs pointing to old_input with new_input.
1322
+ *ops* is used if defined, otherwise `_output_to_consumers`
1323
+ is used to determine the impacted nodes.
1324
+ """
1262
1325
if old_input == new_input :
1263
1326
return
1327
+ if new_input not in self ._output_to_consumers :
1328
+ self ._output_to_consumers [new_input ] = set ()
1329
+
1330
+ if ops is not None :
1331
+ keep_ops = True
1332
+ elif old_input in self ._output_to_consumers :
1333
+ ops = list (
1334
+ filter (lambda a : a is not None ,
1335
+ map (self .get_node_by_name , self ._output_to_consumers [old_input ])))
1336
+ keep_ops = False
1337
+ else :
1338
+ ops = []
1339
+ keep_ops = False
1264
1340
1265
1341
for node in ops :
1342
+ assert node is not None
1266
1343
if old_input in node .input and new_input in node .output :
1267
1344
raise RuntimeError ("creating a circle in the graph is not allowed: " + node .name )
1345
+ self ._register_input_name (new_input , node )
1268
1346
1269
1347
for i , input_name in enumerate (node .input ):
1270
1348
if input_name == old_input :
1271
- node .input [i ] = new_input
1349
+ self . replace_input ( node , node .input [i ], new_input , i )
1272
1350
1273
- # modify references in sub graphs
1274
- body_graphs = node . get_body_graphs ()
1275
- if body_graphs :
1276
- for g in body_graphs . values ():
1277
- g . replace_all_inputs ( g . get_nodes (), old_input , new_input )
1351
+ # modify references in sub graphs
1352
+ if old_input in self . _input_to_graph :
1353
+ for g in self . _input_to_graph [ old_input ]. values () :
1354
+ g . replace_all_inputs ( old_input , new_input ,
1355
+ ops = g . get_nodes () if keep_ops else None )
1278
1356
1279
1357
def replace_input (self , node , old_input , new_input , input_index = None ):
1280
1358
"""
@@ -1294,11 +1372,31 @@ def replace_input(self, node, old_input, new_input, input_index=None):
1294
1372
is_replaced = True
1295
1373
else :
1296
1374
raise RuntimeError ("Unable to replace input %r into %r for node %r." % (old_input , new_input , node .name ))
1375
+
1376
+ to_ops = self ._output_to_consumers .get (old_input , None )
1377
+ if to_ops is not None :
1378
+ if node .name in to_ops :
1379
+ # A node may take twice the same entry.
1380
+ to_ops .remove (node .name )
1381
+
1382
+ self ._register_input_name (new_input , node )
1297
1383
return is_replaced
1298
1384
1299
1385
def replace_inputs (self , node , new_inputs ):
1300
1386
"""Replace node inputs."""
1301
1387
assert isinstance (node , Node ) and isinstance (new_inputs , list )
1388
+
1389
+ for old_input in node .input :
1390
+ to_ops = self ._output_to_consumers .get (old_input , None )
1391
+ if to_ops is not None and old_input in to_ops :
1392
+ # To avoid issues when a node
1393
+ # takes twice the same entry.
1394
+ to_ops .remove (old_input )
1395
+
1396
+ for input_name in new_inputs :
1397
+ assert isinstance (input_name , six .text_type )
1398
+ self ._register_input_name (input_name , node )
1399
+
1302
1400
node .input = new_inputs
1303
1401
return True
1304
1402
@@ -1374,7 +1472,7 @@ def delete_unused_nodes(self, outputs_name):
1374
1472
for node in related_nodes :
1375
1473
attr_body_graphs = node .get_body_graphs ()
1376
1474
if attr_body_graphs :
1377
- for _ , body_graph in attr_body_graphs .items ():
1475
+ for body_graph in attr_body_graphs .values ():
1378
1476
body_graph .delete_unused_nodes (body_graph .outputs )
1379
1477
self .reset_nodes (related_nodes )
1380
1478
0 commit comments