1818from pymatnext .ns import NS
1919from pymatnext .params import check_fill_defaults
2020from pymatnext .sample_params import param_defaults
21+ from pymatnext .sample_utils import truncate_file_first_col_iter
2122
2223from pymatnext .loop_exit import NSLoopExit
2324
@@ -166,44 +167,24 @@ def sample(args, MPI, NS_comm, walker_comm):
166167 traj_interval = params_global ["traj_interval" ]
167168 sample_interval = params_global ["sample_interval" ]
168169 snapshot_interval = params_global ["snapshot_interval" ]
170+ snapshot_save_old = params_global ["snapshot_save_old" ]
169171 stdout_report_interval_s = params_global ["stdout_report_interval_s" ]
170172 step_size_tune_interval = params_step_size_tune ["interval" ]
171- # WARNING: clone_history_file not restartable
172- if params_global ["clone_history" ]:
173- clone_history_file = open (f"{ output_filename_prefix } .clone_history" , "w" )
174- clone_history_file .write (f'# {{"fields": ["loop_iter", "clone_source", "clone_target"], "n_walkers": { ns .n_configs_global } }}\n ' )
175- else :
176- clone_history_file = None
177173
178174 ns_file_name = f"{ output_filename_prefix } .NS_samples"
175+ clone_history_file_name = f"{ output_filename_prefix } .clone_history" if params_global ["clone_history" ] else None
179176 traj_file_name = f"{ output_filename_prefix } .traj{ config_suffix } "
180177
181178 if NS_comm .rank == 0 :
179+ # set up I/O
182180 if ns .snapshot_iter >= 0 :
183- # snapshot, truncate existing NS_samples and .traj.suffix files
184-
185- # NOTE: does this code belong here? Maybe refactor to a function, maybe
186- # move trajectory truncation into NSConfig or something?
187-
188- # truncate .NS_samples file
189- f_samples = open (ns_file_name , "r+" )
190- # skip header
191- _ = f_samples .readline ()
192- line_i = None
193- while True :
194- line = f_samples .readline ()
195- if not line :
196- raise RuntimeError (f"Failed to find enough lines in .NS_samples file (last line { line_i } ) to reach snapshot iter { ns .snapshot_iter } " )
197-
198- line_i = int (line .split ()[0 ])
199- if line_i + sample_interval > ns .snapshot_iter :
200- cur_pos = f_samples .tell ()
201- f_samples .truncate (cur_pos )
202- break
181+ # snapshot, truncate existing .NS_samples, .clone_history, and .traj.<suffix> files
203182
204- f_samples .close ()
183+ truncate_file_first_col_iter (ns_file_name , n_header = 1 , sample_interval = sample_interval , max_iter = ns .snapshot_iter )
184+ truncate_file_first_col_iter (clone_history_file_name , n_header = 1 , sample_interval = 1 , max_iter = ns .snapshot_iter )
205185
206- # truncate .traj.suffix file
186+ # NOTE: should move trajectory truncation into NSConfig, since it's config file-format specific
187+ # truncate .traj.<suffix> file
207188 f_configs = open (traj_file_name , "r+" )
208189 while True :
209190 try :
@@ -220,16 +201,27 @@ def sample(args, MPI, NS_comm, walker_comm):
220201
221202 ns_file = open (ns_file_name , "a" )
222203 traj_file = open (traj_file_name , "a" )
204+ clone_history_file = open (clone_history_file_name , "a" )
223205
224206 else :
225- # run from start, open new .NS_samples and .traj.suffix files
226-
207+ # run from start, open new .NS_samples, .clone_history, and .traj.< suffix> files
208+ # write header as needed
227209 ns_file = open (ns_file_name , "w" )
228210 header_dict = { "n_walkers" : ns .n_configs_global , "n_cull" : 1 }
229211 header_dict .update (ns .local_configs [0 ].header_dict ())
230212 ns_file .write ("# " + " " .join (json .dumps (header_dict , indent = 0 ).splitlines ()) + "\n " )
231213
214+ if clone_history_file_name :
215+ clone_history_file = open (clone_history_file_name , "w" )
216+ clone_history_file .write (f'# {{"fields": ["loop_iter", "clone_source", "clone_target"], "n_walkers": { ns .n_configs_global } }}\n ' )
217+ else :
218+ clone_history_file = None
219+
232220 traj_file = open (traj_file_name , "w" )
221+ else :
222+ ns_file = None
223+ traj_file = None
224+ clone_history_file = None
233225
234226 max_iter = params_global ["max_iter" ]
235227 if max_iter > 0 :
@@ -250,7 +242,7 @@ def sample(args, MPI, NS_comm, walker_comm):
250242 global_ind_of_max = ns .global_ind (ns .rank_of_max , ns .local_ind_of_max )
251243
252244 # write quantities for max config which will be culled below
253- if NS_comm . rank == 0 and sample_interval > 0 and loop_iter % sample_interval == 0 :
245+ if ns_file and sample_interval > 0 and loop_iter % sample_interval == 0 :
254246 ns_file .write (f"{ loop_iter } { global_ind_of_max } { ns .max_val :.10f} " + " " .join ([f"{ quant :.10f} " for quant in ns .max_quants ]) + "\n " )
255247 ns_file .flush ()
256248
@@ -265,14 +257,14 @@ def sample(args, MPI, NS_comm, walker_comm):
265257 global_ind_of_clone_source = (global_ind_of_max + 1 + ns .rng_global .integers (0 , ns .n_configs_global - 1 )) % ns .n_configs_global
266258 rank_of_clone_source , local_ind_of_clone_source = ns .local_ind (global_ind_of_clone_source )
267259
268- if clone_history_file is not None :
260+ if clone_history_file :
269261 clone_history_file .write (f"{ loop_iter } { global_ind_of_clone_source } { global_ind_of_max } \n " )
270262 if loop_iter % 1000 == 1000 - 1 :
271263 clone_history_file .flush ()
272264
273265 # write max to traj file
274266 if traj_interval > 0 and loop_iter % traj_interval == 0 :
275- if NS_comm . rank == 0 :
267+ if traj_file :
276268 # only head node writes
277269 if NS_comm .rank == ns .rank_of_max :
278270 # already local
@@ -326,12 +318,16 @@ def sample(args, MPI, NS_comm, walker_comm):
326318 # NOTE: should this be a time rather than iteration interval? That'd basically be straightforward,
327319 # except it would require an additional communication so all processes agree that it's time for a snapshot
328320 if loop_iter > 0 and snapshot_interval > 0 and loop_iter % snapshot_interval == 0 :
329- ns .snapshot (loop_iter , output_filename_prefix )
321+ ns .snapshot (loop_iter , output_filename_prefix , save_old = snapshot_save_old )
330322
331323 loop_iter += 1
332324
333- if clone_history_file is not None :
325+ if ns_file :
326+ ns_file .close ()
327+ if clone_history_file :
334328 clone_history_file .close ()
329+ if traj_file :
330+ traj_file .close ()
335331
336332
337333def main (args_list = None , mpi_finalize = True ):
0 commit comments