@@ -73,6 +73,7 @@ def __init__(self, params_ns, comm, MPI, random_seed, params_configs, output_fil
7373 self .max_val = None
7474
7575 self .local_configs = []
76+ self .extra_config = False
7677
7778 old_state_files = NS ._old_state_files (output_filename_prefix )
7879 if len (old_state_files ) > 0 :
@@ -202,9 +203,14 @@ def init_configs(self, params_configs, configs_file=None, extra=False):
202203 self .local_configs = []
203204 if self .comm .rank == 0 :
204205 for config_i , new_config in enumerate (new_configs_generator ()):
206+ if config_i == 0 :
207+ first_config = new_config
205208 if config_i >= self .n_configs_global :
206209 raise RuntimeError (f"Got too many configs (expected { self .n_configs_global } ) from new config generator { new_configs_generator } " )
207210
211+ # Check that all step sizes are the same. Maybe instead we should just copy from first?
212+ assert new_config .step_size == first_config .step_size , f"Mismatched step size for config { config_i } { new_config .step_size } != 0 { first_config .step_size } "
213+
208214 target_rank = config_i // self .max_n_configs_local
209215 if target_rank == self .comm .rank :
210216 self .local_configs .append (new_config )
@@ -352,6 +358,13 @@ def step_size_tune(self, n_configs=1, min_accept_rate=0.25, max_accept_rate=0.5,
352358 print ("step_size_tune initial" , name , "size" , size , "max" , max_size , "freq" , freq )
353359 first_iter = False
354360
361+ # It looks like the following should always give the same values, hence exit
362+ # condition, on all MPI tasks, but this is not guaranteed and can lead to deadlocks
363+ # in the allreduce. The reason is that the value of done_i in the loop depends
364+ # on the value returned from _tune_from_accept_rate, which depends on the previous
365+ # step size, and if those are inconsistent between MPI tasks (as in
366+ # https://github.com/libAtoms/pymatnext/issues/20), a deadlock may occur.
367+ # Only fix is to make sure this doesn't happen (https://github.com/libAtoms/pymatnext/pull/23)
355368 done = []
356369 for param_i in range (n_params ):
357370 if accept_freq [param_i ][0 ] > 0 :
@@ -369,6 +382,9 @@ def step_size_tune(self, n_configs=1, min_accept_rate=0.25, max_accept_rate=0.5,
369382 new_step_size = {k : v * m for k , v , m in zip (step_size_names , step_size , max_step_size )}
370383 for ns_config in self .local_configs :
371384 ns_config .step_size = new_step_size
385+ # make sure that config used as buffer also has correct step_size
386+ if self .extra_config :
387+ self .extra_config .step_size = new_step_size
372388
373389 # if self.comm.rank == 0:
374390 # print("step_size_tune done", list(zip(done, accept_freq, step_size)))
0 commit comments