22import scipy
33from pySDC .implementations .datatype_classes .mesh import mesh
44from scipy .special import factorial
5+ from functools import wraps
6+
7+
8+ def cache (func ):
9+ """
10+ Decorator for caching return values of functions.
11+
12+ Example:
13+
14+ .. code-block:: python
15+
16+ num_calls = 0
17+
18+ @cache
19+ def increment(x):
20+ num_calls += 1
21+ return x + 1
22+
23+ increment(0) # returns 1, num_calls = 1
24+ increment(1) # returns 2, num_calls = 2
25+ increment(0) # returns 1, num_calls = 2
26+
27+
28+ Args:
29+ func (function): The function you want to cache the return value of
30+
31+ Returns:
32+ return value of func
33+ """
34+ attr_cache = f"_{ func .__name__ } _cache"
35+
36+ @wraps (func )
37+ def wrapper (self , * args , ** kwargs ):
38+ if not hasattr (self , attr_cache ):
39+ setattr (self , attr_cache , {})
40+
41+ cache = getattr (self , attr_cache )
42+
43+ key = (args , frozenset (kwargs .items ()))
44+ if key in cache :
45+ return cache [key ]
46+ result = func (self , * args , ** kwargs )
47+ cache [key ] = result
48+ return result
49+
50+ return wrapper
551
652
753class SpectralHelper1D :
@@ -203,7 +249,6 @@ def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs):
203249 if self .transform_type == 'fft' :
204250 self .get_fft_utils ()
205251
206- self .cache = {}
207252 self .norm = self .get_norm ()
208253
209254 def get_1dgrid (self ):
@@ -221,6 +266,7 @@ def get_wavenumbers(self):
221266 """Get the domain in spectral space"""
222267 return self .xp .arange (self .N )
223268
269+ @cache
224270 def get_conv (self , name , N = None ):
225271 '''
226272 Get conversion matrix between different kinds of polynomials. The supported kinds are
@@ -238,9 +284,6 @@ def get_conv(self, name, N=None):
238284 Returns:
239285 scipy.sparse: Sparse conversion matrix
240286 '''
241- if name in self .cache .keys () and not N :
242- return self .cache [name ]
243-
244287 N = N if N else self .N
245288 sp = self .sparse_lib
246289 xp = self .xp
@@ -271,7 +314,6 @@ def get_forward_conv(name):
271314 except NotImplementedError :
272315 raise NotImplementedError from E
273316
274- self .cache [name ] = mat
275317 return mat
276318
277319 def get_basis_change_matrix (self , conv = 'T2T' , ** kwargs ):
0 commit comments