Skip to content

Commit 3b49d5f

Browse files
committed
Merge branch 'GS_WP' into neuralpint
2 parents 9d3ef1d + 28d6715 commit 3b49d5f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+4387
-924
lines changed

pySDC/core/controller.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pySDC.helpers.pysdc_helper import FrozenClass
88
from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
99
from pySDC.implementations.hooks.default_hook import DefaultHooks
10+
from pySDC.implementations.hooks.log_timings import CPUTimings
1011

1112

1213
# short helper class to add params as attributes
@@ -43,7 +44,7 @@ def __init__(self, controller_params, description, useMPI=None):
4344

4445
# check if we have a hook on this list. If not, use default class.
4546
self.__hooks = []
46-
hook_classes = [DefaultHooks]
47+
hook_classes = [DefaultHooks, CPUTimings]
4748
user_hooks = controller_params.get('hook_class', [])
4849
hook_classes += user_hooks if type(user_hooks) == list else [user_hooks]
4950
[self.add_hook(hook) for hook in hook_classes]

pySDC/core/convergence_controller.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,19 @@ def post_step_processing(self, controller, S, **kwargs):
292292
"""
293293
pass
294294

295+
def post_run_processing(self, controller, S, **kwargs):
296+
"""
297+
Do whatever you want to after the run here.
298+
299+
Args:
300+
controller (pySDC.Controller): The controller
301+
S (pySDC.Step): The current step
302+
303+
Returns:
304+
None
305+
"""
306+
pass
307+
295308
def prepare_next_block(self, controller, S, size, time, Tend, **kwargs):
296309
"""
297310
Prepare stuff like spreading step sizes or whatever.

