Skip to content

Commit e99ee9e

Browse files
committed
Implemented caching wrapper for spectral helper
1 parent 1b5b2dc commit e99ee9e

File tree

1 file changed

+47
-5
lines changed

1 file changed

+47
-5
lines changed

pySDC/helpers/spectral_helper.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,52 @@
22
import scipy
33
from pySDC.implementations.datatype_classes.mesh import mesh
44
from scipy.special import factorial
5+
from functools import wraps
6+
7+
8+
def cache(func):
9+
"""
10+
Decorator for caching return values of functions.
11+
12+
Example:
13+
14+
.. code-block:: python
15+
16+
num_calls = 0
17+
18+
@cache
19+
def increment(x):
20+
num_calls += 1
21+
return x + 1
22+
23+
increment(0) # returns 1, num_calls = 1
24+
increment(1) # returns 2, num_calls = 2
25+
increment(0) # returns 1, num_calls = 2
26+
27+
28+
Args:
29+
func (function): The function you want to cache the return value of
30+
31+
Returns:
32+
return value of func
33+
"""
34+
attr_cache = f"_{func.__name__}_cache"
35+
36+
@wraps(func)
37+
def wrapper(self, *args, **kwargs):
38+
if not hasattr(self, attr_cache):
39+
setattr(self, attr_cache, {})
40+
41+
cache = getattr(self, attr_cache)
42+
43+
key = (args, frozenset(kwargs.items()))
44+
if key in cache:
45+
return cache[key]
46+
result = func(self, *args, **kwargs)
47+
cache[key] = result
48+
return result
49+
50+
return wrapper
551

652

753
class SpectralHelper1D:
@@ -203,7 +249,6 @@ def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs):
203249
if self.transform_type == 'fft':
204250
self.get_fft_utils()
205251

206-
self.cache = {}
207252
self.norm = self.get_norm()
208253

209254
def get_1dgrid(self):
@@ -221,6 +266,7 @@ def get_wavenumbers(self):
221266
"""Get the domain in spectral space"""
222267
return self.xp.arange(self.N)
223268

269+
@cache
224270
def get_conv(self, name, N=None):
225271
'''
226272
Get conversion matrix between different kinds of polynomials. The supported kinds are
@@ -238,9 +284,6 @@ def get_conv(self, name, N=None):
238284
Returns:
239285
scipy.sparse: Sparse conversion matrix
240286
'''
241-
if name in self.cache.keys() and not N:
242-
return self.cache[name]
243-
244287
N = N if N else self.N
245288
sp = self.sparse_lib
246289
xp = self.xp
@@ -271,7 +314,6 @@ def get_forward_conv(name):
271314
except NotImplementedError:
272315
raise NotImplementedError from E
273316

274-
self.cache[name] = mat
275317
return mat
276318

277319
def get_basis_change_matrix(self, conv='T2T', **kwargs):

0 commit comments

Comments
 (0)