Skip to content

Commit 651c14f

Browse files
committed
Use tracemalloc instead of memory_profiler for measuring CPU memory
1 parent 1791828 commit 651c14f

File tree

2 files changed

+28
-41
lines changed

2 files changed

+28
-41
lines changed

benchmarks/README.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
# Benchmarks for `s2fft`
22

3-
Scripts for benchmarking `s2fft` with `timeit` (and optionally `memory_profiler`).
3+
Scripts for benchmarking `s2fft` transforms.
44

55
Measures time to compute transforms for grids of parameters settings, optionally
66
outputting the results to a JSON file to allow comparing performance over versions
7-
and/or systems.
8-
If the [`memory_profiler` package](https://github.com/pythonprofilers/memory_profiler)
9-
is installed an estimate of the peak (main) memory usage of the benchmarked functions
10-
will also be recorded.
7+
and/or systems.
118
If the [`py-cpuinfo` package](https://pypi.org/project/py-cpuinfo/)
129
is installed additional information about CPU of system benchmarks are run on will be
1310
recorded in JSON output.

benchmarks/benchmarking.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,13 @@ def mean(x, n):
7575
from __future__ import annotations
7676

7777
import argparse
78+
import contextlib
7879
import datetime
7980
import inspect
8081
import json
8182
import platform
8283
import timeit
84+
import tracemalloc
8385
from ast import literal_eval
8486
from functools import partial
8587
from importlib.metadata import PackageNotFoundError, version
@@ -90,13 +92,6 @@ def mean(x, n):
9092
import jax
9193
import numpy as np
9294

93-
try:
94-
import memory_profiler
95-
96-
MEMORY_PROFILER_AVAILABLE = True
97-
except ImportError:
98-
MEMORY_PROFILER_AVAILABLE = False
99-
10095

10196
class SkipBenchmarkException(Exception):
10297
"""Exception to be raised to skip benchmark for some parameter set."""
@@ -230,8 +225,8 @@ def _format_results_entry(results_entry):
230225
+ f"min(run times): {min(results_entry['run_times_in_seconds']):>#7.2g}s, "
231226
+ f"max(run times): {max(results_entry['run_times_in_seconds']):>#7.2g}s"
232227
+ (
233-
f", peak memory: {results_entry['peak_memory_in_bytes']:>#7.2g}B"
234-
if "peak_memory_in_bytes" in results_entry
228+
f", peak memory: {results_entry['traced_memory_peak_in_bytes']:>#7.2g}B"
229+
if "traced_memory_peak_in_bytes" in results_entry
235230
else ""
236231
)
237232
+ (
@@ -339,31 +334,27 @@ def collect_benchmarks(module, benchmark_names):
339334
]
340335

341336

342-
def measure_peak_memory_usage(benchmark_function, interval):
343-
"""Measure peak memory usage in mebibytes (MiB) of a function using memory_profiler.
337+
@contextlib.contextmanager
338+
def trace_memory_allocations(n_frames=1):
339+
"""Context manager for tracing memory allocations in managed with block.
340+
341+
Returns a thunk (zero-argument function) which can be called on exit from with block
342+
to get tuple of current size and peak size of memory blocks traced in bytes.
344343
345344
Args:
346-
benchmark_function: Function to benchmark peak memory usage of.
347-
interval: Interval in seconds at which memory measurements are collected.
345+
n_frames: Limit on depth of frames to trace memory allocations in.
348346
349347
Returns:
350-
Peak memory usage measure in mebibytes (MiB).
348+
A thunk (zero-argument function) which can be called on exit from with block to
349+
get tuple of current size and peak size of memory blocks traced in bytes.
351350
"""
352-
baseline_memory = memory_profiler.memory_usage(
353-
lambda: None,
354-
max_usage=True,
355-
include_children=True,
356-
)
357-
return max(
358-
memory_profiler.memory_usage(
359-
benchmark_function,
360-
interval=interval,
361-
max_usage=True,
362-
include_children=True,
363-
)
364-
- baseline_memory,
365-
0,
366-
)
351+
tracemalloc.start(n_frames)
352+
current_size, peak_size = None, None
353+
try:
354+
yield lambda: (current_size, peak_size)
355+
current_size, peak_size = tracemalloc.get_traced_memory()
356+
finally:
357+
tracemalloc.stop()
367358

368359

369360
def _compile_jax_benchmark_and_analyse(benchmark_function, results_entry):
@@ -437,8 +428,12 @@ def run_benchmarks(
437428
benchmark_function, results_entry
438429
)
439430
# Run benchmark once without timing to record output for potentially
440-
# computing numerical error
441-
output = benchmark_function()
431+
# computing numerical error and trace memory usage
432+
with trace_memory_allocations() as traced_memory:
433+
output = benchmark_function()
434+
current_size, peak_size = traced_memory()
435+
results_entry["traced_memory_final_in_bytes"] = current_size
436+
results_entry["traced_memory_peak_in_bytes"] = peak_size
442437
if reference_output is not None and output is not None:
443438
results_entry["max_abs_error"] = abs(
444439
reference_output - output
@@ -450,11 +445,6 @@ def run_benchmarks(
450445
)
451446
]
452447
results_entry["run_times_in_seconds"] = run_times
453-
if MEMORY_PROFILER_AVAILABLE:
454-
results_entry["peak_memory_in_bytes"] = measure_peak_memory_usage(
455-
benchmark_function,
456-
interval=min(run_times) / 20,
457-
) * (2**20)
458448
results[benchmark.__name__].append(results_entry)
459449
if print_results:
460450
print(_format_results_entry(results_entry))

0 commit comments

Comments
 (0)