|
6 | 6 | from pySDC.core.base_transfer import BaseTransfer |
7 | 7 | from pySDC.helpers.pysdc_helper import FrozenClass |
8 | 8 | from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence |
| 9 | +from pySDC.implementations.convergence_controller_classes.store_uold import StoreUOld |
9 | 10 | from pySDC.implementations.hooks.default_hook import DefaultHooks |
10 | 11 | from pySDC.implementations.hooks.log_timings import CPUTimings |
11 | 12 |
|
@@ -41,6 +42,7 @@ def __init__(self, controller_params, description, useMPI=None): |
41 | 42 | controller_params (dict): parameter set for the controller and the steps |
42 | 43 | """ |
43 | 44 | self.useMPI = useMPI |
| 45 | + self.description = description |
44 | 46 |
|
45 | 47 | # check if we have a hook on this list. If not, use default class. |
46 | 48 | self.__hooks = [] |
@@ -341,3 +343,68 @@ def return_stats(self): |
341 | 343 | for hook in self.hooks: |
342 | 344 | stats = {**stats, **hook.return_stats()} |
343 | 345 | return stats |
| 346 | + |
| 347 | + |
| 348 | +class ParaDiagController(Controller): |
| 349 | + |
| 350 | + def __init__(self, controller_params, description, n_steps, useMPI=None): |
| 351 | + """ |
| 352 | + Initialization routine for ParaDiag controllers |
| 353 | +
|
| 354 | + Args: |
| 355 | + num_procs: number of parallel time steps (still serial, though), can be 1 |
| 356 | + controller_params: parameter set for the controller and the steps |
| 357 | + description: all the parameters to set up the rest (levels, problems, transfer, ...) |
| 358 | + n_steps (int): Number of parallel steps |
| 359 | + alpha (float): alpha parameter for ParaDiag |
| 360 | + """ |
| 361 | + # TODO: where should I put alpha? When I want to adapt it, maybe it shouldn't be in the controller? |
| 362 | + from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization |
| 363 | + |
| 364 | + if QDiagonalization in description['sweeper_class'].__mro__: |
| 365 | + description['sweeper_params']['ignore_ic'] = True |
| 366 | + description['sweeper_params']['update_f_evals'] = False |
| 367 | + else: |
| 368 | + logging.getLogger('controller').warning( |
| 369 | + f'Warning: Your sweeper class {description["sweeper_class"]} is not derived from {QDiagonalization}. You probably want to use another sweeper class.' |
| 370 | + ) |
| 371 | + |
| 372 | + if controller_params.get('all_to_done', False): |
| 373 | + raise NotImplementedError('ParaDiag only implemented with option `all_to_done=True`') |
| 374 | + if 'alpha' not in controller_params.keys(): |
| 375 | + from pySDC.core.errors import ParameterError |
| 376 | + |
| 377 | + raise ParameterError('Please supply alpha as a parameter to the ParaDiag controller!') |
| 378 | + controller_params['average_jacobian'] = controller_params.get('average_jacobian', True) |
| 379 | + |
| 380 | + controller_params['all_to_done'] = True |
| 381 | + super().__init__(controller_params=controller_params, description=description, useMPI=useMPI) |
| 382 | + self.base_convergence_controllers += [StoreUOld] |
| 383 | + |
| 384 | + self.ParaDiag_block_u0 = None |
| 385 | + self.n_steps = n_steps |
| 386 | + |
| 387 | + def FFT_in_time(self): |
| 388 | + """ |
| 389 | + Compute weighted forward FFT in time. The weighting is determined by the alpha parameter in ParaDiag |
| 390 | +
|
| 391 | + Note: The implementation via matrix-vector multiplication may be inefficient and less stable compared to an FFT |
| 392 | + with transposes! |
| 393 | + """ |
| 394 | + if not hasattr(self, '__FFT_matrix'): |
| 395 | + from pySDC.helpers.ParaDiagHelper import get_weighted_FFT_matrix |
| 396 | + |
| 397 | + self.__FFT_matrix = get_weighted_FFT_matrix(self.n_steps, self.params.alpha) |
| 398 | + |
| 399 | + self.apply_matrix(self.__FFT_matrix) |
| 400 | + |
| 401 | + def iFFT_in_time(self): |
| 402 | + """ |
| 403 | + Compute weighted backward FFT in time. The weighting is determined by the alpha parameter in ParaDiag |
| 404 | + """ |
| 405 | + if not hasattr(self, '__iFFT_matrix'): |
| 406 | + from pySDC.helpers.ParaDiagHelper import get_weighted_iFFT_matrix |
| 407 | + |
| 408 | + self.__iFFT_matrix = get_weighted_iFFT_matrix(self.n_steps, self.params.alpha) |
| 409 | + |
| 410 | + self.apply_matrix(self.__iFFT_matrix) |
0 commit comments