Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions pySDC/helpers/spectral_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,54 @@
import scipy
from pySDC.implementations.datatype_classes.mesh import mesh
from scipy.special import factorial
from functools import wraps


def cache(func):
"""
Decorator for caching return values of functions.
This is very similar to `functools.cache`, but without the memory leaks (see
https://docs.astral.sh/ruff/rules/cached-instance-method/).

Example:

.. code-block:: python

num_calls = 0

@cache
def increment(x):
num_calls += 1
return x + 1

increment(0) # returns 1, num_calls = 1
increment(1) # returns 2, num_calls = 2
increment(0) # returns 1, num_calls = 2


Args:
func (function): The function you want to cache the return value of

Returns:
return value of func
"""
attr_cache = f"_{func.__name__}_cache"

@wraps(func)
def wrapper(self, *args, **kwargs):
if not hasattr(self, attr_cache):
setattr(self, attr_cache, {})

cache = getattr(self, attr_cache)

key = (args, frozenset(kwargs.items()))
if key in cache:
return cache[key]
result = func(self, *args, **kwargs)
cache[key] = result
return result

return wrapper


class SpectralHelper1D:
Expand Down Expand Up @@ -203,7 +251,6 @@ def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs):
if self.transform_type == 'fft':
self.get_fft_utils()

self.cache = {}
self.norm = self.get_norm()

def get_1dgrid(self):
Expand All @@ -221,6 +268,7 @@ def get_wavenumbers(self):
"""Get the domain in spectral space"""
return self.xp.arange(self.N)

@cache
def get_conv(self, name, N=None):
'''
Get conversion matrix between different kinds of polynomials. The supported kinds are
Expand All @@ -238,9 +286,6 @@ def get_conv(self, name, N=None):
Returns:
scipy.sparse: Sparse conversion matrix
'''
if name in self.cache.keys() and not N:
return self.cache[name]

N = N if N else self.N
sp = self.sparse_lib
xp = self.xp
Expand Down Expand Up @@ -271,7 +316,6 @@ def get_forward_conv(name):
except NotImplementedError:
raise NotImplementedError from E

self.cache[name] = mat
return mat

def get_basis_change_matrix(self, conv='T2T', **kwargs):
Expand Down
24 changes: 24 additions & 0 deletions pySDC/tests/test_helpers/test_spectral_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,30 @@ def test_dealias_MPI(num_procs, axis, bx, bz, nx=32, nz=64, **kwargs):
run_MPI_test(num_procs=num_procs, axis=axis, nx=nx, nz=nz, bx=bx, bz=bz, test='dealias')


@pytest.mark.base
def test_cache_decorator():
from pySDC.helpers.spectral_helper import cache
import numpy as np

class Dummy:
num_calls = 0

@cache
def increment(self, x):
self.num_calls += 1
return x + 1

dummy = Dummy()
values = [0, 1, 1, 0, 3, 1, 2]
unique_vals = np.unique(values)

for x in values:
assert dummy.increment(x) == x + 1

assert dummy.num_calls < len(values)
assert dummy.num_calls == len(unique_vals)


if __name__ == '__main__':
str_to_bool = lambda me: False if me == 'False' else True
str_to_tuple = lambda arg: tuple(int(me) for me in arg.split(','))
Expand Down