22import scipy
33from pySDC .implementations .datatype_classes .mesh import mesh
44from scipy .special import factorial
5+ from functools import wraps
6+
7+
8+ def cache (func ):
9+ """
10+ Decorator for caching return values of functions.
11+ This is very similar to `functools.cache`, but without the memory leaks (see
12+ https://docs.astral.sh/ruff/rules/cached-instance-method/).
13+
14+ Example:
15+
16+ .. code-block:: python
17+
18+ num_calls = 0
19+
20+ @cache
21+ def increment(x):
22+ num_calls += 1
23+ return x + 1
24+
25+ increment(0) # returns 1, num_calls = 1
26+ increment(1) # returns 2, num_calls = 2
27+ increment(0) # returns 1, num_calls = 2
28+
29+
30+ Args:
31+ func (function): The function you want to cache the return value of
32+
33+ Returns:
34+ return value of func
35+ """
36+ attr_cache = f"_{ func .__name__ } _cache"
37+
38+ @wraps (func )
39+ def wrapper (self , * args , ** kwargs ):
40+ if not hasattr (self , attr_cache ):
41+ setattr (self , attr_cache , {})
42+
43+ cache = getattr (self , attr_cache )
44+
45+ key = (args , frozenset (kwargs .items ()))
46+ if key in cache :
47+ return cache [key ]
48+ result = func (self , * args , ** kwargs )
49+ cache [key ] = result
50+ return result
51+
52+ return wrapper
553
654
755class SpectralHelper1D :
@@ -203,7 +251,6 @@ def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs):
203251 if self .transform_type == 'fft' :
204252 self .get_fft_utils ()
205253
206- self .cache = {}
207254 self .norm = self .get_norm ()
208255
209256 def get_1dgrid (self ):
@@ -221,6 +268,7 @@ def get_wavenumbers(self):
221268 """Get the domain in spectral space"""
222269 return self .xp .arange (self .N )
223270
271+ @cache
224272 def get_conv (self , name , N = None ):
225273 '''
226274 Get conversion matrix between different kinds of polynomials. The supported kinds are
@@ -238,9 +286,6 @@ def get_conv(self, name, N=None):
238286 Returns:
239287 scipy.sparse: Sparse conversion matrix
240288 '''
241- if name in self .cache .keys () and not N :
242- return self .cache [name ]
243-
244289 N = N if N else self .N
245290 sp = self .sparse_lib
246291 xp = self .xp
@@ -271,7 +316,6 @@ def get_forward_conv(name):
271316 except NotImplementedError :
272317 raise NotImplementedError from E
273318
274- self .cache [name ] = mat
275319 return mat
276320
277321 def get_basis_change_matrix (self , conv = 'T2T' , ** kwargs ):
0 commit comments