@@ -89,6 +89,7 @@ def debugprint(
8989 | Sequence [Variable | Apply | Function | FunctionGraph ],
9090 depth : int = - 1 ,
9191 print_type : bool = False ,
92+ print_shape : bool = False ,
9293 file : Literal ["str" ] | TextIO | None = None ,
9394 id_type : IDTypesType = "CHAR" ,
9495 stop_on_name : bool = False ,
@@ -98,6 +99,7 @@ def debugprint(
9899 print_op_info : bool = False ,
99100 print_destroy_map : bool = False ,
100101 print_view_map : bool = False ,
102+ print_memory_map : bool = False ,
101103 print_fgraph_inputs : bool = False ,
102104) -> str | TextIO :
103105 r"""Print a graph as text.
@@ -123,6 +125,8 @@ def debugprint(
123125 Print graph to this depth (``-1`` for unlimited).
124126 print_type
125127 If ``True``, print the `Type`\s of each `Variable` in the graph.
128+ print_shape
129+ If ``True``, print the shape of each `Variable` in the graph.
126130 file
127131 When `file` extends `TextIO`, print to it; when `file` is
128132 equal to ``"str"``, return a string; when `file` is ``None``, print to
@@ -153,6 +157,8 @@ def debugprint(
153157 Whether to print the `destroy_map`\s of printed objects
154158 print_view_map
155159 Whether to print the `view_map`\s of printed objects
160+ print_memory_map
161+ Whether to set both `print_destroy_map` and `print_view_map` to ``True``.
156162 print_fgraph_inputs
157163 Print the inputs of `FunctionGraph`\s.
158164
@@ -177,6 +183,10 @@ def debugprint(
177183 if used_ids is None :
178184 used_ids = dict ()
179185
186+ if print_memory_map :
187+ print_destroy_map = True
188+ print_view_map = True
189+
180190 inputs_to_print = []
181191 outputs_to_print = []
182192 profile_list : list [Any | None ] = []
@@ -265,6 +275,7 @@ def debugprint(
265275 depth = depth ,
266276 done = done ,
267277 print_type = print_type ,
278+ print_shape = print_shape ,
268279 file = _file ,
269280 id_type = id_type ,
270281 inner_graph_ops = inner_graph_vars ,
@@ -295,6 +306,7 @@ def debugprint(
295306 depth = depth ,
296307 done = done ,
297308 print_type = print_type ,
309+ print_shape = print_shape ,
298310 file = _file ,
299311 topo_order = topo_order ,
300312 id_type = id_type ,
@@ -365,6 +377,7 @@ def debugprint(
365377 depth = depth ,
366378 done = done ,
367379 print_type = print_type ,
380+ print_shape = print_shape ,
368381 file = _file ,
369382 id_type = id_type ,
370383 inner_graph_ops = inner_graph_vars ,
@@ -387,6 +400,7 @@ def debugprint(
387400 depth = depth ,
388401 done = done ,
389402 print_type = print_type ,
403+ print_shape = print_shape ,
390404 file = _file ,
391405 id_type = id_type ,
392406 stop_on_name = stop_on_name ,
@@ -421,6 +435,7 @@ def debugprint(
421435 depth = depth ,
422436 done = done ,
423437 print_type = print_type ,
438+ print_shape = print_shape ,
424439 file = _file ,
425440 id_type = id_type ,
426441 stop_on_name = stop_on_name ,
@@ -452,6 +467,7 @@ def _debugprint(
452467 depth : int = - 1 ,
453468 done : dict [Literal ["output" ] | Variable | Apply , str ] | None = None ,
454469 print_type : bool = False ,
470+ print_shape : bool = False ,
455471 file : TextIO = sys .stdout ,
456472 print_destroy_map : bool = False ,
457473 print_view_map : bool = False ,
@@ -484,6 +500,8 @@ def _debugprint(
484500 See `debugprint`.
485501 print_type
486502 See `debugprint`.
503+ print_shape
504+ See `debugprint`.
487505 file
488506 File-like object to which to print.
489507 print_destroy_map
@@ -532,6 +550,11 @@ def _debugprint(
532550 else :
533551 type_str = ""
534552
553+ if print_shape and hasattr (var .type , "shape" ):
554+ shape_str = f" shape={ str (var .type .shape ).replace ('None' , '?' )} "
555+ else :
556+ shape_str = ""
557+
535558 if prefix_child is None :
536559 prefix_child = prefix
537560
@@ -612,7 +635,7 @@ def get_id_str(
612635 if is_inner_graph_header :
613636 var_output = f"{ prefix } { node .op } { id_str } { destroy_map_str } { view_map_str } { o } "
614637 else :
615- var_output = f"{ prefix } { node .op } { output_idx } { id_str } { type_str } { var_name } { destroy_map_str } { view_map_str } { o } { data } "
638+ var_output = f"{ prefix } { node .op } { output_idx } { id_str } { type_str } { shape_str } { var_name } { destroy_map_str } { view_map_str } { o } { data } "
616639
617640 if print_op_info and node not in op_information :
618641 op_information .update (op_debug_information (node .op , node ))
@@ -662,6 +685,7 @@ def get_id_str(
662685 depth = depth - 1 ,
663686 done = _done ,
664687 print_type = print_type ,
688+ print_shape = print_shape ,
665689 file = file ,
666690 topo_order = topo_order ,
667691 id_type = id_type ,
@@ -692,7 +716,7 @@ def get_id_str(
692716 else :
693717 data = ""
694718
695- var_output = f"{ prefix } { var } { id_str } { type_str } { data } "
719+ var_output = f"{ prefix } { var } { id_str } { type_str } { shape_str } { data } "
696720
697721 if print_op_info and var .owner and var .owner not in op_information :
698722 op_information .update (op_debug_information (var .owner .op , var .owner ))
0 commit comments