Skip to content

Commit b957358

Browse files
authored
Merge pull request #81 from pdobbelaere/ase-optim
Implement cell optimisations through ASE
2 parents f2bd5e5 + 660d5f5 commit b957358

File tree

13 files changed

+491
-78
lines changed

13 files changed

+491
-78
lines changed

psiflow/free_energy/phonons.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@
1313
import psiflow
1414
from psiflow.data import Dataset
1515
from psiflow.geometry import Geometry, mass_weight
16-
from psiflow.hamiltonians import Hamiltonian
17-
from psiflow.sampling.optimize import setup_forces, setup_sockets
18-
from psiflow.sampling.sampling import make_start_command, make_client_command
16+
from psiflow.hamiltonians import Hamiltonian, MixtureHamiltonian
17+
from psiflow.sampling.sampling import (
18+
setup_sockets,
19+
label_forces,
20+
make_force_xml,
21+
serialize_mixture,
22+
make_start_command,
23+
make_client_command
24+
)
1925
from psiflow.utils.apps import multiply
2026
from psiflow.utils.io import load_numpy, save_xml
2127
from psiflow.utils import TMP_COMMAND, CD_COMMAND
@@ -112,7 +118,6 @@ def _execute_ipi(
112118
TMP_COMMAND,
113119
CD_COMMAND,
114120
command_start,
115-
"sleep 3s",
116121
*commands_client,
117122
"wait",
118123
command_end,
@@ -133,8 +138,10 @@ def compute_harmonic(
133138
pos_shift: float = 0.01,
134139
energy_shift: float = 0.00095,
135140
) -> AppFuture:
136-
hamiltonians_map, forces = setup_forces(hamiltonian)
137-
sockets = setup_sockets(hamiltonians_map)
141+
hamiltonian: MixtureHamiltonian = 1 * hamiltonian
142+
names = label_forces(hamiltonian)
143+
sockets = setup_sockets(names)
144+
forces = make_force_xml(hamiltonian, names)
138145

139146
initialize = ET.Element("initialize", nbeads="1")
140147
start = ET.Element("file", mode="ase", cell_units="angstrom")
@@ -168,11 +175,10 @@ def compute_harmonic(
168175
input_future,
169176
Dataset([state]).extxyz,
170177
]
171-
inputs += [h.serialize_function(dtype="float64") for h in hamiltonians_map.values()]
178+
inputs += serialize_mixture(hamiltonian, dtype="float64")
172179

173-
hamiltonian_names = list(hamiltonians_map.keys())
174180
client_args = []
175-
for name in hamiltonian_names:
181+
for name in names:
176182
args = definition.get_client_args(name, 1, "vibrations")
177183
client_args.append(args)
178184
outputs = [
@@ -184,7 +190,7 @@ def compute_harmonic(
184190
resources = definition.wq_resources(1)
185191

186192
result = execute_ipi(
187-
hamiltonian_names,
193+
names,
188194
client_args,
189195
command_server,
190196
command_client,

psiflow/hamiltonians.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,12 @@ class MixtureHamiltonian(Hamiltonian):
114114

115115
def __init__(
116116
self,
117-
hamiltonians: list[Hamiltonian],
118-
coefficients: list[float],
117+
hamiltonians: Union[tuple, list][Hamiltonian],
118+
coefficients: Union[tuple, list][float],
119119
) -> None:
120-
self.hamiltonians = hamiltonians
121-
self.coefficients = coefficients
120+
assert len(hamiltonians) == len(coefficients)
121+
self.hamiltonians = list(hamiltonians)
122+
self.coefficients = list(coefficients)
122123

123124
def compute( # override compute for efficient batching
124125
self,

psiflow/sampling/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .metadynamics import Metadynamics # noqa: F401
2-
from .optimize import optimize, optimize_dataset # noqa: F401
2+
# from .optimize import optimize, optimize_dataset # noqa: F401
33
from .output import SimulationOutput # noqa: F401
44
from .sampling import sample # noqa: F401
55
from .walker import ReplicaExchange # noqa: F401

psiflow/sampling/_ase.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""
2+
Structure optimisation through ASE
3+
TODO: do we need to check for very large forces?
4+
TODO: what units are pressure?
5+
TODO: what to do when max_steps is reached before converging?
6+
TODO: timeout is duplicated code
7+
"""
8+
9+
import os
10+
import json
11+
import warnings
12+
import signal
13+
import argparse
14+
from pathlib import Path
15+
from types import SimpleNamespace
16+
17+
import ase
18+
import ase.io
19+
import numpy as np
20+
from ase.io.extxyz import save_calc_results
21+
from ase.calculators.calculator import Calculator, all_properties
22+
from ase.calculators.mixing import LinearCombinationCalculator
23+
from ase.optimize.precon import PreconLBFGS
24+
from ase.filters import FrechetCellFilter
25+
26+
from psiflow.geometry import Geometry
27+
from psiflow.functions import function_from_json, EnergyFunction
28+
from psiflow.sampling.utils import TimeoutException, timeout_handler
29+
30+
31+
ALLOWED_MODES: tuple[str, ...] = ('full', 'fix_volume', 'fix_shape', 'fix_cell')
32+
FILE_OUT: str = 'out.xyz'
33+
FILE_TRAJ: str = 'out.traj'
34+
35+
36+
class FunctionCalculator(Calculator):
37+
implemented_properties = ['energy', 'free_energy', 'forces', 'stress']
38+
39+
def __init__(self, function: EnergyFunction, **kwargs):
40+
super().__init__(**kwargs)
41+
self.function = function
42+
43+
def calculate(
44+
self,
45+
atoms=None,
46+
properties=all_properties,
47+
system_changes=None,
48+
):
49+
super().calculate(atoms, properties, system_changes)
50+
geometry = Geometry.from_atoms(self.atoms)
51+
self.results = self.function(geometry)
52+
self.results['free_energy'] = self.results['energy'] # required by optimiser
53+
54+
55+
def log_state(atoms: ase.Atoms) -> None:
56+
""""""
57+
def make_log(data: list[tuple[str]]):
58+
""""""
59+
txt = ['', 'Current atoms state:']
60+
txt += [f'{_[0]:<15}: {_[1]:<25}[{_[2]}]' for _ in data]
61+
txt += 'End', ''
62+
print(*txt, sep='\n')
63+
64+
data = []
65+
if atoms.calc:
66+
energy, max_force = atoms.get_potential_energy(), np.linalg.norm(atoms.get_forces(), axis=0).max()
67+
else:
68+
energy, max_force = [np.nan] * 2
69+
data += ('Energy', f'{energy:.2f}', 'eV'), ('Max. force', f'{max_force:.2E}', 'eV/A')
70+
71+
if not all(atoms.pbc):
72+
make_log(data)
73+
return
74+
75+
volume, cell = atoms.get_volume(), atoms.get_cell().cellpar().round(3)
76+
data += ('Cell volume', f'{atoms.get_volume():.2f}', 'A^3'),
77+
data += ('Box norms', str(cell[:3])[1:-1], 'A'), ('Box angles', str(cell[3:])[1:-1], 'degrees')
78+
79+
make_log(data)
80+
return
81+
82+
83+
def get_dof_filter(atoms: ase.Atoms, mode: str, pressure: float) -> ase.Atoms | FrechetCellFilter:
84+
""""""
85+
if mode == 'fix_cell':
86+
if pressure:
87+
warnings.warn('Ignoring external pressure..')
88+
return atoms
89+
kwargs = {'mask': [True] * 6, 'scalar_pressure': pressure} # enable cell DOFs
90+
if mode == 'fix_shape':
91+
kwargs['hydrostatic_strain'] = True
92+
if mode == 'fix_volume':
93+
kwargs['constant_volume'] = True
94+
if pressure:
95+
warnings.warn('Ignoring applied pressure during fixed volume optimisation..')
96+
return FrechetCellFilter(atoms, **kwargs)
97+
98+
99+
def run(args: SimpleNamespace):
100+
""""""
101+
config = json.load(Path(args.input_config).open('r'))
102+
103+
atoms = ase.io.read(args.start_xyz)
104+
if not any(atoms.pbc):
105+
atoms.center(vacuum=0) # optimiser mysteriously requires a nonzero unit cell
106+
if config['mode'] != 'fix_cell':
107+
config['mode'] = 'fix_cell'
108+
warnings.warn('Molecular structure is not periodic. Ignoring cell..')
109+
110+
# construct calculator by combining hamiltonians
111+
assert args.path_hamiltonian is not None
112+
print('Making calculator from:', *config['forces'], sep='\n')
113+
functions = [function_from_json(p) for p in args.path_hamiltonian]
114+
calc = LinearCombinationCalculator(
115+
[FunctionCalculator(f) for f in functions],
116+
[float(h['weight']) for h in config['forces']]
117+
)
118+
119+
atoms.calc = calc
120+
dof = get_dof_filter(atoms, config['mode'], config['pressure'])
121+
opt = PreconLBFGS(dof, trajectory=FILE_TRAJ if config['keep_trajectory'] else None)
122+
123+
print(f"pid: {os.getpid()}")
124+
print(f"CPU affinity: {os.sched_getaffinity(os.getpid())}")
125+
log_state(atoms)
126+
try:
127+
opt.run(fmax=config['f_max'], steps=config['max_steps'])
128+
except TimeoutException:
129+
print('OPTIMISATION TIMEOUT')
130+
# TODO: what to do here?
131+
return
132+
133+
log_state(atoms)
134+
save_calc_results(atoms, calc_prefix='', remove_atoms_calc=True)
135+
if not any(atoms.pbc):
136+
atoms.cell = None # remove meaningless cell
137+
ase.io.write(FILE_OUT, atoms)
138+
print('OPTIMISATION SUCCESSFUL')
139+
return
140+
141+
142+
def clean(args: SimpleNamespace):
143+
""""""
144+
from psiflow.data.utils import _write_frames
145+
146+
geometry = Geometry.load(FILE_OUT)
147+
_write_frames(geometry, outputs=[args.output_xyz])
148+
if Path(FILE_TRAJ).is_file():
149+
traj = [at for at in ase.io.trajectory.Trajectory(FILE_TRAJ)]
150+
geometries = [Geometry.from_atoms(at) for at in traj]
151+
_write_frames(*geometries, outputs=[args.output_traj])
152+
print('FILES MOVED')
153+
return
154+
155+
156+
def main():
157+
signal.signal(signal.SIGTERM, timeout_handler)
158+
parser = argparse.ArgumentParser()
159+
subparsers = parser.add_subparsers(help='what to do', dest='action')
160+
run_parser = subparsers.add_parser("run")
161+
run_parser.set_defaults(func=run)
162+
run_parser.add_argument(
163+
"--path_hamiltonian",
164+
action='extend',
165+
nargs='*',
166+
type=str,
167+
)
168+
run_parser.add_argument(
169+
"--input_config",
170+
type=str,
171+
default=None,
172+
)
173+
run_parser.add_argument(
174+
"--start_xyz",
175+
type=str,
176+
default=None,
177+
)
178+
clean_parser = subparsers.add_parser("clean")
179+
clean_parser.set_defaults(func=clean)
180+
clean_parser.add_argument(
181+
"--output_xyz",
182+
type=str,
183+
default=None,
184+
)
185+
clean_parser.add_argument(
186+
"--output_traj",
187+
type=str,
188+
default=None,
189+
)
190+
args = parser.parse_args()
191+
args.func(args)
192+
193+

0 commit comments

Comments
 (0)