Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 39 additions & 29 deletions pySDC/helpers/spectral_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,30 +1178,36 @@ def put_BCs_in_matrix(self, A):
def put_BCs_in_rhs_hat(self, rhs_hat):
"""
Put the BCs in the right hand side in spectral space for solving.
This function needs no transforms.
This function needs no transforms and caches a mask for faster subsequent use.
Args:
rhs_hat: Right hand side in spectral space
Returns:
rhs in spectral space with BCs
"""
ndim = self.ndim

for axis in range(ndim):
for bc in self.full_BCs:
slices = (
[slice(0, self.init[0][i + 1]) for i in range(axis)]
+ [bc['line']]
+ [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
)
if axis == bc['axis']:
_slice = [self.index(bc['equation'])] + slices
N = self.axes[axis].N
if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
_slice[axis + 1] -= self.local_slice[axis].start
rhs_hat[(*_slice,)] = 0

if not hasattr(self, '_rhs_hat_zero_mask'):
"""
Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them
by the boundary conditions. The mask is then cached.
"""
self._rhs_hat_zero_mask = self.xp.zeros(shape=rhs_hat.shape, dtype=bool)

for axis in range(self.ndim):
for bc in self.full_BCs:
slices = (
[slice(0, self.init[0][i + 1]) for i in range(axis)]
+ [bc['line']]
+ [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
)
if axis == bc['axis']:
_slice = [self.index(bc['equation'])] + slices
N = self.axes[axis].N
if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
_slice[axis + 1] -= self.local_slice[axis].start
self._rhs_hat_zero_mask[(*_slice,)] = True

rhs_hat[self._rhs_hat_zero_mask] = 0
return rhs_hat + self.rhs_BCs_hat

def put_BCs_in_rhs(self, rhs):
Expand Down Expand Up @@ -1347,18 +1353,22 @@ def get_fft(self, axes=None, direction='object', padding=None, shape=None):
elif direction == 'object':
self.fft_cache[key] = None
else:
from mpi4py_fft import PFFT

_fft = PFFT(
comm=self.comm,
shape=shape,
axes=sorted(axes),
dtype='D',
collapse=False,
backend=self.fft_backend,
comm_backend=self.fft_comm_backend,
padding=padding,
)
if direction == 'object':
from mpi4py_fft import PFFT

_fft = PFFT(
comm=self.comm,
shape=shape,
axes=sorted(axes),
dtype='D',
collapse=False,
backend=self.fft_backend,
comm_backend=self.fft_comm_backend,
padding=padding,
)
else:
_fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape)

if direction == 'forward':
self.fft_cache[key] = _fft.forward
elif direction == 'backward':
Expand Down