Skip to content

Commit 4a94591

Browse files
waysgpytorchmergebot
authored andcommitted
filter out alloc-free pairs from trace plot (pytorch#165752)
Summary: When dealing with a large memory trace, the resulting plot can be challenging to interpret and analyze. This commit introduces a feature that enables filtering of allocations that have already been freed, providing a more focused view. The remaining events in the plot often warrant closer examination, as they may be indicative of potential out-of-memory (OOM) issues. Pull Request resolved: pytorch#165752 Approved by: https://github.com/zdevito
1 parent 5e7272b commit 4a94591

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

test/test_cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4222,6 +4222,7 @@ def run():
42224222
ss = torch.cuda.memory._snapshot()
42234223

42244224
trace_plot(ss)
4225+
trace_plot(ss, filter_freed=True)
42254226
segment_plot(ss)
42264227
text = json.dumps(ss)
42274228

torch/cuda/_memory_viz.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,18 +446,59 @@ def _format_viz(data, viz_kind, device):
446446
)
447447

448448

449-
def trace_plot(data, device=None, plot_segments=False):
449+
def filter_alloc_free_pairs(data):
450+
for dev_id in range(len(data["device_traces"])):
451+
# set of indexes of trace events for alloc-free pairs
452+
filterSet = set()
453+
# map from addr to index of alloc event
454+
allocMap = {}
455+
# set of addrs from free_requested events
456+
freeRequested = set()
457+
for idx, event in enumerate(data["device_traces"][dev_id]):
458+
if event["action"] == "alloc":
459+
allocMap[event["addr"]] = idx
460+
elif event["action"] == "free_requested":
461+
freeRequested.add(event["addr"])
462+
if allocMap.get(event["addr"]) is not None:
463+
filterSet.add(idx)
464+
filterSet.add(allocMap[event["addr"]])
465+
allocMap.pop(event["addr"])
466+
elif event["action"] == "free_completed":
467+
if event["addr"] in freeRequested:
468+
freeRequested.remove(event["addr"])
469+
filterSet.add(idx)
470+
else:
471+
print(f"free_completed without free_requested: {event}")
472+
473+
# Remove events whose index is in filterSet
474+
if filterSet:
475+
# Create a new list excluding events with indices in filterSet
476+
data["device_traces"][dev_id] = [
477+
event
478+
for idx, event in enumerate(data["device_traces"][dev_id])
479+
if idx not in filterSet
480+
]
481+
482+
return data
483+
484+
485+
def trace_plot(data, device=None, plot_segments=False, filter_freed=False):
450486
"""Generate a visualization over time of the memory usage recorded by the trace as an html file.
451487
452488
Args:
453489
data: Memory snapshot as generated from torch.cuda.memory._snapshot()
454490
device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
455491
plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
456492
Defaults to False.
493+
filter_freed (bool, optional): Filter out alloc-free paired events to only plot allocations that are not freed yet.
494+
Defaults to False to plot all trace events.
457495
458496
Returns:
459497
str: HTML of visualization
460498
"""
499+
if filter_freed:
500+
data = filter_alloc_free_pairs(data)
501+
461502
return _format_viz(
462503
data,
463504
"Active Memory Timeline"
@@ -698,6 +739,14 @@ def _output(p):
698739
"-s", "--segments", action="store_true", help=help
699740
)
700741

742+
help = (
743+
"filter out allocation-free pairs to only visualize the allocations that are not freed yet;"
744+
"useful to reduce the number of events for large traces for debugging OOM"
745+
)
746+
trace_plot_a.add_argument(
747+
"-f", "--filter_freed", action="store_true", help=help
748+
)
749+
701750
args = parser.parse_args()
702751

703752
def _read(name):
@@ -734,7 +783,12 @@ def _write(name, data):
734783
data = _read(args.input)
735784
_write(
736785
args.output,
737-
trace_plot(data, device=args.device, plot_segments=args.segments),
786+
trace_plot(
787+
data,
788+
device=args.device,
789+
plot_segments=args.segments,
790+
filter_freed=args.filter_freed,
791+
),
738792
)
739793
elif args.action == "segment_plot":
740794
data = _read(args.input)

0 commit comments

Comments
 (0)