Skip to content

Commit 712afa5

Browse files
authored
Merge pull request #33 from libAtoms/one_walker
Allow for single walker runs, and fix underflow in analysis
2 parents 5d07556 + c9ca003 commit 712afa5

File tree

4 files changed

+48
-21
lines changed

4 files changed

+48
-21
lines changed

pymatnext/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.3"
1+
__version__ = "0.1.4"

pymatnext/analysis/utils.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def calc_log_a(iters, n_walkers, n_cull, discrete=False):
4444
return log_a
4545

4646

47-
def calc_Z_terms(beta, log_a, Es, flat_V_prior=False, N_atoms=None, Vs=None):
47+
def calc_log_Z_terms(beta, log_a, Es, flat_V_prior=False, N_atoms=None, Vs=None):
4848
"""Return the terms that sum to Z
4949
5050
Parameters
@@ -64,7 +64,7 @@ def calc_Z_terms(beta, log_a, Es, flat_V_prior=False, N_atoms=None, Vs=None):
6464
6565
Returns
6666
-------
67-
Z_term: list of terms that sum to Z, multipled by exp(-shift)
67+
log_Z_term: list of logs of terms that sum to Z, shifted by -log_shift
6868
log_shift: shift subtracted from each log(Z_term_true) to get log(Z_term)
6969
"""
7070
log_Z_term = log_a[:] - beta * Es[:]
@@ -75,9 +75,10 @@ def calc_Z_terms(beta, log_a, Es, flat_V_prior=False, N_atoms=None, Vs=None):
7575
log_Z_term += N_atoms * np.log(Vs[:])
7676

7777
log_shift = np.amax(log_Z_term[:])
78-
Z_term = np.exp(log_Z_term[:] - log_shift)
78+
# Z_term = np.exp(log_Z_term[:] - log_shift)
79+
log_Z_term -= log_shift
7980

80-
return (Z_term, log_shift)
81+
return (log_Z_term, log_shift)
8182

8283

