diff --git a/pySDC/implementations/problem_classes/generic_spectral.py b/pySDC/implementations/problem_classes/generic_spectral.py index 0a9d937a7b..a8cff43611 100644 --- a/pySDC/implementations/problem_classes/generic_spectral.py +++ b/pySDC/implementations/problem_classes/generic_spectral.py @@ -212,31 +212,31 @@ def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner Dirichlet_recombination (bool): Basis conversion for right preconditioner. Useful for Chebychev and Ultraspherical methods. 10/10 would recommend. left_preconditioner (bool): If True, it will interleave the variables and reverse the Kronecker product """ - sp = self.spectral.sparse_lib N = np.prod(self.init[0][1:]) - Id = sp.eye(N) - Pl_lhs = {comp: {comp: Id} for comp in self.components} - self.Pl = self._setup_operator(Pl_lhs) - if left_preconditioner: # reverse Kronecker product - if self.spectral.useGPU: - R = self.Pl.get().tolil() * 0 + import scipy.sparse as sp else: - R = self.Pl.tolil() * 0 + sp = self.spectral.sparse_lib + + R = sp.lil_matrix((self.ncomponents * N,) * 2, dtype=int) for j in range(self.ncomponents): for i in range(N): - R[i * self.ncomponents + j, j * N + i] = 1.0 + R[i * self.ncomponents + j, j * N + i] = 1 - self.Pl = self.spectral.sparse_lib.csc_matrix(R) + self.Pl = self.spectral.sparse_lib.csc_matrix(R, dtype=complex) + else: + Id = self.spectral.sparse_lib.eye(N) + Pl_lhs = {comp: {comp: Id} for comp in self.components} + self.Pl = self._setup_operator(Pl_lhs) if Dirichlet_recombination and type(self.axes[-1]).__name__ in ['ChebychevHelper', 'UltrasphericalHelper']: _Pr = self.spectral.get_Dirichlet_recombination_matrix(axis=-1) else: - _Pr = Id + _Pr = self.spectral.sparse_lib.eye(N) Pr_lhs = {comp: {comp: _Pr} for comp in self.components} self.Pr = self._setup_operator(Pr_lhs) @ self.Pl.T @@ -393,7 +393,7 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs) def setUpFieldsIO(self): Rectilinear.setupMPI( - comm=self.comm, + comm=self.comm.commMPI if self.useGPU else self.comm, iLoc=[me.start for me in self.local_slice(False)], nLoc=[me.stop - me.start for me in self.local_slice(False)], )