Skip to content

Commit 460d69f

Browse files
committed
New hook, used in boussinesq_2d_imex example to set the GMRES tolerance
1 parent d2bc3d6 commit 460d69f

File tree

9 files changed

+64
-31
lines changed

9 files changed

+64
-31
lines changed

examples/boussinesq_2d_imex/HookClass.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,25 @@ def __init__(self):
2424
#self.counter = 0
2525

2626
def dump_sweep(self,status):
27-
pass
27+
"""
28+
Set new GMRES tolerance depending on the previous SDC residual
29+
30+
Args:
31+
status: status object per step
32+
"""
33+
super(plot_solution,self).dump_sweep(status)
34+
self.level.prob.gmres_tol = max(self.level.status.residual*self.level.prob.gmres_tol_factor,self.level.prob.gmres_tol_limit)
35+
36+
def dump_pre_iteration(self,status):
37+
"""
38+
Set new GMRES tolerance depending on the initial SDC residual
39+
40+
Args:
41+
status: status object per step
42+
"""
43+
super(plot_solution,self).dump_pre_iteration(status)
44+
self.level.sweep.compute_residual()
45+
self.level.prob.gmres_tol = max(self.level.status.residual*self.level.prob.gmres_tol_factor,self.level.prob.gmres_tol_limit)
2846

2947
def dump_step(self,status):
3048
"""

examples/boussinesq_2d_imex/ProblemClass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self, cparams, dtype_u, dtype_f):
6666
assert 'order' in cparams
6767
assert 'gmres_maxiter' in cparams
6868
assert 'gmres_restart' in cparams
69-
assert 'gmres_tol' in cparams
69+
assert 'gmres_tol_limit' in cparams
7070

7171
# add parameters as attributes for further reference
7272
for k,v in cparams.items():
@@ -86,6 +86,7 @@ def __init__(self, cparams, dtype_u, dtype_f):
8686
self.D_upwind = getBoussinesq2DUpwindMatrix( self.N, self.h[0], self.u_adv , self.order_upw)
8787

8888
self.logger = logging()
89+
self.gmres_tol = None
8990

9091
def solve_system(self,rhs,factor,u0,t):
9192
"""
@@ -103,6 +104,8 @@ def solve_system(self,rhs,factor,u0,t):
103104

104105
b = rhs.values.flatten()
105106
cb = Callback()
107+
print(self.gmres_tol)
108+
106109
sol, info = LA.gmres( self.Id - factor*self.M, b, x0=u0.values.flatten(), tol=self.gmres_tol, restart=self.gmres_restart, maxiter=self.gmres_maxiter, callback=cb)
107110
# If this is a dummy call with factor==0.0, do not log because it should not be counted as a solver call
108111
if factor!=0.0:

examples/boussinesq_2d_imex/plotgmrescounter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
uimex = np.load('rkimex.npy')
1616
uref = np.load('uref.npy')
1717

18-
print "Estimated discretisation error of DIRK: %5.3e" % ( np.linalg.norm(udirk.flatten() - uref.flatten(), np.inf)/np.linalg.norm(uref.flatten(),np.inf) )
19-
print "Estimated discretisation error of SDC: %5.3e" % ( np.linalg.norm(uend.flatten() - uref.flatten(), np.inf)/np.linalg.norm(uref.flatten(),np.inf) )
20-
print "Estimated discretisation error of RK-IMEX: %5.3e" % ( np.linalg.norm(uimex.flatten() - uref.flatten(), np.inf)/np.linalg.norm(uref.flatten(),np.inf) )
18+
print("Estimated discretisation error of DIRK: %5.3e" % ( np.linalg.norm(udirk.flatten() - uref.flatten(), np.inf)/np.linalg.norm(uref.flatten(),np.inf) ))
19+
print("Estimated discretisation error of SDC: %5.3e" % ( np.linalg.norm(uend.flatten() - uref.flatten(), np.inf)/np.linalg.norm(uref.flatten(),np.inf) ))
20+
print("Estimated discretisation error of RK-IMEX: %5.3e" % ( np.linalg.norm(uimex.flatten() - uref.flatten(), np.inf)/np.linalg.norm(uref.flatten(),np.inf) ))
2121

2222
fs = 8
2323
rcParams['figure.figsize'] = 5.0, 2.5

examples/boussinesq_2d_imex/rungmrescounter.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
swparams = {}
3838
swparams['collocation_class'] = collclass.CollGaussLegendre
3939
swparams['num_nodes'] = 3
40-
swparams['do_LU'] = False
40+
swparams['do_LU'] = True
4141

4242
sparams = {}
4343
sparams['maxiter'] = 4
@@ -47,7 +47,7 @@
4747
# setup parameters "in time"
4848
t0 = 0
4949
Tend = 3000
50-
Nsteps = 500
50+
Nsteps = 100
5151
dt = Tend/float(Nsteps)
5252

