Skip to content

Commit 9f00efc

Browse files
author
Daniel Ruprecht
committed
beautification of imex_sweeper test
1 parent 5741997 commit 9f00efc

File tree

1 file changed

+95
-93
lines changed

1 file changed

+95
-93
lines changed

tests/test_imexsweeper.py

Lines changed: 95 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,38 @@
1414
class TestImexSweeper(unittest.TestCase):
1515

1616
#
17+
# Some auxiliary functions which are not tests themselves
1718
#
19+
def setupLevelStepProblem(self):
20+
step = stepclass.step(params={})
21+
L = lvl.level(problem_class=swfw_scalar, problem_params=self.pparams, dtype_u=mesh, dtype_f=rhs_imex_mesh, sweeper_class=imex, sweeper_params=self.swparams, level_params={}, hook_class=hookclass.hooks, id="imextest")
22+
step.register_level(L)
23+
step.status.dt = 1.0
24+
step.status.time = 0.0
25+
u0 = step.levels[0].prob.u_exact(step.status.time)
26+
step.init_step(u0)
27+
nnodes = step.levels[0].sweep.coll.num_nodes
28+
level = step.levels[0]
29+
problem = level.prob
30+
return step, level, problem, nnodes
31+
32+
def setupQMatrices(self, level):
33+
QE = level.sweep.QE[1:,1:]
34+
QI = level.sweep.QI[1:,1:]
35+
Q = level.sweep.coll.Qmat[1:,1:]
36+
return QE, QI, Q
37+
38+
def setupSweeperMatrices(self, step, level, problem):
39+
nnodes = step.levels[0].sweep.coll.num_nodes
40+
# Build SDC sweep matrix
41+
QE, QI, Q = self.setupQMatrices(level)
42+
dt = step.status.dt
43+
LHS = np.eye(nnodes) - step.status.dt*( problem.lambda_f[0]*QI + problem.lambda_s[0]*QE )
44+
RHS = step.status.dt*( (problem.lambda_f[0]+problem.lambda_s[0])*Q - (problem.lambda_f[0]*QI + problem.lambda_s[0]*QE) )
45+
return LHS, RHS
46+
47+
#
48+
# General setUp function used by all tests
1849
#
1950
def setUp(self):
2051
self.pparams = {}
@@ -25,160 +56,129 @@ def setUp(self):
2556
self.swparams['collocation_class'] = collclass.CollGaussLobatto
2657
self.swparams['num_nodes'] = 2
2758

59+
# ***************
60+
# **** TESTS ****
61+
# ***************
62+
63+
2864
#
29-
#
65+
# Check that a level object can be instantiated
3066
#
3167
def test_caninstantiate(self):
3268
L = lvl.level(problem_class=swfw_scalar, problem_params=self.pparams, dtype_u=mesh, dtype_f=rhs_imex_mesh, sweeper_class=imex, sweeper_params=self.swparams, level_params={}, hook_class=hookclass.hooks, id="imextest")
69+
assert isinstance(L.sweep, imex), "sweeper in generated level is not an object of type imex"
3370

3471
#
35-
#
72+
# Check that a level object can be registered in a step object (needed as prerequiste to execute update_nodes
3673
#
3774
def test_canregisterlevel(self):
3875
step = stepclass.step(params={})
3976
L = lvl.level(problem_class=swfw_scalar, problem_params=self.pparams, dtype_u=mesh, dtype_f=rhs_imex_mesh, sweeper_class=imex, sweeper_params=self.swparams, level_params={}, hook_class=hookclass.hooks, id="imextest")
4077
step.register_level(L)
78+
# At this point, it should not be possible to actually execute functions of the sweeper because the parameters set in setupLevelStepProblem are not yet initialised
79+
with self.assertRaises(Exception):
80+
step.sweep.predict()
81+
with self.assertRaises(Exception):
82+
step.sweep.update_nodes()
83+
with self.assertRaises(Exception):
84+
step.sweep.compute_end_point()
4185

