Skip to content

Commit d91bd1a

Browse files
authored
Merge pull request #31 from libAtoms/save_global_ind_of_samples
save info to reconstruct history of clones, as well as saving all snapshots
2 parents 020849b + 61b11bf commit d91bd1a

File tree

5 files changed

+57
-37
lines changed

5 files changed

+57
-37
lines changed

pymatnext/cli/sample.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pymatnext.ns import NS
1919
from pymatnext.params import check_fill_defaults
2020
from pymatnext.sample_params import param_defaults
21+
from pymatnext.sample_utils import truncate_file_first_col_iter
2122

2223
from pymatnext.loop_exit import NSLoopExit
2324

@@ -166,44 +167,24 @@ def sample(args, MPI, NS_comm, walker_comm):
166167
traj_interval = params_global["traj_interval"]
167168
sample_interval = params_global["sample_interval"]
168169
snapshot_interval = params_global["snapshot_interval"]
170+
snapshot_save_old = params_global["snapshot_save_old"]
169171
stdout_report_interval_s = params_global["stdout_report_interval_s"]
170172
step_size_tune_interval = params_step_size_tune["interval"]
171-
# WARNING: clone_history_file not restartable
172-
if params_global["clone_history"]:
173-
clone_history_file = open(f"{output_filename_prefix}.clone_history", "w")
174-
clone_history_file.write(f'# {{"fields": ["loop_iter", "clone_source", "clone_target"], "n_walkers": {ns.n_configs_global}}}\n')
175-
else:
176-
clone_history_file = None
177173

178174
ns_file_name = f"{output_filename_prefix}.NS_samples"
175+
clone_history_file_name = f"{output_filename_prefix}.clone_history" if params_global["clone_history"] else None
179176
traj_file_name = f"{output_filename_prefix}.traj{config_suffix}"
180177

181178
if NS_comm.rank == 0:
179+
# set up I/O
182180
if ns.snapshot_iter >= 0:
183-
# snapshot, truncate existing NS_samples and .traj.suffix files
184-
185-
# NOTE: does this code belong here? Maybe refactor to a function, maybe
186-
# move trajectory truncation into NSConfig or something?
187-
188-
# truncate .NS_samples file
189-
f_samples = open(ns_file_name, "r+")
190-
# skip header
191-
_ = f_samples.readline()
192-
line_i = None
193-
while True:
194-
line = f_samples.readline()
195-
if not line:
196-
raise RuntimeError(f"Failed to find enough lines in .NS_samples file (last line {line_i}) to reach snapshot iter {ns.snapshot_iter}")
197-
198-
line_i = int(line.split()[0])
199-
if line_i + sample_interval > ns.snapshot_iter:
200-
cur_pos = f_samples.tell()
201-
f_samples.truncate(cur_pos)
202-
break
181+
# snapshot, truncate existing .NS_samples, .clone_history, and .traj.<suffix> files
203182

204-
f_samples.close()
183+
truncate_file_first_col_iter(ns_file_name, n_header=1, sample_interval=sample_interval, max_iter=ns.snapshot_iter)
184+
truncate_file_first_col_iter(clone_history_file_name, n_header=1, sample_interval=1, max_iter=ns.snapshot_iter)
205185

206-
# truncate .traj.suffix file
186+
# NOTE: should move trajectory truncation into NSConfig, since it's config file-format specific
187+
# truncate .traj.<suffix> file
207188
f_configs = open(traj_file_name, "r+")
208189
while True:
209190
try:
@@ -220,16 +201,27 @@ def sample(args, MPI, NS_comm, walker_comm):
220201

221202
ns_file = open(ns_file_name, "a")
222203
traj_file = open(traj_file_name, "a")
204+
clone_history_file = open(clone_history_file_name, "a")
223205

224206
else:
225-
# run from start, open new .NS_samples and .traj.suffix files
226-
207+
# run from start, open new .NS_samples, .clone_history, and .traj.<suffix> files
208+
# write header as needed
227209
ns_file = open(ns_file_name, "w")
228210
header_dict = { "n_walkers": ns.n_configs_global, "n_cull": 1 }
229211
header_dict.update(ns.local_configs[0].header_dict())
230212
ns_file.write("# " + " ".join(json.dumps(header_dict, indent=0).splitlines()) + "\n")
231213

214+
if clone_history_file_name:
215+
clone_history_file = open(clone_history_file_name, "w")
216+
clone_history_file.write(f'# {{"fields": ["loop_iter", "clone_source", "clone_target"], "n_walkers": {ns.n_configs_global}}}\n')
217+
else:
218+
clone_history_file = None
219+
232220
traj_file = open(traj_file_name, "w")
221+
else:
222+
ns_file = None
223+
traj_file = None
224+
clone_history_file = None
233225

