Skip to content

Commit 82864f7

Browse files
committed
Refactor and fix
1 parent 10c4ba1 commit 82864f7

File tree

2 files changed

+44
-66
lines changed

2 files changed

+44
-66
lines changed

pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,8 @@ def estimate_embedded_error_serial(self, L):
9393
dtype_u: The embedded error estimate
9494
"""
9595
if self.params.sweeper_type == "RK":
96-
if L.f[1] is None:
97-
return -1
98-
else:
99-
L.sweep.compute_end_point()
100-
return abs(L.uend - L.sweep.u_secondary)
96+
L.sweep.compute_end_point()
97+
return abs(L.uend - L.sweep.u_secondary)
10198
elif self.params.sweeper_type == "SDC":
10299
# order rises by one between sweeps
103100
return abs(L.uold[-1] - L.u[-1])

pySDC/implementations/sweeper_classes/Runge_Kutta.py

Lines changed: 42 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,7 @@ def __init__(self, weights, nodes, matrix):
1717
nodes (numpy.ndarray): Butcher tableau nodes
1818
matrix (numpy.ndarray): Butcher tableau entries
1919
"""
20-
# check if the arguments have the correct form
21-
if type(matrix) != np.ndarray:
22-
raise ParameterError('Runge-Kutta matrix needs to be supplied as a numpy array!')
23-
elif len(np.unique(matrix.shape)) != 1 or len(matrix.shape) != 2:
24-
raise ParameterError('Runge-Kutta matrix needs to be a square 2D numpy array!')
25-
26-
if type(weights) != np.ndarray:
27-
raise ParameterError('Weights need to be supplied as a numpy array!')
28-
elif len(weights.shape) != 1:
29-
raise ParameterError(f'Incompatible dimension of weights! Need 1, got {len(weights.shape)}')
30-
elif len(weights) != matrix.shape[0]:
31-
raise ParameterError(f'Incompatible number of weights! Need {matrix.shape[0]}, got {len(weights)}')
32-
33-
if type(nodes) != np.ndarray:
34-
raise ParameterError('Nodes need to be supplied as a numpy array!')
35-
elif len(nodes.shape) != 1:
36-
raise ParameterError(f'Incompatible dimension of nodes! Need 1, got {len(nodes.shape)}')
37-
elif len(nodes) != matrix.shape[0]:
38-
raise ParameterError(f'Incompatible number of nodes! Need {matrix.shape[0]}, got {len(nodes)}')
39-
40-
self.globally_stiffly_accurate = np.allclose(matrix[-1], weights)
20+
self.check_method(weights, nodes, matrix)
4121

4222
self.tleft = 0.0
4323
self.tright = 1.0
@@ -61,63 +41,56 @@ def __init__(self, weights, nodes, matrix):
6141
# check if the RK scheme is implicit
6242
self.implicit = any(matrix[i, i] != 0 for i in range(self.num_nodes))
6343

64-
65-
class ButcherTableauEmbedded(object):
66-
def __init__(self, weights, nodes, matrix):
44+
def check_method(self, weights, nodes, matrix):
6745
"""
68-
Initialization routine to get a quadrature matrix out of a Butcher tableau for embedded RK methods.
69-
70-
Be aware that the method that generates the final solution should be in the first row of the weights matrix.
71-
72-
Args:
73-
weights (numpy.ndarray): Butcher tableau weights
74-
nodes (numpy.ndarray): Butcher tableau nodes
75-
matrix (numpy.ndarray): Butcher tableau entries
46+
Check that the method is entered in the correct format
7647
"""
77-
# check if the arguments have the correct form
7848
if type(matrix) != np.ndarray:
7949
raise ParameterError('Runge-Kutta matrix needs to be supplied as a numpy array!')
8050
elif len(np.unique(matrix.shape)) != 1 or len(matrix.shape) != 2:
8151
raise ParameterError('Runge-Kutta matrix needs to be a square 2D numpy array!')
8252

83-
if type(weights) != np.ndarray:
84-
raise ParameterError('Weights need to be supplied as a numpy array!')
85-
elif len(weights.shape) != 2:
86-
raise ParameterError(f'Incompatible dimension of weights! Need 2, got {len(weights.shape)}')
87-
elif len(weights[0]) != matrix.shape[0]:
88-
raise ParameterError(f'Incompatible number of weights! Need {matrix.shape[0]}, got {len(weights[0])}')
89-
9053
if type(nodes) != np.ndarray:
9154
raise ParameterError('Nodes need to be supplied as a numpy array!')
9255
elif len(nodes.shape) != 1:
9356
raise ParameterError(f'Incompatible dimension of nodes! Need 1, got {len(nodes.shape)}')
9457
elif len(nodes) != matrix.shape[0]:
9558
raise ParameterError(f'Incompatible number of nodes! Need {matrix.shape[0]}, got {len(nodes)}')
9659