4286
#
43-
#
87+
# Check that the sweeper functions update_nodes and compute_end_point can be executed
4488
#
4589
def test_canrunsweep(self):
4690

47-
step = stepclass.step(params={})
48-
L = lvl.level(problem_class=swfw_scalar, problem_params=self.pparams, dtype_u=mesh, dtype_f=rhs_imex_mesh, sweeper_class=imex, sweeper_params=self.swparams, level_params={}, hook_class=hookclass.hooks, id="imextest")
49-
step.register_level(L)
50-
step.status.dt = 1.0
51-
step.status.time = 0.0
52-
nnodes = step.levels[0].sweep.coll.num_nodes
53-
u0 = step.levels[0].prob.u_exact(step.status.time)
54-
step.init_step(u0)
55-
step.levels[0].sweep.predict()
56-
u0full = np.array([ step.levels[0].u[l].values.flatten() for l in range(1,nnodes+1) ])
57-
58-
step.levels[0].sweep.update_nodes()
59-
assert step.levels[0].uend is None, "uend should be None previous to running compute_end_point"
60-
step.levels[0].sweep.compute_end_point()
61-
#print "Sweep: %s" % step.levels[0].uend.values
91+
# After running setupLevelStepProblem, the functions predict, update_nodes and compute_end_point should run
92+
step, level, problem, nnodes = self.setupLevelStepProblem()
93+
assert level.u[0] is not None, "After init_step, level.u[0] should no longer be of type None"
94+
assert level.u[1] is None, "Before predict, level.u[1] and following should be of type None"
95+
level.sweep.predict()
96+
# Should now be able to run update nodes
97+
level.sweep.update_nodes()
98+
assert level.uend is None, "uend should be None previous to running compute_end_point"
99+
level.sweep.compute_end_point()
100+
assert level.uend is not None, "uend still None after running compute_end_point"
62101

63102
#
64103
# Make sure a sweep in matrix form is equal to a sweep in node-to-node form
65104
#
66105
def test_sweepequalmatrix(self):
67106

68-
step = stepclass.step(params={})
69-
L = lvl.level(problem_class=swfw_scalar, problem_params=self.pparams, dtype_u=mesh, dtype_f=rhs_imex_mesh, sweeper_class=imex, sweeper_params=self.swparams, level_params={}, hook_class=hookclass.hooks, id="imextest")
70-
step.register_level(L)
71-
step.status.dt = 1.0
72-
step.status.time = 0.0
73-
nnodes = step.levels[0].sweep.coll.num_nodes
74-
u0 = step.levels[0].prob.u_exact(step.status.time)
75-
step.init_step(u0)
107+
step, level, problem, nnodes = self.setupLevelStepProblem()
76108
step.levels[0].sweep.predict()
77-
u0full = np.array([ step.levels[0].u[l].values.flatten() for l in range(1,nnodes+1) ])
109+
u0full = np.array([ level.u[l].values.flatten() for l in range(1,nnodes+1) ])
78110

79111
# Perform node-to-node SDC sweep
80-
step.levels[0].sweep.update_nodes()
112+
level.sweep.update_nodes()
113+
114+
LHS, RHS = self.setupSweeperMatrices(step, level, problem)
81115

82-
# Build SDC sweep matrix
83-
level = step.levels[0]
84-
problem = level.prob
85-
QE = level.sweep.QE[1:,1:]
86-
QI = level.sweep.QI[1:,1:]
87-
Q = level.sweep.coll.Qmat[1:,1:]
88-
dt = step.status.dt
89-
LHS = np.eye(nnodes) - step.status.dt*( problem.lambda_f[0]*QI + problem.lambda_s[0]*QE )
90-
RHS = step.status.dt*( (problem.lambda_f[0]+problem.lambda_s[0])*Q - (problem.lambda_f[0]*QI + problem.lambda_s[0]*QE) )
91116
unew = np.linalg.inv(LHS).dot( u0full + RHS.dot(u0full) )
92117
usweep = np.array([ level.u[l].values.flatten() for l in range(1,nnodes+1) ])
93118
assert np.linalg.norm(unew - usweep, np.infty)<1e-14, "Single SDC sweeps in matrix and node-to-node formulation yield different results"
94119

