@@ -446,18 +446,59 @@ def _format_viz(data, viz_kind, device):
446446 )
447447
448448
449- def trace_plot (data , device = None , plot_segments = False ):
449+ def filter_alloc_free_pairs (data ):
450+ for dev_id in range (len (data ["device_traces" ])):
451+ # set of indexes of trace events for alloc-free pairs
452+ filterSet = set ()
453+ # map from addr to index of alloc event
454+ allocMap = {}
455+ # set of addrs from free_requested events
456+ freeRequested = set ()
457+ for idx , event in enumerate (data ["device_traces" ][dev_id ]):
458+ if event ["action" ] == "alloc" :
459+ allocMap [event ["addr" ]] = idx
460+ elif event ["action" ] == "free_requested" :
461+ freeRequested .add (event ["addr" ])
462+ if allocMap .get (event ["addr" ]) is not None :
463+ filterSet .add (idx )
464+ filterSet .add (allocMap [event ["addr" ]])
465+ allocMap .pop (event ["addr" ])
466+ elif event ["action" ] == "free_completed" :
467+ if event ["addr" ] in freeRequested :
468+ freeRequested .remove (event ["addr" ])
469+ filterSet .add (idx )
470+ else :
471+ print (f"free_completed without free_requested: { event } " )
472+
473+ # Remove events whose index is in filterSet
474+ if filterSet :
475+ # Create a new list excluding events with indices in filterSet
476+ data ["device_traces" ][dev_id ] = [
477+ event
478+ for idx , event in enumerate (data ["device_traces" ][dev_id ])
479+ if idx not in filterSet
480+ ]
481+
482+ return data
483+
484+
485+ def trace_plot (data , device = None , plot_segments = False , filter_freed = False ):
450486 """Generate a visualization over time of the memory usage recorded by the trace as an html file.
451487
452488 Args:
453489 data: Memory snapshot as generated from torch.cuda.memory._snapshot()
454490 device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
455491 plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
456492 Defaults to False.
493+ filter_freed (bool, optional): Filter out alloc-free paired events to only plot allocations that are not freed yet.
494+ Defaults to False to plot all trace events.
457495
458496 Returns:
459497 str: HTML of visualization
460498 """
499+ if filter_freed :
500+ data = filter_alloc_free_pairs (data )
501+
461502 return _format_viz (
462503 data ,
463504 "Active Memory Timeline"
@@ -698,6 +739,14 @@ def _output(p):
698739 "-s" , "--segments" , action = "store_true" , help = help
699740 )
700741
742+ help = (
743+ "filter out allocation-free pairs to only visualize the allocations that are not freed yet;"
744+ "useful to reduce the number of events for large traces for debugging OOM"
745+ )
746+ trace_plot_a .add_argument (
747+ "-f" , "--filter_freed" , action = "store_true" , help = help
748+ )
749+
701750 args = parser .parse_args ()
702751
703752 def _read (name ):
@@ -734,7 +783,12 @@ def _write(name, data):
734783 data = _read (args .input )
735784 _write (
736785 args .output ,
737- trace_plot (data , device = args .device , plot_segments = args .segments ),
786+ trace_plot (
787+ data ,
788+ device = args .device ,
789+ plot_segments = args .segments ,
790+ filter_freed = args .filter_freed ,
791+ ),
738792 )
739793 elif args .action == "segment_plot" :
740794 data = _read (args .input )
0 commit comments