diff --git a/pySDC/helpers/spectral_helper.py b/pySDC/helpers/spectral_helper.py index a801503a7c..25f177daad 100644 --- a/pySDC/helpers/spectral_helper.py +++ b/pySDC/helpers/spectral_helper.py @@ -2,6 +2,54 @@ import scipy from pySDC.implementations.datatype_classes.mesh import mesh from scipy.special import factorial +from functools import wraps + + +def cache(func): + """ + Decorator for caching return values of functions. + This is very similar to `functools.cache`, but without the memory leaks (see + https://docs.astral.sh/ruff/rules/cached-instance-method/). + + Example: + + .. code-block:: python + + num_calls = 0 + + @cache + def increment(x): + num_calls += 1 + return x + 1 + + increment(0) # returns 1, num_calls = 1 + increment(1) # returns 2, num_calls = 2 + increment(0) # returns 1, num_calls = 2 + + + Args: + func (function): The function you want to cache the return value of + + Returns: + return value of func + """ + attr_cache = f"_{func.__name__}_cache" + + @wraps(func) + def wrapper(self, *args, **kwargs): + if not hasattr(self, attr_cache): + setattr(self, attr_cache, {}) + + cache = getattr(self, attr_cache) + + key = (args, frozenset(kwargs.items())) + if key in cache: + return cache[key] + result = func(self, *args, **kwargs) + cache[key] = result + return result + + return wrapper class SpectralHelper1D: @@ -203,7 +251,6 @@ def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs): if self.transform_type == 'fft': self.get_fft_utils() - self.cache = {} self.norm = self.get_norm() def get_1dgrid(self): @@ -221,6 +268,7 @@ def get_wavenumbers(self): """Get the domain in spectral space""" return self.xp.arange(self.N) + @cache def get_conv(self, name, N=None): ''' Get conversion matrix between different kinds of polynomials. The supported kinds are @@ -238,9 +286,6 @@ def get_conv(self, name, N=None): Returns: scipy.sparse: Sparse conversion matrix ''' - if name in self.cache.keys() and not N: - return self.cache[name] - N = N if N else self.N sp = self.sparse_lib xp = self.xp @@ -271,7 +316,6 @@ def get_forward_conv(name): except NotImplementedError: raise NotImplementedError from E - self.cache[name] = mat return mat def get_basis_change_matrix(self, conv='T2T', **kwargs): diff --git a/pySDC/tests/test_helpers/test_spectral_helper.py b/pySDC/tests/test_helpers/test_spectral_helper.py index 3a2f9d2136..0e27874822 100644 --- a/pySDC/tests/test_helpers/test_spectral_helper.py +++ b/pySDC/tests/test_helpers/test_spectral_helper.py @@ -551,6 +551,30 @@ def test_dealias_MPI(num_procs, axis, bx, bz, nx=32, nz=64, **kwargs): run_MPI_test(num_procs=num_procs, axis=axis, nx=nx, nz=nz, bx=bx, bz=bz, test='dealias') +@pytest.mark.base +def test_cache_decorator(): + from pySDC.helpers.spectral_helper import cache + import numpy as np + + class Dummy: + num_calls = 0 + + @cache + def increment(self, x): + self.num_calls += 1 + return x + 1 + + dummy = Dummy() + values = [0, 1, 1, 0, 3, 1, 2] + unique_vals = np.unique(values) + + for x in values: + assert dummy.increment(x) == x + 1 + + assert dummy.num_calls < len(values) + assert dummy.num_calls == len(unique_vals) + + if __name__ == '__main__': str_to_bool = lambda me: False if me == 'False' else True str_to_tuple = lambda arg: tuple(int(me) for me in arg.split(','))