95120
#
96-
#
121+
# Make sure the implemented update formula matches the matrix update formula
97122
#
98123
@unittest.skip("Needs fix of isse #52 before passing")
99124
def test_updateformula(self):
100125

101-
step = stepclass.step(params={})
102-
L = lvl.level(problem_class=swfw_scalar, problem_params=self.pparams, dtype_u=mesh, dtype_f=rhs_imex_mesh, sweeper_class=imex, sweeper_params=self.swparams, level_params={}, hook_class=hookclass.hooks, id="imextest")
103-
step.register_level(L)
104-
step.status.dt = 1.0
105-
step.status.time = 0.0
106-
nnodes = step.levels[0].sweep.coll.num_nodes
107-
u0 = step.levels[0].prob.u_exact(step.status.time)
108-
step.init_step(u0)
109-
step.levels[0].sweep.predict()
110-
u0full = np.array([ step.levels[0].u[l].values.flatten() for l in range(1,nnodes+1) ])
111-
112-
# Build SDC sweep matrix
113-
level = step.levels[0]
114-
problem = level.prob
115-
QE = level.sweep.QE[1:,1:]
116-
QI = level.sweep.QI[1:,1:]
117-
Q = level.sweep.coll.Qmat[1:,1:]
126+
step, level, problem, nnodes = self.setupLevelStepProblem()
127+
level.sweep.predict()
128+
u0full = np.array([ level.u[l].values.flatten() for l in range(1,nnodes+1) ])
118129

119130
# Perform update step in sweeper
120-
step.levels[0].sweep.update_nodes()
121-
ustages = np.array([ step.levels[0].u[l].values.flatten() for l in range(1,nnodes+1) ])
122-
123-
step.levels[0].sweep.compute_end_point()
124-
uend_sweep = step.levels[0].uend.values
125-
uend_mat = u0.values + step.status.dt*step.levels[0].sweep.coll.weights.dot(ustages*(problem.lambda_s[0] + problem.lambda_f[0]))
131+
level.sweep.update_nodes()
132+
ustages = np.array([ level.u[l].values.flatten() for l in range(1,nnodes+1) ])
133+
134+
# Compute end value through provided function
135+
level.sweep.compute_end_point()
136+
uend_sweep = level.uend.values
137+
# Compute end value from matrix formulation
138+
uend_mat = u0.values + step.status.dt*level.sweep.coll.weights.dot(ustages*(problem.lambda_s[0] + problem.lambda_f[0]))
126139
assert np.linalg.norm(uend_sweep - uend_mat, np.infty)<1e-14, "Update formula in sweeper gives different result than matrix update formula"
127140

128141
#
129142
# Compute the exact collocation solution by matrix inversion and make sure it is a fixed point
130143
#
131144
def test_collocationinvariant(self):
132145

133-
step = stepclass.step(params={})
134-
L = lvl.level(problem_class=swfw_scalar, problem_params=self.pparams, dtype_u=mesh, dtype_f=rhs_imex_mesh, sweeper_class=imex, sweeper_params=self.swparams, level_params={}, hook_class=hookclass.hooks, id="imextest")
135-
step.register_level(L)
136-
step.status.dt = 1.0
137-
step.status.time = 0.0
138-
nnodes = step.levels[0].sweep.coll.num_nodes
139-
u0 = step.levels[0].prob.u_exact(step.status.time)
140-
step.init_step(u0)
141-
step.levels[0].sweep.predict()
142-
u0full = np.array([ step.levels[0].u[l].values.flatten() for l in range(1,nnodes+1) ])
143-
144-
# Build SDC sweep matrix
145-
level = step.levels[0]
146-
problem = level.prob
147-
QE = level.sweep.QE[1:,1:]
148-
QI = level.sweep.QI[1:,1:]
149-
Q = level.sweep.coll.Qmat[1:,1:]
146+
step, level, problem, nnodes = self.setupLevelStepProblem()
147+
level.sweep.predict()
148+
u0full = np.array([ level.u[l].values.flatten() for l in range(1,nnodes+1) ])
150149

