Skip to content

Commit 38f3e2b

Browse files
Fixed some issues with transfer classes (#422)
* Fixes for transfer classes * Cleanup of `mesh_to_mesh` * Bug fixes * Fixed `fft_to_fft`
1 parent 19af3b0 commit 38f3e2b

File tree

11 files changed

+282
-175
lines changed

11 files changed

+282
-175
lines changed

pySDC/helpers/transfer_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def restriction_matrix_1d(fine_grid, coarse_grid, k=2, periodic=False, pad=1):
138138

139139
def interpolation_matrix_1d(fine_grid, coarse_grid, k=2, periodic=False, pad=1, equidist_nested=True):
140140
"""
141-
Function to contruct the restriction matrix in 1d using barycentric interpolation
141+
Function to construct the restriction matrix in 1d using barycentric interpolation
142142
143143
Args:
144144
fine_grid (np.ndarray): a one dimensional 1d array containing the nodes of the fine grid

pySDC/implementations/convergence_controller_classes/check_iteration_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
class CheckIterationEstimatorNonMPI(ConvergenceController):
77
def __init__(self, controller, params, description, **kwargs):
88
"""
9-
Initalization routine
9+
Initialization routine
1010
1111
Args:
1212
controller (pySDC.Controller): The controller

pySDC/implementations/datatype_classes/mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def __new__(cls, init, *args, **kwargs):
199199
def __getattr__(self, name):
200200
if name in self.components:
201201
if self.shape[0] == len(self.components):
202-
return self[self.components.index(name)]
202+
return self[self.components.index(name)].view(mesh)
203203
else:
204204
raise AttributeError(f'Cannot access {name!r} in {type(self)!r} because the shape is unexpected.')
205205
else:

pySDC/implementations/problem_classes/Brusselator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(self, alpha=0.1, **kwargs):
3030
shape = (2,) + (self.init[0])
3131
self.iU = 0
3232
self.iV = 1
33+
self.ncomp = 2 # needed for transfer class
3334
self.init = (shape, self.comm, np.dtype('float'))
3435

3536
def _eval_explicit_part(self, u, t, f_expl):

pySDC/implementations/problem_classes/GrayScott_MPIFFT.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(self, Du=1.0, Dv=0.01, A=0.09, B=0.086, **kwargs):
7979
shape = (2,) + (self.init[0])
8080
self.iU = 0
8181
self.iV = 1
82+
self.ncomp = 2 # needed for transfer class
8283
self.init = (shape, self.comm, self.xp.dtype('float'))
8384

8485
self._makeAttributeAndRegister('Du', 'Dv', 'A', 'B', localVars=locals(), readOnly=True)

pySDC/implementations/sweeper_classes/imex_1st_order.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,14 @@ def integrate(self):
4242
list of dtype_u: containing the integral as values
4343
"""
4444

45-
# get current level and problem description
4645
L = self.level
46+
P = L.prob
4747

4848
me = []
49-
5049
# integrate RHS over all collocation nodes
5150
for m in range(1, self.coll.num_nodes + 1):
52-
me.append(L.dt * self.coll.Qmat[m, 1] * (L.f[1].impl + L.f[1].expl))
53-
# new instance of dtype_u, initialize values with 0
54-
for j in range(2, self.coll.num_nodes + 1):
51+
me.append(P.dtype_u(P.init, val=0.0))
52+
for j in range(1, self.coll.num_nodes + 1):
5553
me[m - 1] += L.dt * self.coll.Qmat[m, j] * (L.f[j].impl + L.f[j].expl)
5654

5755
return me

pySDC/implementations/transfer_classes/TransferMesh.py

Lines changed: 46 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
import pySDC.helpers.transfer_helper as th
55
from pySDC.core.Errors import TransferError
66
from pySDC.core.SpaceTransfer import space_transfer
7-
from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh, comp2_mesh
87

98

109
class mesh_to_mesh(space_transfer):
1110
"""
12-
Custon base_transfer class, implements Transfer.py
11+
Custom base_transfer class, implements Transfer.py
1312
1413
This implementation can restrict and prolong between nd meshes with dirichlet-0 or periodic boundaries
15-
via matrix-vector products
14+
via matrix-vector products.
1615
1716
Attributes:
1817
Rspace: spatial restriction matrix, dim. Nf x Nc
@@ -30,7 +29,7 @@ def __init__(self, fine_prob, coarse_prob, params):
3029
"""
3130

3231
# invoke super initialization
33-
super(mesh_to_mesh, self).__init__(fine_prob, coarse_prob, params)
32+
super().__init__(fine_prob, coarse_prob, params)
3433

3534
if self.params.rorder % 2 != 0:
3635
raise TransferError('Need even order for restriction')
@@ -153,51 +152,31 @@ def restrict(self, F):
153152
Args:
154153
F: the fine level data (easier to access than via the fine attribute)
155154
"""
156-
if isinstance(F, mesh):
157-
G = self.coarse_prob.dtype_u(self.coarse_prob.init)
158-
if hasattr(self.fine_prob, 'ncomp'):
159-
for i in range(self.fine_prob.ncomp):
160-
tmpF = F[..., i].flatten()
161-
tmpG = self.Rspace.dot(tmpF)
162-
G[..., i] = tmpG.reshape(self.coarse_prob.nvars)
163-
else:
164-
tmpF = F.flatten()
165-
tmpG = self.Rspace.dot(tmpF)
166-
G[:] = tmpG.reshape(self.coarse_prob.nvars)
167-
elif isinstance(F, imex_mesh):
168-
G = self.coarse_prob.dtype_f(self.coarse_prob.init)
169-
if hasattr(self.fine_prob, 'ncomp'):
170-
for i in range(self.fine_prob.ncomp):
171-
tmpF = F.impl[..., i].flatten()
172-
tmpG = self.Rspace.dot(tmpF)
173-
G.impl[..., i] = tmpG.reshape(self.coarse_prob.nvars)
174-
tmpF = F.expl[..., i].flatten()
175-
tmpG = self.Rspace.dot(tmpF)
176-
G.expl[..., i] = tmpG.reshape(self.coarse_prob.nvars)
177-
else:
178-
tmpF = F.impl.flatten()
179-
tmpG = self.Rspace.dot(tmpF)
180-
G.impl[:] = tmpG.reshape(self.coarse_prob.nvars)
181-
tmpF = F.expl.flatten()
182-
tmpG = self.Rspace.dot(tmpF)
183-
G.expl[:] = tmpG.reshape(self.coarse_prob.nvars)
184-
elif isinstance(F, comp2_mesh):
185-
G = self.coarse_prob.dtype_f(self.coarse_prob.init)
155+
G = type(F)(self.coarse_prob.init)
156+
157+
def _restrict(fine, coarse):
186158
if hasattr(self.fine_prob, 'ncomp'):
187159
for i in range(self.fine_prob.ncomp):
188-
tmpF = F.comp1[..., i].flatten()
189-
tmpG = self.Rspace.dot(tmpF)
190-
G.comp1[..., i] = tmpG.reshape(self.coarse_prob.nvars)
191-
tmpF = F.comp2[..., i].flatten()
192-
tmpG = self.Rspace.dot(tmpF)
193-
G.comp2[..., i] = tmpG.reshape(self.coarse_prob.nvars)
160+
if fine.shape[-1] == self.fine_prob.ncomp:
161+
tmpF = fine[..., i].flatten()
162+
tmpG = self.Rspace.dot(tmpF)
163+
coarse[..., i] = tmpG.reshape(self.coarse_prob.nvars)
164+
elif fine.shape[0] == self.fine_prob.ncomp:
165+
tmpF = fine[i, ...].flatten()
166+
tmpG = self.Rspace.dot(tmpF)
167+
coarse[i, ...] = tmpG.reshape(self.coarse_prob.nvars)
168+
else:
169+
raise TransferError('Don\'t know how to restrict for this problem with multiple components')
194170
else:
195-
tmpF = F.comp1.flatten()
196-
tmpG = self.Rspace.dot(tmpF)
197-
G.comp1[:] = tmpG.reshape(self.coarse_prob.nvars)
198-
tmpF = F.comp2.flatten()
171+
tmpF = fine.flatten()
199172
tmpG = self.Rspace.dot(tmpF)
200-
G.comp2[:] = tmpG.reshape(self.coarse_prob.nvars)
173+
coarse[:] = tmpG.reshape(self.coarse_prob.nvars)
174+
175+
if hasattr(type(F), 'components'):
176+
for comp in F.components:
177+
_restrict(F.__getattr__(comp), G.__getattr__(comp))
178+
elif type(F).__name__ == 'mesh':
179+
_restrict(F, G)
201180
else:
202181
raise TransferError('Wrong data type for restriction, got %s' % type(F))
203182
return G
@@ -208,51 +187,32 @@ def prolong(self, G):
208187
Args:
209188
G: the coarse level data (easier to access than via the coarse attribute)
210189
"""
211-
if isinstance(G, mesh):
212-
F = self.fine_prob.dtype_u(self.fine_prob.init)
213-
if hasattr(self.fine_prob, 'ncomp'):
214-
for i in range(self.fine_prob.ncomp):
215-
tmpG = G[..., i].flatten()
216-
tmpF = self.Pspace.dot(tmpG)
217-
F[..., i] = tmpF.reshape(self.fine_prob.nvars)
218-
else:
219-
tmpG = G.flatten()
220-
tmpF = self.Pspace.dot(tmpG)
221-
F[:] = tmpF.reshape(self.fine_prob.nvars)
222-
elif isinstance(G, imex_mesh):
223-
F = self.fine_prob.dtype_f(self.fine_prob.init)
224-
if hasattr(self.fine_prob, 'ncomp'):
225-
for i in range(self.fine_prob.ncomp):
226-
tmpG = G.impl[..., i].flatten()
227-
tmpF = self.Pspace.dot(tmpG)
228-
F.impl[..., i] = tmpF.reshape(self.fine_prob.nvars)
229-
tmpG = G.expl[..., i].flatten()
230-
tmpF = self.Rspace.dot(tmpG)
231-
F.expl[..., i] = tmpF.reshape(self.fine_prob.nvars)
232-
else:
233-
tmpG = G.impl.flatten()
234-
tmpF = self.Pspace.dot(tmpG)
235-
F.impl[:] = tmpF.reshape(self.fine_prob.nvars)
236-
tmpG = G.expl.flatten()
237-
tmpF = self.Pspace.dot(tmpG)
238-
F.expl[:] = tmpF.reshape(self.fine_prob.nvars)
239-
elif isinstance(G, comp2_mesh):
240-
F = self.fine_prob.dtype_f(self.fine_prob.init)
190+
F = type(G)(self.fine_prob.init)
191+
192+
def _prolong(coarse, fine):
241193
if hasattr(self.fine_prob, 'ncomp'):
242194
for i in range(self.fine_prob.ncomp):
243-
tmpG = G.comp1[..., i].flatten()
244-
tmpF = self.Pspace.dot(tmpG)
245-
F.comp1[..., i] = tmpF.reshape(self.fine_prob.nvars)
246-
tmpG = G.comp2[..., i].flatten()
247-
tmpF = self.Rspace.dot(tmpG)
248-
F.comp2[..., i] = tmpF.reshape(self.fine_prob.nvars)
195+
if coarse.shape[-1] == self.fine_prob.ncomp:
196+
tmpG = coarse[..., i].flatten()
197+
tmpF = self.Pspace.dot(tmpG)
198+
fine[..., i] = tmpF.reshape(self.fine_prob.nvars)
199+
elif coarse.shape[0] == self.fine_prob.ncomp:
200+
tmpG = coarse[i, ...].flatten()
201+
tmpF = self.Pspace.dot(tmpG)
202+
fine[i, ...] = tmpF.reshape(self.fine_prob.nvars)
203+
else:
204+
raise TransferError('Don\'t know how to prolong for this problem with multiple components')
249205
else:
250-
tmpG = G.comp1.flatten()
251-
tmpF = self.Pspace.dot(tmpG)
252-
F.comp1[:] = tmpF.reshape(self.fine_prob.nvars)
253-
tmpG = G.comp2.flatten()
206+
tmpG = coarse.flatten()
254207
tmpF = self.Pspace.dot(tmpG)
255-
F.comp2[:] = tmpF.reshape(self.fine_prob.nvars)
208+
fine[:] = tmpF.reshape(self.fine_prob.nvars)
209+
return fine
210+
211+
if hasattr(type(F), 'components'):
212+
for comp in G.components:
213+
_prolong(G.__getattr__(comp), F.__getattr__(comp))
214+
elif type(G).__name__ == 'mesh':
215+
F[:] = _prolong(G, F)
256216
else:
257217
raise TransferError('Wrong data type for prolongation, got %s' % type(G))
258218
return F

pySDC/implementations/transfer_classes/TransferMesh_FFT.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from pySDC.core.Errors import TransferError
44
from pySDC.core.SpaceTransfer import space_transfer
5-
from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
65

76

87
class mesh_to_mesh_fft(space_transfer):
@@ -26,9 +25,9 @@ def __init__(self, fine_prob, coarse_prob, params):
2625
params: parameters for the transfer operators
2726
"""
2827
# invoke super initialization
29-
super(mesh_to_mesh_fft, self).__init__(fine_prob, coarse_prob, params)
28+
super().__init__(fine_prob, coarse_prob, params)
3029

31-
self.ratio = int(self.fine_prob.params.nvars / self.coarse_prob.params.nvars)
30+
self.ratio = int(self.fine_prob.nvars / self.coarse_prob.nvars)
3231

3332
def restrict(self, F):
3433
"""
@@ -37,11 +36,11 @@ def restrict(self, F):
3736
Args:
3837
F: the fine level data (easier to access than via the fine attribute)
3938
"""
40-
if isinstance(F, mesh):
41-
G = mesh(self.coarse_prob.init, val=0.0)
39+
G = type(F)(self.coarse_prob.init, val=0.0)
40+
41+
if type(F).__name__ == 'mesh':
4242
G[:] = F[:: self.ratio]
43-
elif isinstance(F, imex_mesh):
44-
G = imex_mesh(self.coarse_prob.init, val=0.0)
43+
elif type(F).__name__ == 'imex_mesh':
4544
G.impl[:] = F.impl[:: self.ratio]
4645
G.expl[:] = F.expl[:: self.ratio]
4746
else:
@@ -55,28 +54,21 @@ def prolong(self, G):
5554
Args:
5655
G: the coarse level data (easier to access than via the coarse attribute)
5756
"""
58-
if isinstance(G, mesh):
59-
F = mesh(self.fine_prob.init, val=0.0)
60-
tmpG = np.fft.rfft(G)
61-
tmpF = np.zeros(self.fine_prob.init[0] // 2 + 1, dtype=np.complex128)
62-
halfG = int(self.coarse_prob.init[0] / 2)
63-
tmpF[0:halfG] = tmpG[0:halfG]
64-
tmpF[-1] = tmpG[-1]
65-
F[:] = np.fft.irfft(tmpF) * self.ratio
66-
elif isinstance(G, imex_mesh):
67-
F = imex_mesh(G)
68-
tmpG_impl = np.fft.rfft(G.impl)
69-
tmpF_impl = np.zeros(self.fine_prob.init[0] // 2 + 1, dtype=np.complex128)
70-
halfG = int(self.coarse_prob.init[0] / 2)
71-
tmpF_impl[0:halfG] = tmpG_impl[0:halfG]
72-
tmpF_impl[-1] = tmpG_impl[-1]
73-
F.impl[:] = np.fft.irfft(tmpF_impl) * self.ratio
74-
tmpG_expl = np.fft.rfft(G.expl)
75-
tmpF_expl = np.zeros(self.fine_prob.init[0] // 2 + 1, dtype=np.complex128)
76-
halfG = int(self.coarse_prob.init[0] / 2)
77-
tmpF_expl[0:halfG] = tmpG_expl[0:halfG]
78-
tmpF_expl[-1] = tmpG_expl[-1]
79-
F.expl[:] = np.fft.irfft(tmpF_expl) * self.ratio
57+
F = type(G)(self.fine_prob.init, val=0.0)
58+
59+
def _prolong(coarse):
60+
coarse_hat = np.fft.rfft(coarse)
61+
fine_hat = np.zeros(self.fine_prob.init[0] // 2 + 1, dtype=np.complex128)
62+
half_idx = self.coarse_prob.init[0] // 2
63+
fine_hat[0:half_idx] = coarse_hat[0:half_idx]
64+
fine_hat[-1] = coarse_hat[-1]
65+
return np.fft.irfft(fine_hat) * self.ratio
66+
67+
if type(G).__name__ == 'mesh':
68+
F[:] = _prolong(G)
69+
elif type(G).__name__ == 'imex_mesh':
70+
F.impl[:] = _prolong(G.impl)
71+
F.expl[:] = _prolong(G.expl)
8072
else:
8173
raise TransferError('Unknown data type, got %s' % type(G))
8274
return F

0 commit comments

Comments
 (0)