From e99ee9e98296e28f08fe1fa7913ce986711140b5 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Tue, 17 Jun 2025 10:46:12 +0200 Subject: [PATCH 1/3] Implemented caching wrapper for spectral helper --- pySDC/helpers/spectral_helper.py | 52 +++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/pySDC/helpers/spectral_helper.py b/pySDC/helpers/spectral_helper.py index a801503a7c..55a813cd8f 100644 --- a/pySDC/helpers/spectral_helper.py +++ b/pySDC/helpers/spectral_helper.py @@ -2,6 +2,52 @@ import scipy from pySDC.implementations.datatype_classes.mesh import mesh from scipy.special import factorial +from functools import wraps + + +def cache(func): + """ + Decorator for caching return values of functions. + + Example: + + .. code-block:: python + + num_calls = 0 + + @cache + def increment(x): + num_calls += 1 + return x + 1 + + increment(0) # returns 1, num_calls = 1 + increment(1) # returns 2, num_calls = 2 + increment(0) # returns 1, num_calls = 2 + + + Args: + func (function): The function you want to cache the return value of + + Returns: + return value of func + """ + attr_cache = f"_{func.__name__}_cache" + + @wraps(func) + def wrapper(self, *args, **kwargs): + if not hasattr(self, attr_cache): + setattr(self, attr_cache, {}) + + cache = getattr(self, attr_cache) + + key = (args, frozenset(kwargs.items())) + if key in cache: + return cache[key] + result = func(self, *args, **kwargs) + cache[key] = result + return result + + return wrapper class SpectralHelper1D: @@ -203,7 +249,6 @@ def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs): if self.transform_type == 'fft': self.get_fft_utils() - self.cache = {} self.norm = self.get_norm() def get_1dgrid(self): @@ -221,6 +266,7 @@ def get_wavenumbers(self): """Get the domain in spectral space""" return self.xp.arange(self.N) + @cache def get_conv(self, name, N=None): ''' Get conversion matrix between different kinds of polynomials. The supported kinds are @@ -238,9 +284,6 @@ def get_conv(self, name, N=None): Returns: scipy.sparse: Sparse conversion matrix ''' - if name in self.cache.keys() and not N: - return self.cache[name] - N = N if N else self.N sp = self.sparse_lib xp = self.xp @@ -271,7 +314,6 @@ def get_forward_conv(name): except NotImplementedError: raise NotImplementedError from E - self.cache[name] = mat return mat def get_basis_change_matrix(self, conv='T2T', **kwargs): From 6c5636808b49ff241f940a4eefeae216f3f17a86 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Tue, 17 Jun 2025 13:21:15 +0200 Subject: [PATCH 2/3] Added test for caching decorator --- pySDC/helpers/spectral_helper.py | 2 ++ .../test_helpers/test_spectral_helper.py | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/pySDC/helpers/spectral_helper.py b/pySDC/helpers/spectral_helper.py index 55a813cd8f..25f177daad 100644 --- a/pySDC/helpers/spectral_helper.py +++ b/pySDC/helpers/spectral_helper.py @@ -8,6 +8,8 @@ def cache(func): """ Decorator for caching return values of functions. + This is very similar to `functools.cache`, but without the memory leaks (see + https://docs.astral.sh/ruff/rules/cached-instance-method/). Example: diff --git a/pySDC/tests/test_helpers/test_spectral_helper.py b/pySDC/tests/test_helpers/test_spectral_helper.py index 3a2f9d2136..0e27874822 100644 --- a/pySDC/tests/test_helpers/test_spectral_helper.py +++ b/pySDC/tests/test_helpers/test_spectral_helper.py @@ -551,6 +551,30 @@ def test_dealias_MPI(num_procs, axis, bx, bz, nx=32, nz=64, **kwargs): run_MPI_test(num_procs=num_procs, axis=axis, nx=nx, nz=nz, bx=bx, bz=bz, test='dealias') +@pytest.mark.base +def test_cache_decorator(): + from pySDC.helpers.spectral_helper import cache + import numpy as np + + class Dummy: + num_calls = 0 + + @cache + def increment(self, x): + self.num_calls += 1 + return x + 1 + + dummy = Dummy() + values = [0, 1, 1, 0, 3, 1, 2] + unique_vals = np.unique(values) + + for x in values: + assert dummy.increment(x) == x + 1 + + assert dummy.num_calls < len(values) + assert dummy.num_calls == len(unique_vals) + + if __name__ == '__main__': str_to_bool = lambda me: False if me == 'False' else True str_to_tuple = lambda arg: tuple(int(me) for me in arg.split(',')) From 8bceb916e16938273ac72700b8009b4f19773e22 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Tue, 17 Jun 2025 15:15:29 +0200 Subject: [PATCH 3/3] Added test for memory leaks with cache decorator --- .../test_helpers/test_spectral_helper.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/pySDC/tests/test_helpers/test_spectral_helper.py b/pySDC/tests/test_helpers/test_spectral_helper.py index 0e27874822..639a4e71af 100644 --- a/pySDC/tests/test_helpers/test_spectral_helper.py +++ b/pySDC/tests/test_helpers/test_spectral_helper.py @@ -575,6 +575,39 @@ def increment(self, x): assert dummy.num_calls == len(unique_vals) +@pytest.mark.base +def test_cache_memory_leaks(): + from pySDC.helpers.spectral_helper import cache + + track = [0, 0] + + class KeepTrack: + + def __init__(self): + track[0] += 1 + track[1] = 0 + + @cache + def method(self, a, b, c=1, d=2): + track[1] += 1 + return f"{a},{b},c={c},d={d}" + + def __del__(self): + track[0] -= 1 + + def function(): + obj = KeepTrack() + for _ in range(10): + obj.method(1, 2, d=2) + assert track[0] == 1 + assert track[1] == 1 + + for _ in range(3): + function() + + assert track[0] == 0, "possible memory leak with the @cache" + + if __name__ == '__main__': str_to_bool = lambda me: False if me == 'False' else True str_to_tuple = lambda arg: tuple(int(me) for me in arg.split(','))