pySDC/helpers/NCCL_communicator.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __getattr__(self, name):
2727
Args:
2828
Name (str): Name of the requested attribute
2929
"""
30-
if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split']:
30+
if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split', 'Create_cart', 'Is_inter', 'Get_topology']:
3131
cp.cuda.get_current_stream().synchronize()
3232

3333
return getattr(self.commMPI, name)
@@ -71,6 +71,26 @@ def get_op(self, MPI_op):
7171
else:
7272
raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!')
7373

74+
def reduce(self, sendobj, op=MPI.SUM, root=0):
75+
sync = False
76+
if hasattr(sendobj, 'data'):
77+
if hasattr(sendobj.data, 'ptr'):
78+
sync = True
79+
if sync:
80+
cp.cuda.Device().synchronize()
81+
82+
return self.commMPI.reduce(sendobj, op=op, root=root)
83+
84+
def allreduce(self, sendobj, op=MPI.SUM):
85+
sync = False
86+
if hasattr(sendobj, 'data'):
87+
if hasattr(sendobj.data, 'ptr'):
88+
sync = True
89+
if sync:
90+
cp.cuda.Device().synchronize()
91+
92+
return self.commMPI.allreduce(sendobj, op=op)
93+
7494
def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
7595
if not hasattr(sendbuf.data, 'ptr'):
7696
return self.commMPI.Reduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op, root=root)
@@ -113,3 +133,7 @@ def Bcast(self, buf, root=0):
113133
stream = cp.cuda.get_current_stream()
114134

115135
self.commNCCL.bcast(buff=buf.data.ptr, count=count, datatype=dtype, root=root, stream=stream.ptr)
136+
137+
def Barrier(self):
138+
cp.cuda.get_current_stream().synchronize()
139+
self.commMPI.Barrier()

pySDC/helpers/plot_helper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def figsize_by_journal(journal, scale, ratio): # pragma: no cover
4242
textwidths = {
4343
'JSC_beamer': 426.79135,
4444
'Springer_Numerical_Algorithms': 338.58778,
45+
'Springer_proceedings': 347.12354,
4546
'JSC_thesis': 434.26027,
4647
'TUHH_thesis': 426.79135,
4748
}
@@ -50,6 +51,7 @@ def figsize_by_journal(journal, scale, ratio): # pragma: no cover
5051
'JSC_beamer': 214.43411,
5152
'JSC_thesis': 635.5,
5253
'TUHH_thesis': 631.65118,
54+
'Springer_proceedings': 549.13828,
5355
}
5456
assert (
5557
journal in textwidths.keys()

pySDC/helpers/spectral_helper.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,7 @@ def __init__(self, comm=None, useGPU=False, debug=False):
882882
self.BCs = None
883883

884884
self.fft_cache = {}
885+
self.fft_dealias_shape_cache = {}
885886

886887
@property
887888
def u_init(self):
@@ -1470,8 +1471,13 @@ def _transform_dct(self, u, axes, padding=None, **kwargs):
14701471

14711472
if padding is not None:
14721473
shape = list(v.shape)
1473-
if self.comm:
1474-
shape[0] = self.comm.allreduce(v.shape[0])
1474+
if ('forward', *padding) in self.fft_dealias_shape_cache.keys():
1475+
shape[0] = self.fft_dealias_shape_cache[('forward', *padding)]
1476+
elif self.comm:
1477+
send_buf = np.array(v.shape[0])
1478+
recv_buf = np.array(v.shape[0])
1479+
self.comm.Allreduce(send_buf, recv_buf)
1480+
shape[0] = int(recv_buf)
14751481
fft = self.get_fft(axes, 'forward', shape=shape)
14761482
else:
14771483
fft = self.get_fft(axes, 'forward', **kwargs)
@@ -1642,8 +1648,13 @@ def _transform_idct(self, u, axes, padding=None, **kwargs):
16421648
if padding is not None:
16431649
if padding[axis] != 1:
16441650
shape = list(v.shape)
1645-
if self.comm:
1646-
shape[0] = self.comm.allreduce(v.shape[0])
1651+
if ('backward', *padding) in self.fft_dealias_shape_cache.keys():
1652+
shape[0] = self.fft_dealias_shape_cache[('backward', *padding)]
1653+
elif self.comm:
1654+
send_buf = np.array(v.shape[0])
1655+
recv_buf = np.array(v.shape[0])
1656+
self.comm.Allreduce(send_buf, recv_buf)
1657+
shape[0] = int(recv_buf)
16471658
ifft = self.get_fft(axes, 'backward', shape=shape)
16481659
else:
16491660
ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs)
@@ -1748,8 +1759,6 @@ def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs):
17481759
if self.comm.size == 1:
17491760
return u.copy()
17501761

1751-
fft = self.get_fft(**kwargs) if fft is None else fft
1752-
17531762
global_fft = self.get_fft(**kwargs)
17541763
axisA = [me.axisA for me in global_fft.transfer]
17551764
axisB = [me.axisB for me in global_fft.transfer]
@@ -1787,6 +1796,8 @@ def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs):
17871796
else: # go the potentially slower route of not reusing transfer classes
17881797
from mpi4py_fft import newDistArray
17891798

1799+
fft = self.get_fft(**kwargs) if fft is None else fft
1800+
17901801
_in = newDistArray(fft, forward).redistribute(axis_in)
17911802
_in[...] = u
17921803

pySDC/implementations/controller_classes/controller_MPI.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def run(self, u0, t0, Tend):
160160
for hook in self.hooks:
161161
hook.post_run(step=self.S, level_number=0)
162162

163+
for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
164+
C.post_run_processing(self, self.S, comm=self.comm)
165+
163166
comm_active.Free()
164167

165168
return uend, self.return_stats()

pySDC/implementations/controller_classes/controller_nonMPI.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,10 @@ def run(self, u0, t0, Tend):
171171
for hook in self.hooks:
172172
hook.post_run(step=S, level_number=0)
173173

174+
for S in self.MS:
175+
for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
176+
C.post_run_processing(self, S, MS=MS_active)
177+
174178
return uend, self.return_stats()
175179

176180
def restart_block(self, active_slots, time, u0):

pySDC/implementations/convergence_controller_classes/adaptivity.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,16 @@ def determine_restart(self, controller, S, **kwargs):
229229
if self.get_convergence(controller, S, **kwargs):
230230
self.res_last_iter = np.inf
231231

232-
if self.params.restart_at_maxiter and S.levels[0].status.residual > S.levels[0].params.restol:
232+
L = S.levels[0]
233+
e_tol_converged = (
234+
L.status.increment < L.params.e_tol if (L.params.get('e_tol') and L.status.get('increment')) else False
235+
)
236+
237+
if (
238+
self.params.restart_at_maxiter
239+
and S.levels[0].status.residual > S.levels[0].params.restol
240+
and not e_tol_converged
241+
):
233242
self.trigger_restart_upon_nonconvergence(S)
234243
elif self.get_local_error_estimate(controller, S, **kwargs) > self.params.e_tol:
235244
S.status.restart = True

pySDC/implementations/convergence_controller_classes/check_convergence.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,7 @@ def check_convergence(S, self=None):
7575
iter_converged = S.status.iter >= S.params.maxiter
7676
res_converged = L.status.residual <= L.params.restol
7777
e_tol_converged = (
78-
L.status.error_embedded_estimate < L.params.e_tol
79-
if (L.params.get('e_tol') and L.status.get('error_embedded_estimate'))
80-
else False
78+
L.status.increment < L.params.e_tol if (L.params.get('e_tol') and L.status.get('increment')) else False
8179
)
8280
converged = (
8381
iter_converged or res_converged or e_tol_converged or S.status.force_done

pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ def estimate_embedded_error_serial(self, L):
9393
dtype_u: The embedded error estimate
9494
"""
9595
if self.params.sweeper_type == "RK":
96-
# lower order solution is stored in the second to last entry of L.u
97-
return abs(L.u[-2] - L.u[-1])
96+
L.sweep.compute_end_point()
97+
return abs(L.uend - L.sweep.u_secondary)
9898
elif self.params.sweeper_type == "SDC":
99-
# order rises by one between sweeps, making this so ridiculously easy
99+
# order rises by one between sweeps
100100
return abs(L.uold[-1] - L.u[-1])
101101
elif self.params.sweeper_type == 'MPI':
102102
comm = L.sweep.comm
@@ -109,12 +109,13 @@ def estimate_embedded_error_serial(self, L):
109109

110110
def setup_status_variables(self, controller, **kwargs):
111111
"""
112-
Add the embedded error variable to the error function.
112+
Add the embedded error to the level status
113113
114114
Args:
115115
controller (pySDC.Controller): The controller
116116
"""
117117
self.add_status_variable_to_level('error_embedded_estimate')
118+
self.add_status_variable_to_level('increment')
118119

119120
def post_iteration_processing(self, controller, S, **kwargs):
120121
"""
@@ -134,6 +135,7 @@ def post_iteration_processing(self, controller, S, **kwargs):
134135
if S.status.iter > 0 or self.params.sweeper_type == "RK":
135136
for L in S.levels:
136137
L.status.error_embedded_estimate = max([self.estimate_embedded_error_serial(L), np.finfo(float).eps])
138+
L.status.increment = L.status.error_embedded_estimate * 1
137139
self.debug(f'L.status.error_embedded_estimate={L.status.error_embedded_estimate:.5e}', S)
138140

139141
return None

0 commit comments

Comments
 (0)