234226
max_iter = params_global["max_iter"]
235227
if max_iter > 0:
@@ -250,7 +242,7 @@ def sample(args, MPI, NS_comm, walker_comm):
250242
global_ind_of_max = ns.global_ind(ns.rank_of_max, ns.local_ind_of_max)
251243

252244
# write quantities for max config which will be culled below
253-
if NS_comm.rank == 0 and sample_interval > 0 and loop_iter % sample_interval == 0:
245+
if ns_file and sample_interval > 0 and loop_iter % sample_interval == 0:
254246
ns_file.write(f"{loop_iter} {global_ind_of_max} {ns.max_val:.10f} " + " ".join([f"{quant:.10f}" for quant in ns.max_quants]) + "\n")
255247
ns_file.flush()
256248

@@ -265,14 +257,14 @@ def sample(args, MPI, NS_comm, walker_comm):
265257
global_ind_of_clone_source = (global_ind_of_max + 1 + ns.rng_global.integers(0, ns.n_configs_global - 1)) % ns.n_configs_global
266258
rank_of_clone_source, local_ind_of_clone_source = ns.local_ind(global_ind_of_clone_source)
267259

268-
if clone_history_file is not None:
260+
if clone_history_file:
269261
clone_history_file.write(f"{loop_iter} {global_ind_of_clone_source} {global_ind_of_max}\n")
270262
if loop_iter % 1000 == 1000 - 1:
271263
clone_history_file.flush()
272264

273265
# write max to traj file
274266
if traj_interval > 0 and loop_iter % traj_interval == 0:
275-
if NS_comm.rank == 0:
267+
if traj_file:
276268
# only head node writes
277269
if NS_comm.rank == ns.rank_of_max:
278270
# already local
@@ -326,12 +318,16 @@ def sample(args, MPI, NS_comm, walker_comm):
326318
# NOTE: should this be a time rather than iteration interval? That'd basically be straightforward,
327319
# except it would require an additional communication so all processes agree that it's time for a snapshot
328320
if loop_iter > 0 and snapshot_interval > 0 and loop_iter % snapshot_interval == 0:
329-
ns.snapshot(loop_iter, output_filename_prefix)
321+
ns.snapshot(loop_iter, output_filename_prefix, save_old=snapshot_save_old)
330322

331323
loop_iter += 1
332324

333-
if clone_history_file is not None:
325+
if ns_file:
326+
ns_file.close()
327+
if clone_history_file:
334328
clone_history_file.close()
329+
if traj_file:
330+
traj_file.close()
335331

336332

337333
def main(args_list=None, mpi_finalize=True):

pymatnext/ns.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,11 +480,11 @@ def snapshot(self, loop_iter, output_filename_prefix, save_old=2):
480480
output_filename_prefix: str
481481
initial part of filenames that will be written to
482482
save_old: int, default 2
483-
number of old snapshots to save
483+
number of old snapshots to save, negative to save all
484484
"""
485485
if self.comm.rank == 0:
486486
old_state_files = NS._old_state_files(output_filename_prefix)
487-
if len(old_state_files) > save_old - 1:
487+
if save_old >= 0 and len(old_state_files) > save_old - 1:
488488
old_state_files = old_state_files[:-(save_old-1)]
489489
else:
490490
old_state_files = []

pymatnext/ns_utils.py

100755100644
File mode changed.

pymatnext/sample_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"sample_interval": 1,
1212
"traj_interval": 100,
1313
"snapshot_interval": 10000,
14+
"snapshot_save_old": 2,
1415
"step_size_tune": {
1516
"interval": 1000,
1617
"n_configs": 1,

pymatnext/sample_utils.py

100755100644
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
class NullComm():
24
"""Fake alternative to mpi4py.MPI.Comm for serial run, which implements needed subset
35
of mpi4py.MPI.Comm methods
@@ -36,3 +38,24 @@ class MPI:
3638

3739
def Finalize():
3840
return
41+
42+
43+
def truncate_file_first_col_iter(filename, n_header, sample_interval, max_iter):
44+
warnings.warn(f"Truncating {filename}")
45+
# truncate file after first col exceeds iteration
46+
with open(filename, "r+") as fd:
47+
# skip header
48+
for _ in range(n_header):
49+
_ = fd.readline()
50+
line_i = None
51+
while True:
52+
line = fd.readline()
53+
if not line:
54+
raise RuntimeError(f"Failed to find enough lines in {filename} (last line {line_i}) to reach snapshot iter {max_iter}")
55+
56+
line_i = int(line.split()[0])
57+
if line_i + sample_interval > max_iter:
58+
warnings.warn(f"Truncated {filename} at iter {line_i}")
59+
cur_pos = fd.tell()
60+
fd.truncate(cur_pos)
61+
break

0 commit comments

Comments
 (0)