Skip to content

Commit 5ca8a02

Browse files
committed
Merge pull request #58 from danielru/feature/switch_for_coll_update
Feature/switch for coll update
2 parents f693675 + 9bc47ca commit 5ca8a02

File tree

4 files changed

+78
-20
lines changed

4 files changed

+78
-20
lines changed

pySDC/Sweeper.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def __init__(self,params):
3131

3232
defaults = dict()
3333
defaults['do_LU'] = False
34-
34+
defaults['do_coll_update'] = True
35+
3536
for k,v in defaults.items():
3637
setattr(self,k,v)
3738

@@ -43,6 +44,8 @@ def __init__(self,params):
4344

4445
coll = params['collocation_class'](params['num_nodes'],0,1)
4546
assert isinstance(coll, CollBase)
47+
if not coll.right_is_node:
48+
assert self.params.do_coll_update, "For nodes where the right end point is not a node, do_coll_update has to be set to True"
4649

4750
# This will be set as soon as the sweeper is instantiated at the level
4851
self.__level = None
@@ -151,4 +154,4 @@ def update_nodes(self):
151154
"""
152155
Abstract interface to node update
153156
"""
154-
return None
157+
return None

pySDC/sweeper_classes/generic_LU.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,17 @@ def compute_end_point(self):
135135
L = self.level
136136
P = L.prob
137137

138-
# start with u0 and add integral over the full interval (using coll.weights)
139-
L.uend = P.dtype_u(L.u[0])
140-
for m in range(self.coll.num_nodes):
141-
L.uend += L.dt*self.coll.weights[m]*L.f[m+1]
142-
# add up tau correction of the full interval (last entry)
143-
if L.tau is not None:
144-
L.uend += L.tau[-1]
145-
146-
return None
138+
# check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
139+
if (self.coll.right_is_node and not self.params.do_coll_update):
140+
# a copy is sufficient
141+
L.uend = P.dtype_u(L.u[-1])
142+
else:
143+
# start with u0 and add integral over the full interval (using coll.weights)
144+
L.uend = P.dtype_u(L.u[0])
145+
for m in range(self.coll.num_nodes):
146+
L.uend += L.dt*self.coll.weights[m]*L.f[m+1]
147+
# add up tau correction of the full interval (last entry)
148+
if L.tau is not None:
149+
L.uend += L.tau[-1]
150+
151+
return None

pySDC/sweeper_classes/imex_1st_order.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,19 +134,24 @@ def compute_end_point(self):
134134
"""
135135
Compute u at the right point of the interval
136136
137-
The value uend computed here is a full evaluation of the Picard formulation (always!)
137+
The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False
138138
"""
139139

140140
# get current level and problem description
141141
L = self.level
142142
P = L.prob
143143

144-
# start with u0 and add integral over the full interval (using coll.weights)
145-
L.uend = P.dtype_u(L.u[0])
146-
for m in range(self.coll.num_nodes):
147-
L.uend += L.dt*self.coll.weights[m]*(L.f[m+1].impl + L.f[m+1].expl)
148-
# add up tau correction of the full interval (last entry)
149-
if L.tau is not None:
150-
L.uend += L.tau[-1]
144+
# check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
145+
if (self.coll.right_is_node and not self.params.do_coll_update):
146+
# a copy is sufficient
147+
L.uend = P.dtype_u(L.u[-1])
148+
else:
149+
# start with u0 and add integral over the full interval (using coll.weights)
150+
L.uend = P.dtype_u(L.u[0])
151+
for m in range(self.coll.num_nodes):
152+
L.uend += L.dt*self.coll.weights[m]*(L.f[m+1].impl + L.f[m+1].expl)
153+
# add up tau correction of the full interval (last entry)
154+
if L.tau is not None:
155+
L.uend += L.tau[-1]
151156

152157
return None

tests/test_imexsweeper.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def test_manysweepsequalmatrix(self):
204204
#
205205
# Make sure that update function for K sweeps computed from K-sweep matrix gives same result as K sweeps in node-to-node form plus compute_end_point
206206
#
207-
def test_maysweepupdate(self):
207+
def test_manysweepupdate(self):
208208

209209
step, level, problem, nnodes = self.setupLevelStepProblem()
210210
step.levels[0].sweep.predict()
@@ -229,3 +229,48 @@ def test_maysweepupdate(self):
229229
# Multiply u0 by value of update function to get end value directly
230230
uend_matrix = update*self.pparams['u0']
231231
assert abs(uend_matrix - uend_sweep)<1e-14, "Node-to-node sweep plus update yields different result than update function computed through K-sweep matrix"
232+
233+
#
234+
# Make sure that creating a sweeper object with a collocation object with right_is_node=False and do_coll_update=False throws an exception
235+
#
236+
def test_norightnode_collupdate_fails(self):
237+
self.swparams['collocation_class'] = collclass.CollGaussLegendre
238+
self.swparams['do_coll_update'] = False
239+
# Has to throw an exception
240+
with self.assertRaises(AssertionError):
241+
step, level, problem, nnodes = self.setupLevelStepProblem()
242+
243+
#
244+
# Make sure the update with do_coll_update=False reproduces last stage
245+
#
246+
def test_update_nocollupdate_laststage(self):
247+
self.swparams['do_coll_update'] = False
248+
step, level, problem, nnodes = self.setupLevelStepProblem()
249+
level.sweep.predict()
250+
ulaststage = np.random.rand()
251+
level.u[nnodes].values = ulaststage
252+
level.sweep.compute_end_point()
253+
uend = level.uend.values
254+
assert abs(uend-ulaststage)<1e-14, "compute_end_point with do_coll_update=False did not reproduce last stage value"
255+
256+
#
257+
# Make sure that update with do_coll_update=False is identical to update formula with q=(0,...,0,1)
258+
#
259+
def test_updateformula_no_coll_update(self):
260+
self.swparams['do_coll_update'] = False
261+
step, level, problem, nnodes = self.setupLevelStepProblem()
262+
level.sweep.predict()
263+
u0full = np.array([ level.u[l].values.flatten() for l in range(1,nnodes+1) ])
264+
265+
# Perform update step in sweeper
266+
level.sweep.update_nodes()
267+
ustages = np.array([ level.u[l].values.flatten() for l in range(1,nnodes+1) ])
268+
269+
# Compute end value through provided function
270+
level.sweep.compute_end_point()
271+
uend_sweep = level.uend.values
272+
# Compute end value from matrix formulation
273+
q = np.zeros(nnodes)
274+
q[nnodes-1] = 1.0
275+
uend_mat = q.dot(ustages)
276+
assert np.linalg.norm(uend_sweep - uend_mat, np.infty)<1e-14, "For do_coll_update=False, update formula in sweeper gives different result than matrix update formula with q=(0,..,0,1)"

0 commit comments

Comments
 (0)