Skip to content

Commit 5b92e94

Browse files
Implemented relative error estimates for adaptive step size selection (#515)
1 parent 5ac6441 commit 5b92e94

File tree

4 files changed

+200
-21
lines changed

4 files changed

+200
-21
lines changed

pySDC/implementations/convergence_controller_classes/adaptivity.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def setup(self, controller, params, description, **kwargs):
307307
"""
308308
defaults = {
309309
"embedded_error_flavor": 'standard',
310+
"rel_error": False,
310311
}
311312
return {**defaults, **super().setup(controller, params, description, **kwargs)}
312313

@@ -328,6 +329,9 @@ def dependencies(self, controller, description, **kwargs):
328329
controller.add_convergence_controller(
329330
EstimateEmbeddedError.get_implementation(self.params.embedded_error_flavor, self.params.useMPI),
330331
description=description,
332+
params={
333+
'rel_error': self.params.rel_error,
334+
},
331335
)
332336

333337
# load contraction factor estimator if necessary
@@ -837,6 +841,8 @@ def setup(self, controller, params, description, **kwargs):
837841

838842
defaults = {
839843
'control_order': -50,
844+
'problem_mesh_type': 'numpyesque',
845+
'rel_error': False,
840846
**super().setup(controller, params, description, **kwargs),
841847
**params,
842848
}
@@ -858,16 +864,27 @@ def dependencies(self, controller, description, **kwargs):
858864
Returns:
859865
None
860866
"""
861-
from pySDC.implementations.convergence_controller_classes.estimate_polynomial_error import (
862-
EstimatePolynomialError,
863-
)
867+
if self.params.problem_mesh_type.lower() == 'numpyesque':
868+
from pySDC.implementations.convergence_controller_classes.estimate_polynomial_error import (
869+
EstimatePolynomialError as error_estimation_cls,
870+
)
871+
elif self.params.problem_mesh_type.lower() == 'firedrake':
872+
from pySDC.implementations.convergence_controller_classes.estimate_polynomial_error import (
873+
EstimatePolynomialErrorFiredrake as error_estimation_cls,
874+
)
875+
else:
876+
raise NotImplementedError(
877+
f'Don\'t know what error estimation class to use for problems with mesh type {self.params.problem_mesh_type}'
878+
)
864879

865880
super().dependencies(controller, description, **kwargs)
866881

867882
controller.add_convergence_controller(
868-
EstimatePolynomialError,
883+
error_estimation_cls,
869884
description=description,
870-
params={},
885+
params={
886+
'rel_error': self.params.rel_error,
887+
},
871888
)
872889
return None
873890

pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def setup(self, controller, params, description, **kwargs):
5757
return {
5858
"control_order": -80,
5959
"sweeper_type": sweeper_type,
60+
"rel_error": False,
6061
**super().setup(controller, params, description, **kwargs),
6162
}
6263

@@ -94,13 +95,24 @@ def estimate_embedded_error_serial(self, L):
9495
"""
9596
if self.params.sweeper_type == "RK":
9697
L.sweep.compute_end_point()
97-
return abs(L.uend - L.sweep.u_secondary)
98+
if self.params.rel_error:
99+
return abs(L.uend - L.sweep.u_secondary) / abs(L.uend)
100+
else:
101+
return abs(L.uend - L.sweep.u_secondary)
98102
elif self.params.sweeper_type == "SDC":
99103
# order rises by one between sweeps
100-
return abs(L.uold[-1] - L.u[-1])
104+
if self.params.rel_error:
105+
return abs(L.uold[-1] - L.u[-1]) / abs(L.u[-1])
106+
else:
107+
return abs(L.uold[-1] - L.u[-1])
101108
elif self.params.sweeper_type == 'MPI':
102109
comm = L.sweep.comm
103-
return comm.bcast(abs(L.uold[comm.rank + 1] - L.u[comm.rank + 1]), root=comm.size - 1)
110+
if self.params.rel_error:
111+
return comm.bcast(
112+
abs(L.uold[comm.rank + 1] - L.u[comm.rank + 1]) / abs(L.u[comm.rank + 1]), root=comm.size - 1
113+
)
114+
else:
115+
return comm.bcast(abs(L.uold[comm.rank + 1] - L.u[comm.rank + 1]), root=comm.size - 1)
104116
else:
105117
raise NotImplementedError(
106118
f"Don't know how to estimate embedded error for sweeper type \

pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def setup(self, controller, params, description, **kwargs):
3737
defaults = {
3838
'control_order': -75,
3939
'estimate_on_node': num_nodes + 1 if quad_type == 'GAUSS' else num_nodes - 1,
40+
'rel_error': False,
4041
**super().setup(controller, params, description, **kwargs),
4142
}
4243
self.comm = description['sweeper_params'].get('comm', None)
@@ -103,6 +104,23 @@ def matmul(self, A, b, xp=np):
103104
else:
104105
return A @ xp.asarray(b)
105106

107+
def get_interpolated_solution(self, L, xp):
108+
"""
109+
Get the interpolated solution for numpy or cupy data types
110+
111+
Args:
112+
u_vec (array): Vector of solutions
113+
prob (pySDC.problem): Problem
114+
"""
115+
coll = L.sweep.coll
116+
117+
u = [
118+
L.u[i].flatten() if L.u[i] is not None else L.u[i]
119+
for i in range(coll.num_nodes + 1)
120+
if i != self.params.estimate_on_node
121+
]
122+
return self.matmul(self.interpolation_matrix, u, xp=xp)[0].reshape(L.prob.init[0])
123+
106124
def post_iteration_processing(self, controller, S, **kwargs):
107125
"""
108126
Estimate the error
@@ -120,20 +138,19 @@ def post_iteration_processing(self, controller, S, **kwargs):
120138
coll = L.sweep.coll
121139
nodes = np.append(np.append(0, coll.nodes), 1.0)
122140
estimate_on_node = self.params.estimate_on_node
123-
xp = L.u[0].xp
141+
142+
if hasattr(L.u[0], 'xp'):
143+
xp = L.u[0].xp
144+
else:
145+
xp = np
124146

125147
if self.interpolation_matrix is None:
126148
interpolator = LagrangeApproximation(
127149
points=[nodes[i] for i in range(coll.num_nodes + 1) if i != estimate_on_node]
128150
)
129151
self.interpolation_matrix = xp.array(interpolator.getInterpolationMatrix([nodes[estimate_on_node]]))
130152

131-
u = [
132-
L.u[i].flatten() if L.u[i] is not None else L.u[i]
133-
for i in range(coll.num_nodes + 1)
134-
if i != estimate_on_node
135-
]
136-
u_inter = self.matmul(self.interpolation_matrix, u, xp=xp)[0].reshape(L.prob.init[0])
153+
u_inter = self.get_interpolated_solution(L, xp)
137154

138155
# compute end point if needed
139156
if estimate_on_node == len(nodes) - 1:
@@ -147,12 +164,14 @@ def post_iteration_processing(self, controller, S, **kwargs):
147164
rank = estimate_on_node - 1
148165
L.status.order_embedded_estimate = coll.num_nodes * 1
149166

167+
rescale = float(abs(u_inter)) if self.params.rel_error else 1
168+
150169
if self.comm:
151-
buf = np.array(abs(u_inter - high_order_sol) if self.comm.rank == rank else 0.0)
170+
buf = np.array(abs(u_inter - high_order_sol) / rescale if self.comm.rank == rank else 0.0)
152171
self.comm.Bcast(buf, root=rank)
153172
L.status.error_embedded_estimate = float(buf)
154173
else:
155-
L.status.error_embedded_estimate = abs(u_inter - high_order_sol)
174+
L.status.error_embedded_estimate = abs(u_inter - high_order_sol) / rescale
156175

157176
self.debug(
158177
f'Obtained error estimate: {L.status.error_embedded_estimate:.2e} of order {L.status.order_embedded_estimate}',
@@ -176,3 +195,59 @@ def check_parameters(self, controller, params, description, **kwargs):
176195
return False, 'Need at least two collocation nodes to interpolate to one!'
177196

178197
return True, ""
198+
199+
200+
class EstimatePolynomialErrorFiredrake(EstimatePolynomialError):
201+
def matmul(self, A, b):
202+
"""
203+
Matrix vector multiplication, possibly MPI parallel.
204+
The parallel implementation performs a reduce operation in every row of the matrix. While communicating the
205+
entire vector once could reduce the number of communications, this way we never need to store the entire vector
206+
on any specific rank.
207+
208+
Args:
209+
A (2d np.ndarray): Matrix
210+
b (list): Vector
211+
212+
Returns:
213+
List: Axb
214+
"""
215+
216+
if self.comm:
217+
res = [A[i, 0] * b[0] if b[i] is not None else None for i in range(A.shape[0])]
218+
buf = 0 * b[0]
219+
for i in range(0, A.shape[0]):
220+
index = self.comm.rank + (1 if self.comm.rank < self.params.estimate_on_node - 1 else 0)
221+
send_buf = (
222+
(A[i, index] * b[index]) if self.comm.rank != self.params.estimate_on_node - 1 else 0 * res[0]
223+
)
224+
self.comm.Allreduce(send_buf, buf, op=self.MPI_SUM)
225+
res[i] += buf
226+
return res
227+
else:
228+
res = []
229+
for i in range(A.shape[0]):
230+
res.append(A[i, 0] * b[0])
231+
for j in range(1, A.shape[1]):
232+
res[-1] += A[i, j] * b[j]
233+
234+
return res
235+
236+
def get_interpolated_solution(self, L):
237+
"""
238+
Get the interpolated solution for Firedrake data types
239+
We are not 100% sure that you don't need to invert the mass matrix here, but should be fine.
240+
241+
Args:
242+
u_vec (array): Vector of solutions
243+
prob (pySDC.problem): Problem
244+
"""
245+
coll = L.sweep.coll
246+
247+
u = [
248+
L.u[i] if L.u[i] is not None else L.u[i]
249+
for i in range(coll.num_nodes + 1)
250+
if i != self.params.estimate_on_node
251+
]
252+
return L.prob.dtype_u(self.matmul(self.interpolation_matrix, u)[0])
253+
# return L.prob.invert_mass_matrix(self.matmul(self.interpolation_matrix, u)[0])

pySDC/tests/test_convergence_controllers/test_polynomial_error.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33

4-
def get_controller(dt, num_nodes, quad_type, useMPI, useGPU):
4+
def get_controller(dt, num_nodes, quad_type, useMPI, useGPU, rel_error):
55
"""
66
Get a controller prepared for polynomial test equation
77
@@ -64,7 +64,7 @@ def get_controller(dt, num_nodes, quad_type, useMPI, useGPU):
6464
description['sweeper_params'] = sweeper_params
6565
description['level_params'] = level_params
6666
description['step_params'] = step_params
67-
description['convergence_controllers'] = {EstimatePolynomialError: {}}
67+
description['convergence_controllers'] = {EstimatePolynomialError: {'rel_error': rel_error}}
6868

6969
controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
7070
return controller
@@ -177,13 +177,15 @@ def check_order(dts, **kwargs):
177177
@pytest.mark.base
178178
@pytest.mark.parametrize('num_nodes', [2, 3, 4, 5])
179179
@pytest.mark.parametrize('quad_type', ['RADAU-RIGHT', 'GAUSS'])
180-
def test_interpolation_error(num_nodes, quad_type):
180+
@pytest.mark.parametrize('rel_error', [True, False])
181+
def test_interpolation_error(num_nodes, quad_type, rel_error):
181182
import numpy as np
182183

183184
kwargs = {
184185
'num_nodes': num_nodes,
185186
'quad_type': quad_type,
186187
'useMPI': False,
188+
'rel_error': rel_error,
187189
}
188190
steps = np.logspace(-1, -4, 20)
189191
check_order(steps, **kwargs)
@@ -200,6 +202,7 @@ def test_interpolation_error_GPU(num_nodes, quad_type):
200202
'quad_type': quad_type,
201203
'useMPI': False,
202204
'useGPU': True,
205+
'rel_error': False,
203206
}
204207
steps = np.logspace(-1, -4, 20)
205208
check_order(steps, **kwargs)
@@ -228,6 +231,77 @@ def test_interpolation_error_MPI(num_nodes, quad_type):
228231
)
229232

230233

234+
@pytest.mark.firedrake
235+
def test_polynomial_error_firedrake(dt=1.0, num_nodes=3, useMPI=False):
236+
from pySDC.implementations.problem_classes.HeatFiredrake import Heat1DForcedFiredrake
237+
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
238+
from pySDC.implementations.convergence_controller_classes.estimate_polynomial_error import (
239+
EstimatePolynomialErrorFiredrake,
240+
LagrangeApproximation,
241+
)
242+
import numpy as np
243+
244+
if useMPI:
245+
from pySDC.implementations.sweeper_classes.generic_implicit_MPI import generic_implicit_MPI as sweeper_class
246+
from mpi4py import MPI
247+
248+
comm = MPI.COMM_WORLD
249+
else:
250+
from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper_class
251+
252+
comm = None
253+
254+
level_params = {}
255+
level_params['dt'] = dt
256+
level_params['restol'] = 1.0
257+
258+
sweeper_params = {}
259+
sweeper_params['quad_type'] = 'RADAU-RIGHT'
260+
sweeper_params['num_nodes'] = num_nodes
261+
sweeper_params['comm'] = comm
262+
263+
problem_params = {'n': 1}
264+
265+
step_params = {}
266+
step_params['maxiter'] = 0
267+
268+
controller_params = {}
269+
controller_params['logger_level'] = 30
270+
controller_params['mssdc_jac'] = False
271+
272+
description = {}
273+
description['problem_class'] = Heat1DForcedFiredrake
274+
description['problem_params'] = problem_params
275+
description['sweeper_class'] = sweeper_class
276+
description['sweeper_params'] = sweeper_params
277+
description['level_params'] = level_params
278+
description['step_params'] = step_params
279+
description['convergence_controllers'] = {EstimatePolynomialErrorFiredrake: {}}
280+
281+
controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
282+
283+
L = controller.MS[0].levels[0]
284+
285+
cont = controller.convergence_controllers[
286+
np.arange(len(controller.convergence_controllers))[
287+
[type(me).__name__ == 'EstimatePolynomialErrorFiredrake' for me in controller.convergence_controllers]
288+
][0]
289+
]
290+
291+
nodes = np.append(np.append(0, L.sweep.coll.nodes), 1.0)
292+
estimate_on_node = cont.params.estimate_on_node
293+
interpolator = LagrangeApproximation(points=[nodes[i] for i in range(num_nodes + 1) if i != estimate_on_node])
294+
cont.interpolation_matrix = np.array(interpolator.getInterpolationMatrix([nodes[estimate_on_node]]))
295+
296+
for i in range(num_nodes + 1):
297+
L.u[i] = L.prob.u_init
298+
L.u[i].functionspace.assign(nodes[i])
299+
300+
u_inter = cont.get_interpolated_solution(L)
301+
error = abs(u_inter - L.u[estimate_on_node])
302+
assert np.isclose(error, 0)
303+
304+
231305
if __name__ == "__main__":
232306
import sys
233307
import numpy as np
@@ -238,7 +312,8 @@ def test_interpolation_error_MPI(num_nodes, quad_type):
238312
kwargs = {
239313
'num_nodes': int(sys.argv[1]),
240314
'quad_type': sys.argv[2],
315+
'rel_error': False,
241316
}
242317
check_order(steps, useMPI=True, **kwargs)
243318
else:
244-
check_order(steps, useMPI=False, num_nodes=3, quad_type='RADAU-RIGHT')
319+
check_order(steps, useMPI=False, num_nodes=3, quad_type='RADAU-RIGHT', rel_error=False)

0 commit comments

Comments
 (0)