5353
# This comes as read-in for the problem class
@@ -63,7 +63,8 @@
6363
pparams['order_upw'] = [5]
6464
pparams['gmres_maxiter'] = [500]
6565
pparams['gmres_restart'] = [10]
66-
pparams['gmres_tol'] = [1e-9]
66+
pparams['gmres_tol_limit'] = [1e-3]
67+
pparams['gmres_tol_factor'] = [0.1]
6768

6869
# This comes as read-in for the transfer operations
6970
tparams = {}
@@ -90,11 +91,11 @@
9091
cfl_advection = pparams['u_adv']*dt/P.h[0]
9192
cfl_acoustic_hor = pparams['c_s']*dt/P.h[0]
9293
cfl_acoustic_ver = pparams['c_s']*dt/P.h[1]
93-
print "Horizontal resolution: %4.2f" % P.h[0]
94-
print "Vertical resolution: %4.2f" % P.h[1]
95-
print ("CFL number of advection: %4.2f" % cfl_advection)
96-
print ("CFL number of acoustics (horizontal): %4.2f" % cfl_acoustic_hor)
97-
print ("CFL number of acoustics (vertical): %4.2f" % cfl_acoustic_ver)
94+
print("Horizontal resolution: %4.2f" % P.h[0])
95+
print("Vertical resolution: %4.2f" % P.h[1])
96+
print("CFL number of advection: %4.2f" % cfl_advection)
97+
print("CFL number of acoustics (horizontal): %4.2f" % cfl_acoustic_hor)
98+
print("CFL number of acoustics (vertical): %4.2f" % cfl_acoustic_ver)
9899

99100
dirkp = dirk(P, dirk_order)
100101
u0 = uinit.values.flatten()
@@ -129,18 +130,18 @@
129130
np.save('rkimex', uimex)
130131
np.save('uref', uref)
131132

132-
print " #### Logging report for DIRK-%1i #### " % dirkp.order
133-
print "Number of calls to implicit solver: %5i" % dirkp.logger.solver_calls
134-
print "Total number of GMRES iterations: %5i" % dirkp.logger.iterations
135-
print "Average number of iterations per call: %6.3f" % (float(dirkp.logger.iterations)/float(dirkp.logger.solver_calls))
136-
print " "
137-
print " #### Logging report for RK-IMEX-%1i #### " % rkimex.order
138-
print "Number of calls to implicit solver: %5i" % rkimex.logger.solver_calls
139-
print "Total number of GMRES iterations: %5i" % rkimex.logger.iterations
140-
print "Average number of iterations per call: %6.3f" % (float(rkimex.logger.iterations)/float(rkimex.logger.solver_calls))
141-
print " "
142-
print " #### Logging report for SDC-(%1i,%1i) #### " % (swparams['num_nodes'], sparams['maxiter'])
143-
print "Number of calls to implicit solver: %5i" % P.logger.solver_calls
144-
print "Total number of GMRES iterations: %5i" % P.logger.iterations
145-
print "Average number of iterations per call: %6.3f" % (float(P.logger.iterations)/float(P.logger.solver_calls))
133+
print(" #### Logging report for DIRK-%1i #### " % dirkp.order)
134+
print("Number of calls to implicit solver: %5i" % dirkp.logger.solver_calls)
135+
print("Total number of GMRES iterations: %5i" % dirkp.logger.iterations)
136+
print("Average number of iterations per call: %6.3f" % (float(dirkp.logger.iterations)/float(dirkp.logger.solver_calls)))
137+
print(" ")
138+
print(" #### Logging report for RK-IMEX-%1i #### " % rkimex.order)
139+
print("Number of calls to implicit solver: %5i" % rkimex.logger.solver_calls)
140+
print("Total number of GMRES iterations: %5i" % rkimex.logger.iterations)
141+
print("Average number of iterations per call: %6.3f" % (float(rkimex.logger.iterations)/float(rkimex.logger.solver_calls)))
142+
print(" ")
143+
print(" #### Logging report for SDC-(%1i,%1i) #### " % (swparams['num_nodes'], sparams['maxiter']))
144+
print("Number of calls to implicit solver: %5i" % P.logger.solver_calls)
145+
print("Total number of GMRES iterations: %5i" % P.logger.iterations)
146+
print("Average number of iterations per call: %6.3f" % (float(P.logger.iterations)/float(P.logger.solver_calls)))
146147

examples/boussinesq_2d_imex/standard_integrators.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def f_fast(self, u):
9393

