Skip to content

Commit 3d59549

Browse files
Increased performance of tau methods (#491)
1 parent 04853bb commit 3d59549

File tree

1 file changed

+39
-29
lines changed

1 file changed

+39
-29
lines changed

pySDC/helpers/spectral_helper.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,30 +1178,36 @@ def put_BCs_in_matrix(self, A):
11781178
def put_BCs_in_rhs_hat(self, rhs_hat):
11791179
"""
11801180
Put the BCs in the right hand side in spectral space for solving.
1181-
This function needs no transforms.
1181+
This function needs no transforms and caches a mask for faster subsequent use.
11821182
11831183
Args:
11841184
rhs_hat: Right hand side in spectral space
11851185
11861186
Returns:
11871187
rhs in spectral space with BCs
11881188
"""
1189-
ndim = self.ndim
1190-
1191-
for axis in range(ndim):
1192-
for bc in self.full_BCs:
1193-
slices = (
1194-
[slice(0, self.init[0][i + 1]) for i in range(axis)]
1195-
+ [bc['line']]
1196-
+ [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1197-
)
1198-
if axis == bc['axis']:
1199-
_slice = [self.index(bc['equation'])] + slices
1200-
N = self.axes[axis].N
1201-
if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
1202-
_slice[axis + 1] -= self.local_slice[axis].start
1203-
rhs_hat[(*_slice,)] = 0
1204-
1189+
if not hasattr(self, '_rhs_hat_zero_mask'):
1190+
"""
1191+
Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them
1192+
by the boundary conditions. The mask is then cached.
1193+
"""
1194+
self._rhs_hat_zero_mask = self.xp.zeros(shape=rhs_hat.shape, dtype=bool)
1195+
1196+
for axis in range(self.ndim):
1197+
for bc in self.full_BCs:
1198+
slices = (
1199+
[slice(0, self.init[0][i + 1]) for i in range(axis)]
1200+
+ [bc['line']]
1201+
+ [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1202+
)
1203+
if axis == bc['axis']:
1204+
_slice = [self.index(bc['equation'])] + slices
1205+
N = self.axes[axis].N
1206+
if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
1207+
_slice[axis + 1] -= self.local_slice[axis].start
1208+
self._rhs_hat_zero_mask[(*_slice,)] = True
1209+
1210+
rhs_hat[self._rhs_hat_zero_mask] = 0
12051211
return rhs_hat + self.rhs_BCs_hat
12061212

12071213
def put_BCs_in_rhs(self, rhs):
@@ -1347,18 +1353,22 @@ def get_fft(self, axes=None, direction='object', padding=None, shape=None):
13471353
elif direction == 'object':
13481354
self.fft_cache[key] = None
13491355
else:
1350-
from mpi4py_fft import PFFT
1351-
1352-
_fft = PFFT(
1353-
comm=self.comm,
1354-
shape=shape,
1355-
axes=sorted(axes),
1356-
dtype='D',
1357-
collapse=False,
1358-
backend=self.fft_backend,
1359-
comm_backend=self.fft_comm_backend,
1360-
padding=padding,
1361-
)
1356+
if direction == 'object':
1357+
from mpi4py_fft import PFFT
1358+
1359+
_fft = PFFT(
1360+
comm=self.comm,
1361+
shape=shape,
1362+
axes=sorted(axes),
1363+
dtype='D',
1364+
collapse=False,
1365+
backend=self.fft_backend,
1366+
comm_backend=self.fft_comm_backend,
1367+
padding=padding,
1368+
)
1369+
else:
1370+
_fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape)
1371+
13621372
if direction == 'forward':
13631373
self.fft_cache[key] = _fft.forward
13641374
elif direction == 'backward':

0 commit comments

Comments
 (0)