Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions etc/environment-mpi4py.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ dependencies:
- pip
- pip:
- qmat>=0.1.8
- pytest-isolate-mpi
12 changes: 12 additions & 0 deletions pySDC/core/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ def f_init(self):
def get_default_sweeper_class(cls):
raise NotImplementedError(f'No default sweeper class implemented for {cls} problem!')

def setUpFieldsIO(self):
"""
Set up FieldsIO for MPI with the space decomposition of this problem
"""
pass

def getOutputFile(self, fileName):
raise NotImplementedError(f'No output implemented file for {type(self).__name__}')

def processSolutionForOutput(self, u):
return u

def eval_f(self, u, t):
"""
Abstract interface to RHS computation of the ODE
Expand Down
7 changes: 4 additions & 3 deletions pySDC/helpers/fieldsIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ def initialize(self):
assert not self.initialized, "FieldsIO already initialized"

if not self.ALLOW_OVERWRITE:
assert not os.path.isfile(
self.fileName
), f"file {self.fileName!r} already exists, use FieldsIO.ALLOW_OVERWRITE = True to allow overwriting"
if os.path.isfile(self.fileName):
raise FileExistsError(
f"file {self.fileName!r} already exists, use FieldsIO.ALLOW_OVERWRITE = True to allow overwriting"
)

with open(self.fileName, "w+b") as f:
self.hBase.tofile(f)
Expand Down
65 changes: 63 additions & 2 deletions pySDC/implementations/hooks/log_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pickle
import os
import numpy as np
from pySDC.helpers.fieldsIO import FieldsIO
from pySDC.core.errors import DataError


class LogSolution(Hooks):
Expand Down Expand Up @@ -68,7 +70,7 @@ def post_iteration(self, step, level_number):
)


class LogToFile(Hooks):
class LogToPickleFile(Hooks):
r"""
Hook for logging the solution to file after the step using pickle.

Expand Down Expand Up @@ -171,7 +173,7 @@ def load(cls, index):
return pickle.load(file)


class LogToFileAfterXs(LogToFile):
class LogToPickleFileAfterXS(LogToPickleFile):
r'''
Log to file after certain amount of time has passed instead of after every step
'''
Expand Down Expand Up @@ -200,3 +202,62 @@ def process_solution(L):
}

self.log_to_file(step, level_number, type(self).logging_condition(L), process_solution=process_solution)


class LogToFile(Hooks):
filename = 'myRun.pySDC'
time_increment = 0
allow_overwriting = False

def __init__(self):
super().__init__()
self.outfile = None
self.t_next_log = 0
FieldsIO.ALLOW_OVERWRITE = self.allow_overwriting

def pre_run(self, step, level_number):
if level_number > 0:
return None
L = step.levels[level_number]

# setup outfile
if os.path.isfile(self.filename) and L.time > 0:
L.prob.setUpFieldsIO()
self.outfile = FieldsIO.fromFile(self.filename)
self.logger.info(
f'Set up file {self.filename!r} for writing output. This file already contains solutions up to t={self.outfile.times[-1]:.4f}.'
)
else:
self.outfile = L.prob.getOutputFile(self.filename)
self.logger.info(f'Set up file {self.filename!r} for writing output.')

# write initial conditions
if L.time not in self.outfile.times:
self.outfile.addField(time=L.time, field=L.prob.processSolutionForOutput(L.u[0]))
self.logger.info(f'Written initial conditions at t={L.time:4f} to file')

def post_step(self, step, level_number):
if level_number > 0:
return None

L = step.levels[level_number]

if self.t_next_log == 0:
self.t_next_log = L.time + self.time_increment

if L.time + L.dt >= self.t_next_log and not step.status.restart:
value_exists = True in [abs(me - (L.time + L.dt)) < np.finfo(float).eps * 1000 for me in self.outfile.times]
if value_exists and not self.allow_overwriting:
raise DataError(f'Already have recorded data for time {L.time + L.dt} in this file!')
self.outfile.addField(time=L.time + L.dt, field=L.prob.processSolutionForOutput(L.uend))
self.logger.info(f'Written solution at t={L.time+L.dt:.4f} to file')
self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment

@classmethod
def load(cls, index):
data = {}
file = FieldsIO.fromFile(cls.filename)
file_entry = file.readField(idx=index)
data['u'] = file_entry[1]
data['t'] = file_entry[0]
return data
10 changes: 10 additions & 0 deletions pySDC/implementations/problem_classes/TestEquation_0D.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pySDC.core.problem import Problem, WorkCounter
from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
from pySDC.helpers.fieldsIO import Scalar


class testequation0d(Problem):
Expand Down Expand Up @@ -146,6 +147,15 @@ def u_exact(self, t, u_init=None, t_init=None):
me[:] = u_init * self.xp.exp((t - t_init) * self.lambdas)
return me

def getOutputFile(self, fileName):
fOut = Scalar(np.complex128, fileName=fileName)
fOut.setHeader(self.lambdas.size)
fOut.initialize()
return fOut

def processSolutionForOutput(self, u):
return u.flatten()


class test_equation_IMEX(Problem):
dtype_f = imex_mesh
Expand Down
25 changes: 25 additions & 0 deletions pySDC/implementations/problem_classes/generic_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pySDC.helpers.spectral_helper import SpectralHelper
import numpy as np
from pySDC.core.errors import ParameterError
from pySDC.helpers.fieldsIO import Rectilinear


class GenericSpectralLinear(Problem):
Expand Down Expand Up @@ -333,6 +334,30 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)

return sol

def setUpFieldsIO(self):
Rectilinear.setupMPI(
comm=self.comm,
iLoc=[me.start for me in self.local_slice],
nLoc=[me.stop - me.start for me in self.local_slice],
)

def getOutputFile(self, fileName):
self.setUpFieldsIO()

coords = [me.get_1dgrid() for me in self.spectral.axes]
assert np.allclose([len(me) for me in coords], self.spectral.global_shape[1:])

fOut = Rectilinear(np.float64, fileName=fileName)
fOut.setHeader(nVar=len(self.components), coords=coords)
fOut.initialize()
return fOut

def processSolutionForOutput(self, u):
if self.spectral_space:
return np.array(self.itransform(u).real)
else:
return np.array(u.real)


def compute_residual_DAE(self, stage=''):
"""
Expand Down
2 changes: 1 addition & 1 deletion pySDC/projects/GPU/configs/GS_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class GrayScott(Config):

def get_LogToFile(self, ranks=None):
import numpy as np
from pySDC.implementations.hooks.log_solution import LogToFileAfterXs as LogToFile
from pySDC.implementations.hooks.log_solution import LogToPickleFileAfterXS as LogToFile

LogToFile.path = f'{self.base_path}/data/'
LogToFile.file_name = f'{self.get_path(ranks=ranks)}-solution'
Expand Down
2 changes: 1 addition & 1 deletion pySDC/projects/GPU/configs/RBC_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class RayleighBenardRegular(Config):

def get_LogToFile(self, ranks=None):
import numpy as np
from pySDC.implementations.hooks.log_solution import LogToFileAfterXs as LogToFile
from pySDC.implementations.hooks.log_solution import LogToPickleFileAfterXS as LogToFile

LogToFile.path = f'{self.base_path}/data/'
LogToFile.file_name = f'{self.get_path(ranks=ranks)}-solution'
Expand Down
2 changes: 1 addition & 1 deletion pySDC/projects/GPU/tests/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_description(self, *args, **kwargs):
return desc

def get_LogToFile(self, ranks=None):
from pySDC.implementations.hooks.log_solution import LogToFileAfterXs as LogToFile
from pySDC.implementations.hooks.log_solution import LogToPickleFileAfterXS as LogToFile

LogToFile.path = './data/'
LogToFile.file_name = f'{self.get_path(ranks=ranks)}-solution'
Expand Down
Loading