9494
def f_fast_solve(self, rhs, alpha, u0):
9595
cb = Callback()
96-
sol, info = LA.gmres( self.problem.Id - alpha*self.problem.M, rhs, x0=u0, tol=self.problem.gmres_tol, restart=self.problem.gmres_restart, maxiter=self.problem.gmres_maxiter, callback=cb)
96+
sol, info = LA.gmres( self.problem.Id - alpha*self.problem.M, rhs, x0=u0, tol=self.problem.gmres_tol_limit, restart=self.problem.gmres_restart, maxiter=self.problem.gmres_maxiter, callback=cb)
9797
if alpha!=0.0:
9898
#print "RK-IMEX-%1i: Number of GMRES iterations: %3i --- Final residual: %6.3e" % ( self.order, cb.getcounter(), cb.getresidual() )
9999
self.logger.add(cb.getcounter())
@@ -129,7 +129,7 @@ def f(self,u):
129129
#
130130
def f_solve(self, b, alpha, u0):
131131
cb = Callback()
132-
sol, info = LA.gmres( self.problem.Id - alpha*(self.problem.D_upwind + self.problem.M), b, x0=u0, tol=self.problem.gmres_tol, restart=self.problem.gmres_restart, maxiter=self.problem.gmres_maxiter, callback=cb)
132+
sol, info = LA.gmres( self.problem.Id - alpha*(self.problem.D_upwind + self.problem.M), b, x0=u0, tol=self.problem.gmres_tol_limit, restart=self.problem.gmres_restart, maxiter=self.problem.gmres_maxiter, callback=cb)
133133
if alpha!=0.0:
134134
#print "BDF-2: Number of GMRES iterations: %3i --- Final residual: %6.3e" % ( cb.getcounter(), cb.getresidual() )
135135
self.logger.add(cb.getcounter())
@@ -166,7 +166,7 @@ def f(self,u):
166166
#
167167
def f_solve(self, b, alpha, u0):
168168
cb = Callback()
169-
sol, info = LA.gmres( self.problem.Id - alpha*(self.problem.D_upwind + self.problem.M), b, x0=u0, tol=self.problem.gmres_tol, restart=self.problem.gmres_restart, maxiter=self.problem.gmres_maxiter, callback=cb)
169+
sol, info = LA.gmres( self.problem.Id - alpha*(self.problem.D_upwind + self.problem.M), b, x0=u0, tol=self.problem.gmres_tol_limit, restart=self.problem.gmres_restart, maxiter=self.problem.gmres_maxiter, callback=cb)
170170
if alpha!=0.0:
171171
#print "BDF-2: Number of GMRES iterations: %3i --- Final residual: %6.3e" % ( cb.getcounter(), cb.getresidual() )
172172
self.logger.add(cb.getcounter())
@@ -283,7 +283,7 @@ def f(self,u):
283283
#
284284
def f_solve(self, b, alpha, u0):
285285
cb = Callback()
286-
sol, info = LA.gmres( self.problem.Id - alpha*(self.problem.D_upwind + self.problem.M), b, x0=u0, tol=self.problem.gmres_tol, restart=self.problem.gmres_restart, maxiter=self.problem.gmres_maxiter, callback=cb)
286+
sol, info = LA.gmres( self.problem.Id - alpha*(self.problem.D_upwind + self.problem.M), b, x0=u0, tol=self.problem.gmres_tol_limit, restart=self.problem.gmres_restart, maxiter=self.problem.gmres_maxiter, callback=cb)
287287
if alpha!=0.0:
288288
#print "DIRK-%1i: Number of GMRES iterations: %3i --- Final residual: %6.3e" % ( self.order, cb.getcounter(), cb.getresidual() )
289289
self.logger.add(cb.getcounter())

pySDC/Hooks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ def dump_pre(self,status):
5353
"""
5454
pass
5555

56+
def dump_pre_iteration(self,status):
57+
"""
58+
Default routine called before iteration starts
59+
"""
60+
pass
61+
5662

5763
def dump_sweep(self,status):
5864
"""

pySDC/PFASST_blockwise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def pfasst(MS):
251251
if len(S.levels) > 1 and S.params.predict:
252252
S.status.stage = 'PREDICT'
253253
else:
254+
S.levels[0].hooks.dump_pre_iteration(S.status)
254255
S.status.stage = 'IT_FINE'
255256

256257
return MS
@@ -263,6 +264,7 @@ def pfasst(MS):
263264

264265
for S in MS:
265266
# update stage
267+
S.levels[0].hooks.dump_pre_iteration(S.status)
266268
S.status.stage = 'IT_FINE'
267269

268270
return MS

pySDC/PFASST_stepwise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def pfasst(S):
192192
if len(S.levels) > 1 and S.params.predict:
193193
S.status.stage = 'PREDICT_RESTRICT'
194194
else:
195+
S.levels[0].hooks.dump_pre_iteration(S.status)
195196
S.status.stage = 'IT_FINE_SWEEP'
196197
return S
197198

@@ -254,6 +255,7 @@ def pfasst(S):
254255
S.transfer(source=S.levels[l],target=S.levels[l-1])
255256

256257
# uodate stage and return
258+
S.levels[0].hooks.dump_pre_iteration(S.status)
257259
S.status.stage = 'IT_FINE_SWEEP'
258260
return S
259261

pySDC/Sweeper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def predict(self):
9595

9696
# indicate that this level is now ready for sweeps
9797
L.status.unlocked = True
98+
L.status.updated = True
9899

99100
return None
100101

0 commit comments

Comments
 (0)