8384
def analyse_T(T, Es, E_shift, Vs, extra_vals, log_a, flat_V_prior, N_atoms, kB, n_extra_DOF, p_entropy_min=5.0, sum_f=np.sum):
@@ -113,6 +114,8 @@ def analyse_T(T, Es, E_shift, Vs, extra_vals, log_a, flat_V_prior, N_atoms, kB,
113114
added analytically to energies and specific heats
114115
p_entropy_min: float, default 5
115116
minimum value of entropy of probability distribution that indicates a problem (poor sampling, e.g. with P reweighting)
117+
sum_f: function
118+
function with API equivalent to np.sum, e.g. more accurate sum
116119
117120
Returns
118121
-------
@@ -121,7 +124,8 @@ def analyse_T(T, Es, E_shift, Vs, extra_vals, log_a, flat_V_prior, N_atoms, kB,
121124
beta = 1.0 / (kB * T)
122125

123126
# Z_term here is actually Z_term_true * exp(-log_shift)
124-
(Z_term, log_shift) = calc_Z_terms(beta, log_a, Es, flat_V_prior, N_atoms, Vs)
127+
(log_Z_term, log_shift) = calc_log_Z_terms(beta, log_a, Es, flat_V_prior, N_atoms, Vs)
128+
Z_term = np.exp(log_Z_term)
125129

126130
# Note that
127131
# Z_term = Z_term_true * exp(-log_shift)
@@ -133,9 +137,9 @@ def analyse_T(T, Es, E_shift, Vs, extra_vals, log_a, flat_V_prior, N_atoms, kB,
133137

134138
if N_atoms is not None:
135139
N = sum_f(Z_term * N_atoms) / Z_term_sum
136-
n_extra_DOF * N
137140

138-
U = n_extra_DOF / (2.0 * beta) + U_pot + E_shift
141+
U_extra_DOF = n_extra_DOF / (2.0 * beta)
142+
U = U_pot + U_extra_DOF + E_shift
139143

140144
Cvp = n_extra_DOF * kB / 2.0 + kB * beta * beta * (sum(Z_term * Es**2) / Z_term_sum - U_pot**2)
141145

@@ -154,25 +158,27 @@ def analyse_T(T, Es, E_shift, Vs, extra_vals, log_a, flat_V_prior, N_atoms, kB,
154158
# undo shift of Z_term
155159
log_Z = np.log(Z_term_sum) + log_shift
156160

161+
# to make sure that Helmholtz_F approaches U at T -> 0,
157162
# we want last Z term to have w = 1, so we define a factor f which scales it correctly
158163
# f exp(log_shift) Z_term[-1] = 1.0 * exp(-beta Es[-1])
159164
# f = exp(-beta Es[-1] - log_shift) / Z_term[-1]
160165
# log(f) = -beta Es[-1] - log_shift - log(Z_term[-1])
161-
log_f = -beta * Es[-1] - log_shift - np.log(Z_term[-1])
166+
log_f = -beta * Es[-1] - log_shift - log_Z_term[-1]
162167
# this factor rescales every term in Z
163168
log_Z += log_f
164169

165170
# also add the E_shift
166-
Helmholtz_F = -log_Z / beta + E_shift
171+
Helmholtz_F = -log_Z / beta + U_extra_DOF + E_shift
167172

168173
mode_config = np.argmax(Z_term)
169174

170-
171175
results_dict = {'log_Z': log_Z,
172176
'FG': Helmholtz_F,
173177
'U': U,
174178
'S': (U - Helmholtz_F) * beta,
175179
'Cvp': Cvp}
180+
if N_atoms is not None:
181+
results_dict['N'] = N
176182

177183
if Vs is not None:
178184
results_dict['V'] = V

pymatnext/cli/ns_analyse.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ def main():
4444
pressure_g.add_argument('--delta_P', help="""delta pressure to use for reweighting (works best with flat V prior)""", type=float)
4545
p.add_argument('--entropy', '-S', action='store_true', help="""compute and print entropy (relative to entropy of lowest T structure""")
4646
p.add_argument('--probability_entropy_minimum', type=float, help="""probability entropy mininum that indicates a problem with sampling""", default=5.0)
47-
p.add_argument('--plot', '-p', nargs='*', help="""column names to plot, or optionally 'log(colname)'. """
48-
"""If no column names provided, list allowed names and abort""")
47+
p.add_argument('--plot', '-p', nargs='*', help="""column names to plot. Python expression can be used, with '{colname}' """
48+
"""replaced by the value of the quantity, and 'natoms' by the number of atoms. """
49+
"""Plotted on a semi-log axis if name or expression is enclosed by 'log(..)'. """
50+
"""If no column names provided, list allowed names and abort.""")
4951
p.add_argument('--plot_together', help="""output filename for combined plot""")
5052
p.add_argument('--plot_together_filenames', action='store_true', help="""show filenames in combined plot""")
5153
p.add_argument('--plot_twinx_spacing', type=float, help="""spacing for extra twinx y axes""", default=0.15)
@@ -85,13 +87,17 @@ def main():
8587
ax = {}
8688

8789
def colname(colname_str):
88-
m = re.match(r'log\(([^)]*)\)$', colname_str)
90+
# strip off optional log()
91+
m = re.match(r'^log\((.*)\)$', colname_str)
92+
if m:
93+
colname_str = m.group(1)
94+
# extract col name
95+
m = re.match(r'\{([^}]*)\}', colname_str)
8996
if m:
9097
return m.group(1)
9198
else:
9299
return colname_str
93100

94-
95101
for infile_i, infile in enumerate(args.infile):
96102
iters = []
97103
Es = []
@@ -346,8 +352,17 @@ def str_format(fmt):
346352
fig = Figure()
347353
ax = {}
348354
for field_i, pfield in enumerate(args.plot):
349-
# should this be done here? should it be more general, e.g. eval()?
350-
col_log = pfield.startswith('log')
355+
# check for log axis
356+
col_log = re.match(r'^log\(.*\)$', pfield)
357+
358+
# check for arb math
359+
m = re.match(r'^(.*){([^}]*)}(.*)$', pfield)
360+
if m:
361+
expr = m.group(1) + m.group(2) + m.group(3)
362+
c = colname(pfield)
363+
plot_data[c] = np.asarray([eval(expr, {}, {'natoms': results_dict['N'], c: v}) for v in plot_data[c]])
364+
365+
# replace pfield with colname
351366
pfield = colname(pfield)
352367
if len(ax) == 0:
353368
ax[pfield] = fig.add_subplot()
@@ -391,11 +406,10 @@ def do_plot_sections(pfield, linestyle, color, label):
391406
section_start = T_i
392407
plot_section_start = max(section_start - 1, 0)
393408
plot_section_end = len(plot_data['T'])
394-
if not got_label:
395-
use_label = label
409+
396410
pp(plot_data['T'][plot_section_start:plot_section_end], plot_data[pfield][plot_section_start:plot_section_end],
397411
linestyle if valid_Ts_bool[section_start] else ':',
398-
color=color, label=use_label)
412+
color=color, label=None if got_label else label)
399413

400414
ax[pfield].set_ylabel(header_col(pfield))
401415

pymatnext/cli/sample.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ def sample(args, MPI, NS_comm, walker_comm):
233233
print("params = ", end="")
234234
pprint.pprint(params, sort_dicts=False)
235235

236+
# Only for one walker runs, allow culled config to be source of clone
237+
if ns.n_configs_global == 1:
238+
clone_index_exclude = 0
239+
else:
240+
clone_index_exclude = 1
241+
236242
time_prev_stdout_report = time.time()
237243
for loop_iter in loop_iterable:
238244
if exit_cond(ns, loop_iter):
@@ -254,7 +260,8 @@ def sample(args, MPI, NS_comm, walker_comm):
254260
adjust_factor=params_step_size_tune["adjust_factor"])
255261

256262
# pick random config as source for clone.
257-
global_ind_of_clone_source = (global_ind_of_max + 1 + ns.rng_global.integers(0, ns.n_configs_global - 1)) % ns.n_configs_global
263+
global_ind_of_clone_source = (global_ind_of_max + clone_index_exclude +
264+
ns.rng_global.integers(0, ns.n_configs_global - clone_index_exclude)) % ns.n_configs_global
258265
rank_of_clone_source, local_ind_of_clone_source = ns.local_ind(global_ind_of_clone_source)
259266

260267
if clone_history_file:

0 commit comments

Comments
 (0)