@@ -152,6 +152,28 @@ def __str__(self):
152
152
def __repr__ (self ):
153
153
return "<onnx op type='%s' name=%s>" % (self .type , self ._op .name )
154
154
155
+ @property
156
+ def summary (self ):
157
+ """Return node summary information."""
158
+ lines = []
159
+ lines .append ("OP={}" .format (self .type ))
160
+ lines .append ("Name={}" .format (self .name ))
161
+
162
+ g = self .graph
163
+ if self .input :
164
+ lines .append ("Inputs:" )
165
+ for name in self .input :
166
+ node = g .get_node_by_output (name )
167
+ op = node .type if node else "N/A"
168
+ lines .append ("\t {}={}, {}, {}" .format (name , op , g .get_shape (name ), g .get_dtype (name )))
169
+
170
+ if self .output :
171
+ for name in self .output :
172
+ lines .append ("Outpus:" )
173
+ lines .append ("\t {}={}, {}" .format (name , g .get_shape (name ), g .get_dtype (name )))
174
+
175
+ return '\n ' .join (lines )
176
+
155
177
def get_attr (self , name , default = None ):
156
178
"""Get raw attribute value."""
157
179
attr = self .attr .get (name , default )
@@ -436,6 +458,8 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
436
458
if op_name_scope :
437
459
name = "_" .join ([op_name_scope , name ])
438
460
461
+ logger .debug ("Making node: Name=%s, OP=%s" , name , op_type )
462
+
439
463
if outputs is None :
440
464
outputs = [name + ":" + str (i ) for i in range (output_count )]
441
465
@@ -479,6 +503,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
479
503
if (not shapes or not dtypes ) and infer_shape_dtype :
480
504
self .update_node_shape_dtype (node , override = False )
481
505
506
+ logger .debug ("Made node: %s\n %s" , node .name , node .summary )
482
507
self ._nodes .append (node )
483
508
return node
484
509
@@ -904,7 +929,11 @@ def dump_graph(self):
904
929
"""Dump graph with shapes (helpful for debugging)."""
905
930
for node in self .get_nodes ():
906
931
input_names = ["{}{}" .format (n , self .get_shape (n )) for n in node .input ]
907
- print ("{} {} {} {}" .format (node .type , self .get_shape (node .output [0 ]), node .name , ", " .join (input_names )))
932
+ logger .debug ("%s %s %s %s" ,
933
+ node .type ,
934
+ self .get_shape (node .output [0 ]),
935
+ node .name ,
936
+ ", " .join (input_names ))
908
937
909
938
def follow_inputs (self , node , num , space = "" ):
910
939
"""Follow inputs for (helpful for debugging)."""
0 commit comments