97-
# Set number of nodes, left and right interval boundaries
98-
self.num_nodes = matrix.shape[0]
99-
self.tleft = 0.0
100-
self.tright = 1.0
60+
self.check_weights(weights, nodes, matrix)
10161

102-
self.nodes = np.append(np.append([0], nodes), [1, 1])
103-
self.weights = weights
104-
self.Qmat = np.zeros([self.num_nodes + 1, self.num_nodes + 1])
105-
self.Qmat[1:, 1:] = matrix
62+
def check_weights(self, weights, nodes, matrix):
63+
"""
64+
Check that the weights of the method are entered in the correct format
65+
"""
66+
if type(weights) != np.ndarray:
67+
raise ParameterError('Weights need to be supplied as a numpy array!')
68+
elif len(weights.shape) != 1:
69+
raise ParameterError(f'Incompatible dimension of weights! Need 1, got {len(weights.shape)}')
70+
elif len(weights) != matrix.shape[0]:
71+
raise ParameterError(f'Incompatible number of weights! Need {matrix.shape[0]}, got {len(weights)}')
10672

107-
self.left_is_node = True
108-
self.right_is_node = self.nodes[-1] == self.tright
73+
@property
74+
def globally_stiffly_accurate(self):
75+
return np.allclose(self.Qmat[-1, 1:], self.weights)
10976

110-
self.globally_stiffly_accurate = np.allclose(matrix[-1], weights[0])
11177

112-
# compute distances between the nodes
113-
if self.num_nodes > 1:
114-
self.delta_m = self.nodes[1:] - self.nodes[:-1]
115-
else:
116-
self.delta_m = np.zeros(1)
117-
self.delta_m[0] = self.nodes[0] - self.tleft
78+
class ButcherTableauEmbedded(ButcherTableau):
11879

119-
# check if the RK scheme is implicit
120-
self.implicit = any(matrix[i, i] != 0 for i in range(self.num_nodes))
80+
def check_weights(self, weights, nodes, matrix):
81+
"""
82+
Check that the weights of the method are entered in the correct format
83+
"""
84+
if type(weights) != np.ndarray:
85+
raise ParameterError('Weights need to be supplied as a numpy array!')
86+
elif len(weights.shape) != 2:
87+
raise ParameterError(f'Incompatible dimension of weights! Need 2, got {len(weights.shape)}')
88+
elif len(weights[0]) != matrix.shape[0]:
89+
raise ParameterError(f'Incompatible number of weights! Need {matrix.shape[0]}, got {len(weights[0])}')
90+
91+
@property
92+
def globally_stiffly_accurate(self):
93+
return np.allclose(self.Qmat[-1, 1:], self.weights[0])
12194

12295

12396
class RungeKutta(Sweeper):
@@ -294,7 +267,11 @@ def compute_end_point(self):
294267
"""
295268
In this Runge-Kutta implementation, the solution to the step is always stored in the last node
296269
"""
297-
if self.coll.globally_stiffly_accurate:
270+
if self.level.f[1] is None:
271+
self.level.uend = self.level.u[0]
272+
if type(self.coll) == ButcherTableauEmbedded:
273+
self.u_secondary = self.level.u[0].copy()
274+
elif self.coll.globally_stiffly_accurate:
298275
self.level.uend = self.level.u[-1]
299276
if type(self.coll) == ButcherTableauEmbedded:
300277
self.u_secondary = self.level.u[0].copy()
@@ -467,7 +444,11 @@ def compute_end_point(self):
467444
"""
468445
In this Runge-Kutta implementation, the solution to the step is always stored in the last node
469446
"""
470-
if self.coll.globally_stiffly_accurate and self.coll_explicit.globally_stiffly_accurate:
447+
if self.level.f[1] is None:
448+
self.level.uend = self.level.u[0]
449+
if type(self.coll) == ButcherTableauEmbedded:
450+
self.u_secondary = self.level.u[0].copy()
451+
elif self.coll.globally_stiffly_accurate and self.coll_explicit.globally_stiffly_accurate:
471452
self.level.uend = self.level.u[-1]
472453
if type(self.coll) == ButcherTableauEmbedded:
473454
self.u_secondary = self.level.u[0].copy()

0 commit comments

Comments
 (0)