Skip to content

Commit e351091

Browse files
Added caching decorator (#554)
* Implemented caching wrapper for spectral helper * Added test for caching decorator
1 parent c5fbd09 commit e351091

File tree

2 files changed

+73
-5
lines changed

2 files changed

+73
-5
lines changed

pySDC/helpers/spectral_helper.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,54 @@
22
import scipy
33
from pySDC.implementations.datatype_classes.mesh import mesh
44
from scipy.special import factorial
5+
from functools import wraps
6+
7+
8+
def cache(func):
9+
"""
10+
Decorator for caching return values of functions.
11+
This is very similar to `functools.cache`, but without the memory leaks (see
12+
https://docs.astral.sh/ruff/rules/cached-instance-method/).
13+
14+
Example:
15+
16+
.. code-block:: python
17+
18+
num_calls = 0
19+
20+
@cache
21+
def increment(x):
22+
num_calls += 1
23+
return x + 1
24+
25+
increment(0) # returns 1, num_calls = 1
26+
increment(1) # returns 2, num_calls = 2
27+
increment(0) # returns 1, num_calls = 2
28+
29+
30+
Args:
31+
func (function): The function you want to cache the return value of
32+
33+
Returns:
34+
return value of func
35+
"""
36+
attr_cache = f"_{func.__name__}_cache"
37+
38+
@wraps(func)
39+
def wrapper(self, *args, **kwargs):
40+
if not hasattr(self, attr_cache):
41+
setattr(self, attr_cache, {})
42+
43+
cache = getattr(self, attr_cache)
44+
45+
key = (args, frozenset(kwargs.items()))
46+
if key in cache:
47+
return cache[key]
48+
result = func(self, *args, **kwargs)
49+
cache[key] = result
50+
return result
51+
52+
return wrapper
553

654

755
class SpectralHelper1D:
@@ -203,7 +251,6 @@ def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs):
203251
if self.transform_type == 'fft':
204252
self.get_fft_utils()
205253

206-
self.cache = {}
207254
self.norm = self.get_norm()
208255

209256
def get_1dgrid(self):
@@ -221,6 +268,7 @@ def get_wavenumbers(self):
221268
"""Get the domain in spectral space"""
222269
return self.xp.arange(self.N)
223270

271+
@cache
224272
def get_conv(self, name, N=None):
225273
'''
226274
Get conversion matrix between different kinds of polynomials. The supported kinds are
@@ -238,9 +286,6 @@ def get_conv(self, name, N=None):
238286
Returns:
239287
scipy.sparse: Sparse conversion matrix
240288
'''
241-
if name in self.cache.keys() and not N:
242-
return self.cache[name]
243-
244289
N = N if N else self.N
245290
sp = self.sparse_lib
246291
xp = self.xp
@@ -271,7 +316,6 @@ def get_forward_conv(name):
271316
except NotImplementedError:
272317
raise NotImplementedError from E
273318

274-
self.cache[name] = mat
275319
return mat
276320

277321
def get_basis_change_matrix(self, conv='T2T', **kwargs):

pySDC/tests/test_helpers/test_spectral_helper.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,30 @@ def test_dealias_MPI(num_procs, axis, bx, bz, nx=32, nz=64, **kwargs):
551551
run_MPI_test(num_procs=num_procs, axis=axis, nx=nx, nz=nz, bx=bx, bz=bz, test='dealias')
552552

553553

554+
@pytest.mark.base
555+
def test_cache_decorator():
556+
from pySDC.helpers.spectral_helper import cache
557+
import numpy as np
558+
559+
class Dummy:
560+
num_calls = 0
561+
562+
@cache
563+
def increment(self, x):
564+
self.num_calls += 1
565+
return x + 1
566+
567+
dummy = Dummy()
568+
values = [0, 1, 1, 0, 3, 1, 2]
569+
unique_vals = np.unique(values)
570+
571+
for x in values:
572+
assert dummy.increment(x) == x + 1
573+
574+
assert dummy.num_calls < len(values)
575+
assert dummy.num_calls == len(unique_vals)
576+
577+
554578
if __name__ == '__main__':
555579
str_to_bool = lambda me: False if me == 'False' else True
556580
str_to_tuple = lambda arg: tuple(int(me) for me in arg.split(','))

0 commit comments

Comments
 (0)