Skip to content

Commit d0a4a9b

Browse files
committed
state_params: usage in forcing/specific.py
1 parent 4b0a623 commit d0a4a9b

File tree

2 files changed

+49
-33
lines changed

2 files changed

+49
-33
lines changed

fluidsim/base/forcing/specific.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -804,39 +804,51 @@ def __init__(self, sim):
804804
super().__init__(sim)
805805

806806
if mpi.rank == 0:
807-
path_input_forcing_state = self._forcing_state_file_path = (
808-
Path(sim.output.path_run) / "_forcing_state.txt"
809-
)
810-
807+
state_params = sim.state.state_params
811808
if (
812-
not path_input_forcing_state.exists()
813-
and sim.params.NEW_DIR_RESULTS
814-
and sim.params.init_fields.from_file.path != ""
809+
state_params is None
810+
or "forcing" not in state_params._tag_children
815811
):
812+
state_params = sim.state.get_state_params()
813+
state_params._set_child(
814+
"forcing",
815+
attribs={"t_last_change": None, "seed0": None, "seed1": None},
816+
)
817+
818+
# _forcing_state.txt is an old format (<=0.8.6)
819+
# this code is a bit complicated because we try to continue to
820+
# load correctly simulations using it.
821+
self.t_last_change = None
816822
path_input_forcing_state = (
817-
Path(sim.params.init_fields.from_file.path).parent.parent
818-
/ "_forcing_state.txt"
823+
Path(sim.output.path_run) / "_forcing_state.txt"
819824
)
820-
if not path_input_forcing_state.exists():
821-
warn(
822-
"Restarting a forced simulation but file "
823-
f"{path_input_forcing_state} does not exist."
825+
if (
826+
not path_input_forcing_state.exists()
827+
and sim.params.NEW_DIR_RESULTS
828+
and sim.params.init_fields.from_file.path != ""
829+
):
830+
path_input_forcing_state = (
831+
Path(sim.params.init_fields.from_file.path).parent.parent
832+
/ "_forcing_state.txt"
824833
)
825-
826-
if path_input_forcing_state.exists():
827-
lines = path_input_forcing_state.read_text().split("\n")
828-
t_last_change, seed0, seed1 = lines[-2].split()
829-
self.t_last_change = float(t_last_change)
830-
self._seed0 = int(seed0)
831-
self._seed1 = int(seed1)
834+
if path_input_forcing_state.exists():
835+
lines = path_input_forcing_state.read_text().split("\n")
836+
t_last_change, seed0, seed1 = lines[-2].split()
837+
self.t_last_change = float(t_last_change)
838+
self._seed0 = int(seed0)
839+
self._seed1 = int(seed1)
840+
self._update_sim_state()
832841
else:
842+
p_forcing = state_params.forcing
843+
self.t_last_change = p_forcing.t_last_change
844+
self._seed0 = p_forcing.seed0
845+
self._seed1 = p_forcing.seed1
846+
847+
if self.t_last_change is None:
833848
self.t_last_change = self.sim.time_stepping.t
834849
self._seed0 = np.random.randint(0, 2**31)
835850
self._seed1 = np.random.randint(0, 2**31)
836-
self._save_state()
837-
838-
if not self._forcing_state_file_path.exists():
839-
self._save_state()
851+
self._update_sim_state()
840852

841853
np.random.seed(self._seed0)
842854
self.forcing0 = self.compute_forcingc_raw()
@@ -868,19 +880,16 @@ def forcingc_raw_each_time(self, a_fft):
868880
self._seed1 = np.random.randint(0, 2**31)
869881
np.random.seed(self._seed1)
870882
self.forcing1 = self.compute_forcingc_raw()
871-
self._save_state()
883+
self._update_sim_state()
872884

873885
f_fft = self.forcingc_from_f0f1()
874886
return f_fft
875887

876-
def _save_state(self):
877-
if not self.params.output.HAS_TO_SAVE:
878-
return
879-
880-
self._forcing_state_file_path.write_text(
881-
"# do not modify by hand\n# t_last_change seed0 seed1\n"
882-
f"{self.t_last_change} {self._seed0} {self._seed1}\n"
883-
)
888+
def _update_sim_state(self):
889+
p_forcing = self.sim.state.state_params.forcing
890+
p_forcing.t_last_change = self.t_last_change
891+
p_forcing.seed0 = self._seed0
892+
p_forcing.seed1 = self._seed1
884893

885894
def forcingc_from_f0f1(self):
886895
"""Return a coarse forcing as a linear combination of 2 random arrays"""

fluidsim/solvers/ns2d/bouss/test_solver.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ def test_forcing_output(self):
103103
path_run = mpi.comm.bcast(path_run)
104104

105105
sim3 = fls.load_state_phys_file(path_run, modif_save_params=False)
106+
assert (
107+
sim3.state.state_params.forcing == sim.state.state_params.forcing
108+
)
109+
assert (
110+
sim3.forcing.forcing_maker._seed0
111+
== sim.forcing.forcing_maker._seed0
112+
)
106113
sim3.params.time_stepping.t_end += 0.2
107114
sim3.time_stepping.start()
108115

0 commit comments

Comments
 (0)