@@ -75,11 +75,13 @@ def mean(x, n):
7575from __future__ import annotations
7676
7777import argparse
78+ import contextlib
7879import datetime
7980import inspect
8081import json
8182import platform
8283import timeit
84+ import tracemalloc
8385from ast import literal_eval
8486from functools import partial
8587from importlib .metadata import PackageNotFoundError , version
@@ -90,13 +92,6 @@ def mean(x, n):
9092import jax
9193import numpy as np
9294
93- try :
94- import memory_profiler
95-
96- MEMORY_PROFILER_AVAILABLE = True
97- except ImportError :
98- MEMORY_PROFILER_AVAILABLE = False
99-
10095
10196class SkipBenchmarkException (Exception ):
10297 """Exception to be raised to skip benchmark for some parameter set."""
@@ -230,8 +225,8 @@ def _format_results_entry(results_entry):
230225 + f"min(run times): { min (results_entry ['run_times_in_seconds' ]):>#7.2g} s, "
231226 + f"max(run times): { max (results_entry ['run_times_in_seconds' ]):>#7.2g} s"
232227 + (
233- f", peak memory: { results_entry ['peak_memory_in_bytes ' ]:>#7.2g} B"
234- if "peak_memory_in_bytes " in results_entry
228+ f", peak memory: { results_entry ['traced_memory_peak_in_bytes ' ]:>#7.2g} B"
229+ if "traced_memory_peak_in_bytes " in results_entry
235230 else ""
236231 )
237232 + (
@@ -339,31 +334,27 @@ def collect_benchmarks(module, benchmark_names):
339334 ]
340335
341336
342- def measure_peak_memory_usage (benchmark_function , interval ):
343- """Measure peak memory usage in mebibytes (MiB) of a function using memory_profiler.
337+ @contextlib .contextmanager
338+ def trace_memory_allocations (n_frames = 1 ):
339+ """Context manager for tracing memory allocations in managed with block.
340+
341+ Returns a thunk (zero-argument function) which can be called on exit from with block
342+ to get tuple of current size and peak size of memory blocks traced in bytes.
344343
345344 Args:
346- benchmark_function: Function to benchmark peak memory usage of.
347- interval: Interval in seconds at which memory measurements are collected.
345+ n_frames: Limit on depth of frames to trace memory allocations in.
348346
349347 Returns:
350- Peak memory usage measure in mebibytes (MiB).
348+ A thunk (zero-argument function) which can be called on exit from with block to
349+ get tuple of current size and peak size of memory blocks traced in bytes.
351350 """
352- baseline_memory = memory_profiler .memory_usage (
353- lambda : None ,
354- max_usage = True ,
355- include_children = True ,
356- )
357- return max (
358- memory_profiler .memory_usage (
359- benchmark_function ,
360- interval = interval ,
361- max_usage = True ,
362- include_children = True ,
363- )
364- - baseline_memory ,
365- 0 ,
366- )
351+ tracemalloc .start (n_frames )
352+ current_size , peak_size = None , None
353+ try :
354+ yield lambda : (current_size , peak_size )
355+ current_size , peak_size = tracemalloc .get_traced_memory ()
356+ finally :
357+ tracemalloc .stop ()
367358
368359
369360def _compile_jax_benchmark_and_analyse (benchmark_function , results_entry ):
@@ -437,8 +428,12 @@ def run_benchmarks(
437428 benchmark_function , results_entry
438429 )
439430 # Run benchmark once without timing to record output for potentially
440- # computing numerical error
441- output = benchmark_function ()
431+ # computing numerical error and trace memory usage
432+ with trace_memory_allocations () as traced_memory :
433+ output = benchmark_function ()
434+ current_size , peak_size = traced_memory ()
435+ results_entry ["traced_memory_final_in_bytes" ] = current_size
436+ results_entry ["traced_memory_peak_in_bytes" ] = peak_size
442437 if reference_output is not None and output is not None :
443438 results_entry ["max_abs_error" ] = abs (
444439 reference_output - output
@@ -450,11 +445,6 @@ def run_benchmarks(
450445 )
451446 ]
452447 results_entry ["run_times_in_seconds" ] = run_times
453- if MEMORY_PROFILER_AVAILABLE :
454- results_entry ["peak_memory_in_bytes" ] = measure_peak_memory_usage (
455- benchmark_function ,
456- interval = min (run_times ) / 20 ,
457- ) * (2 ** 20 )
458448 results [benchmark .__name__ ].append (results_entry )
459449 if print_results :
460450 print (_format_results_entry (results_entry ))
0 commit comments