@@ -89,6 +89,7 @@ def debugprint(
8989 | Sequence [Variable | Apply | Function | FunctionGraph ],
9090 depth : int = - 1 ,
9191 print_type : bool = False ,
92+ print_shape : bool = False ,
9293 file : Literal ["str" ] | TextIO | None = None ,
9394 id_type : IDTypesType = "CHAR" ,
9495 stop_on_name : bool = False ,
@@ -123,6 +124,8 @@ def debugprint(
123124 Print graph to this depth (``-1`` for unlimited).
124125 print_type
125126 If ``True``, print the `Type`\s of each `Variable` in the graph.
127+ print_shape
128+ If ``True``, print the shape of each `Variable` in the graph.
126129 file
127130 When `file` extends `TextIO`, print to it; when `file` is
128131 equal to ``"str"``, return a string; when `file` is ``None``, print to
@@ -265,6 +268,7 @@ def debugprint(
265268 depth = depth ,
266269 done = done ,
267270 print_type = print_type ,
271+ print_shape = print_shape ,
268272 file = _file ,
269273 id_type = id_type ,
270274 inner_graph_ops = inner_graph_vars ,
@@ -295,6 +299,7 @@ def debugprint(
295299 depth = depth ,
296300 done = done ,
297301 print_type = print_type ,
302+ print_shape = print_shape ,
298303 file = _file ,
299304 topo_order = topo_order ,
300305 id_type = id_type ,
@@ -365,6 +370,7 @@ def debugprint(
365370 depth = depth ,
366371 done = done ,
367372 print_type = print_type ,
373+ print_shape = print_shape ,
368374 file = _file ,
369375 id_type = id_type ,
370376 inner_graph_ops = inner_graph_vars ,
@@ -387,6 +393,7 @@ def debugprint(
387393 depth = depth ,
388394 done = done ,
389395 print_type = print_type ,
396+ print_shape = print_shape ,
390397 file = _file ,
391398 id_type = id_type ,
392399 stop_on_name = stop_on_name ,
@@ -421,6 +428,7 @@ def debugprint(
421428 depth = depth ,
422429 done = done ,
423430 print_type = print_type ,
431+ print_shape = print_shape ,
424432 file = _file ,
425433 id_type = id_type ,
426434 stop_on_name = stop_on_name ,
@@ -452,6 +460,7 @@ def _debugprint(
452460 depth : int = - 1 ,
453461 done : dict [Literal ["output" ] | Variable | Apply , str ] | None = None ,
454462 print_type : bool = False ,
463+ print_shape : bool = False ,
455464 file : TextIO = sys .stdout ,
456465 print_destroy_map : bool = False ,
457466 print_view_map : bool = False ,
@@ -484,6 +493,8 @@ def _debugprint(
484493 See `debugprint`.
485494 print_type
486495 See `debugprint`.
496+ print_shape
497+ See `debugprint`.
487498 file
488499 File-like object to which to print.
489500 print_destroy_map
@@ -532,6 +543,11 @@ def _debugprint(
532543 else :
533544 type_str = ""
534545
546+ if print_shape :
547+ shape_str = f" shape={ str (var .type .shape ).replace ("None" , "?" )} "
548+ else :
549+ shape_str = ""
550+
535551 if prefix_child is None :
536552 prefix_child = prefix
537553
@@ -612,7 +628,7 @@ def get_id_str(
612628 if is_inner_graph_header :
613629 var_output = f"{ prefix } { node .op } { id_str } { destroy_map_str } { view_map_str } { o } "
614630 else :
615- var_output = f"{ prefix } { node .op } { output_idx } { id_str } { type_str } { var_name } { destroy_map_str } { view_map_str } { o } { data } "
631+ var_output = f"{ prefix } { node .op } { output_idx } { id_str } { type_str } { shape_str } { var_name } { destroy_map_str } { view_map_str } { o } { data } "
616632
617633 if print_op_info and node not in op_information :
618634 op_information .update (op_debug_information (node .op , node ))
@@ -662,6 +678,7 @@ def get_id_str(
662678 depth = depth - 1 ,
663679 done = _done ,
664680 print_type = print_type ,
681+ print_shape = print_shape ,
665682 file = file ,
666683 topo_order = topo_order ,
667684 id_type = id_type ,
@@ -692,7 +709,7 @@ def get_id_str(
692709 else :
693710 data = ""
694711
695- var_output = f"{ prefix } { var } { id_str } { type_str } { data } "
712+ var_output = f"{ prefix } { var } { id_str } { type_str } { shape_str } { data } "
696713
697714 if print_op_info and var .owner and var .owner not in op_information :
698715 op_information .update (op_debug_information (var .owner .op , var .owner ))
0 commit comments