150+
QE, QI, Q = self.setupQMatrices(level)
151+
152+
# Build collocation matrix
151153
Mcoll = np.eye(nnodes) - step.status.dt*Q*(problem.lambda_s[0] + problem.lambda_f[0])
154+
155+
# Solve collocation problem directly
152156
ucoll = np.linalg.inv(Mcoll).dot(u0full)
157+
158+
# Put stages of collocation solution into level
153159
for l in range(0,nnodes):
154-
step.levels[0].u[l+1].values = ucoll[l]
155-
step.levels[0].f[l+1].impl.values = problem.lambda_f[0]*ucoll[l]
156-
step.levels[0].f[l+1].expl.values = problem.lambda_s[0]*ucoll[l]
160+
level.u[l+1].values = ucoll[l]
161+
level.f[l+1].impl.values = problem.lambda_f[0]*ucoll[l]
162+
level.f[l+1].expl.values = problem.lambda_s[0]*ucoll[l]
157163

158164
# Perform node-to-node SDC sweep
159-
step.levels[0].sweep.update_nodes()
165+
level.sweep.update_nodes()
160166

167+
# Build matrices for matrix formulation of sweep
161168
LHS = np.eye(nnodes) - step.status.dt*( problem.lambda_f[0]*QI + problem.lambda_s[0]*QE )
162169
RHS = step.status.dt*( (problem.lambda_f[0]+problem.lambda_s[0])*Q - (problem.lambda_f[0]*QI + problem.lambda_s[0]*QE) )
170+
# Make sure both matrix and node-to-node sweep leave collocation unaltered
163171
unew = np.linalg.inv(LHS).dot( u0full + RHS.dot(ucoll) )
164172
assert np.linalg.norm( unew - ucoll, np.infty )<1e-14, "Collocation solution not invariant under matrix SDC sweep"
165-
unew_sweep = np.array([ step.levels[0].u[l].values.flatten() for l in range(1,nnodes+1) ])
173+
unew_sweep = np.array([ level.u[l].values.flatten() for l in range(1,nnodes+1) ])
166174
assert np.linalg.norm( unew_sweep - ucoll, np.infty )<1e-14, "Collocation solution not invariant under node-to-node sweep"
167175

168176
#
169177
#
170178
#
171179
def test_canrunmatrixsweep(self):
172-
step = stepclass.step(params={})
173-
L = lvl.level(problem_class=swfw_scalar, problem_params=self.pparams, dtype_u=mesh, dtype_f=rhs_imex_mesh, sweeper_class=imex, sweeper_params=self.swparams, level_params={}, hook_class=hookclass.hooks, id="imextest")
174-
step.register_level(L)
175-
step.status.dt = 1.0
176-
step.status.time = 0.0
177-
u0 = step.levels[0].prob.u_exact(step.status.time)
178-
step.init_step(u0)
179-
nnodes = step.levels[0].sweep.coll.num_nodes
180-
level = step.levels[0]
181-
problem = level.prob
180+
step, level, problem, nnodes = self.setupLevelStepProblem()
181+
182182
QE = level.sweep.QE[1:,1:]
183183
QI = level.sweep.QI[1:,1:]
184184
Q = level.sweep.coll.Qmat[1:,1:]
@@ -204,3 +204,5 @@ def test_canrunmatrixsweep(self):
204204
#print ufull
205205
#uend = u0.values + step.status.dt*level.sweep.coll.weights.dot( (problem.lambda_f[0]+problem.lambda_s[0])*ufull )
206206
#print "Matrix: %s" % uend
207+
208+

0 commit comments

Comments
 (0)