Skip to content

Commit bb75d0f

Browse files
authored
add Minimizer plugin and System.minimize method (#304)
This PR adds the `minimize` method to `System` and also the Minimizer plugins. Also add two minimizers: ASE and AMBER SQM.
1 parent c9deb10 commit bb75d0f

File tree

5 files changed

+242
-6
lines changed

5 files changed

+242
-6
lines changed

dpdata/driver.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,74 @@ def label(self, data: dict) -> dict:
147147
labeled_data['energies'] += lb_data ['energies']
148148
labeled_data['forces'] += lb_data ['forces']
149149
return labeled_data
150+
151+
152+
class Minimizer(ABC):
153+
"""The base class for a minimizer plugin. A minimizer can
154+
minimize geometry.
155+
"""
156+
__MinimizerPlugin = Plugin()
157+
158+
@staticmethod
159+
def register(key: str) -> Callable:
160+
"""Register a minimizer plugin. Used as decorators.
161+
162+
Parameter
163+
---------
164+
key: str
165+
key of the plugin.
166+
167+
Returns
168+
-------
169+
Callable
170+
decorator of a class
171+
172+
Examples
173+
--------
174+
>>> @Minimizer.register("some_minimizer")
175+
... class SomeMinimizer(Minimizer):
176+
... pass
177+
"""
178+
return Minimizer.__MinimizerPlugin.register(key)
179+
180+
@staticmethod
181+
def get_minimizer(key: str) -> "Minimizer":
182+
"""Get a minimizer plugin.
183+
184+
Parameter
185+
---------
186+
key: str
187+
key of the plugin.
188+
189+
Returns
190+
-------
191+
Minimizer
192+
the specific minimizer class
193+
194+
Raises
195+
------
196+
RuntimeError
197+
if the requested minimizer is not implemented
198+
"""
199+
try:
200+
return Minimizer.__MinimizerPlugin.plugins[key]
201+
except KeyError as e:
202+
raise RuntimeError('Unknown minimizer: ' + key) from e
203+
204+
def __init__(self, *args, **kwargs) -> None:
205+
"""Setup the minimizer."""
206+
207+
@abstractmethod
208+
def minimize(self, data: dict) -> dict:
209+
"""Minimize the geometry.
210+
211+
Parameters
212+
----------
213+
data : dict
214+
data with coordinates and atom types
215+
216+
Returns
217+
-------
218+
dict
219+
labeled data with minimized coordinates, energies, and forces
220+
"""

dpdata/plugins/amber.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import dpdata.amber.md
66
import dpdata.amber.sqm
77
from dpdata.format import Format
8-
from dpdata.driver import Driver
8+
from dpdata.driver import Driver, Minimizer
99

1010

1111
@Format.register("amber/md")
@@ -122,3 +122,21 @@ def label(self, data: dict) -> dict:
122122
) from e
123123
labeled_system.append(dpdata.LabeledSystem(out_fn, fmt="sqm/out"))
124124
return labeled_system.data
125+
126+
127+
@Minimizer.register("sqm")
128+
class SQMMinimizer(Minimizer):
129+
"""SQM minimizer.
130+
131+
Parameters
132+
----------
133+
maxcyc : int, default=1000
134+
maximun cycle to minimize
135+
"""
136+
def __init__(self, maxcyc=1000, *args, **kwargs) -> None:
137+
assert maxcyc > 0, "maxcyc should be more than 0 to minimize"
138+
self.driver = SQMDriver(maxcyc=maxcyc, **kwargs)
139+
140+
def minimize(self, data: dict) -> dict:
141+
# sqm has minimize feature
142+
return self.driver.label(data)

dpdata/plugins/ase.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from dpdata.driver import Driver
1+
from typing import TYPE_CHECKING, Type
2+
from dpdata.driver import Driver, Minimizer
23
from dpdata.format import Format
34
import numpy as np
45
import dpdata
56
try:
67
import ase.io
78
from ase.calculators.calculator import PropertyNotImplementedError
9+
if TYPE_CHECKING:
10+
from ase.optimize.optimize import Optimizer
811
except ImportError:
912
pass
1013

@@ -204,3 +207,61 @@ def label(self, data: dict) -> dict:
204207
ls = dpdata.LabeledSystem(atoms, fmt="ase/structure", type_map=data['atom_names'])
205208
labeled_system.append(ls)
206209
return labeled_system.data
210+
211+
212+
@Minimizer.register("ase")
213+
class ASEMinimizer(Minimizer):
214+
"""ASE minimizer.
215+
216+
Parameters
217+
----------
218+
driver : Driver
219+
dpdata driver
220+
optimizer : type, optional
221+
ase optimizer class
222+
fmax : float, optional, default=5e-3
223+
force convergence criterion
224+
optimizer_kwargs : dict, optional
225+
other parameters for optimizer
226+
"""
227+
def __init__(self,
228+
driver: Driver,
229+
optimizer: Type["Optimizer"] = None,
230+
fmax: float = 5e-3,
231+
optimizer_kwargs: dict = {}) -> None:
232+
self.calculator = driver.ase_calculator
233+
if optimizer is None:
234+
from ase.optimize import LBFGS
235+
self.optimizer = LBFGS
236+
else:
237+
self.optimizer = optimizer
238+
self.optimizer_kwargs = {
239+
"logfile": None,
240+
**optimizer_kwargs.copy(),
241+
}
242+
self.fmax = fmax
243+
244+
def minimize(self, data: dict) -> dict:
245+
"""Minimize the geometry.
246+
247+
Parameters
248+
----------
249+
data : dict
250+
data with coordinates and atom types
251+
252+
Returns
253+
-------
254+
dict
255+
labeled data with minimized coordinates, energies, and forces
256+
"""
257+
system = dpdata.System(data=data)
258+
# list[Atoms]
259+
structures = system.to_ase_structure()
260+
labeled_system = dpdata.LabeledSystem()
261+
for atoms in structures:
262+
atoms.calc = self.calculator
263+
dyn = self.optimizer(atoms, **self.optimizer_kwargs)
264+
dyn.run(fmax=self.fmax)
265+
ls = dpdata.LabeledSystem(atoms, fmt="ase/structure", type_map=data['atom_names'])
266+
labeled_system.append(ls)
267+
return labeled_system.data

dpdata/system.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import dpdata.md.pbc
77
from copy import deepcopy
88
from enum import Enum, unique
9-
from typing import Any, Tuple
9+
from typing import Any, Tuple, Union
1010
from monty.json import MSONable
1111
from monty.serialization import loadfn,dumpfn
1212
from dpdata.periodic_table import Element
@@ -17,7 +17,7 @@
1717
import dpdata.plugins
1818
from dpdata.plugin import Plugin
1919
from dpdata.format import Format
20-
from dpdata.driver import Driver
20+
from dpdata.driver import Driver, Minimizer
2121

2222
from dpdata.utils import (
2323
elements_index_map,
@@ -869,6 +869,28 @@ def predict(self, *args: Any, driver: str="dp", **kwargs: Any) -> "LabeledSystem
869869
data = driver.label(self.data.copy())
870870
return LabeledSystem(data=data)
871871

872+
def minimize(self, *args: Any, minimizer: Union[str, Minimizer], **kwargs: Any) -> "LabeledSystem":
873+
"""Minimize the geometry.
874+
875+
Parameters
876+
----------
877+
*args : iterable
878+
Arguments passing to the minimizer
879+
minimizer : str or Minimizer
880+
The assigned minimizer
881+
**kwargs : dict
882+
Other arguments passing to the minimizer
883+
884+
Returns
885+
-------
886+
labeled_sys : LabeledSystem
887+
A new labeled system.
888+
"""
889+
if not isinstance(minimizer, Minimizer):
890+
minimizer = Minimizer.get_minimizer(minimizer)(*args, **kwargs)
891+
data = minimizer.minimize(self.data.copy())
892+
return LabeledSystem(data=data)
893+
872894
def pick_atom_idx(self, idx, nopbc=None):
873895
"""Pick atom index
874896
@@ -1308,10 +1330,43 @@ def predict(self, *args: Any, driver="dp", **kwargs: Any) -> "MultiSystems":
13081330
"""
13091331
if not isinstance(driver, Driver):
13101332
driver = Driver.get_driver(driver)(*args, **kwargs)
1311-
new_multisystems = dpdata.MultiSystems()
1333+
new_multisystems = dpdata.MultiSystems(type_map=self.atom_names)
13121334
for ss in self:
13131335
new_multisystems.append(ss.predict(*args, driver=driver, **kwargs))
13141336
return new_multisystems
1337+
1338+
def minimize(self, *args: Any, minimizer: Union[str, Minimizer], **kwargs: Any) -> "MultiSystems":
1339+
"""
1340+
Minimize geometry by a minimizer.
1341+
1342+
Parameters
1343+
----------
1344+
*args : iterable
1345+
Arguments passing to the minimizer
1346+
minimizer : str or Minimizer
1347+
The assigned minimizer
1348+
**kwargs : dict
1349+
Other arguments passing to the minimizer
1350+
1351+
Returns
1352+
-------
1353+
MultiSystems
1354+
A new labeled MultiSystems.
1355+
1356+
Examples
1357+
--------
1358+
Minimize a system using ASE BFGS along with a DP driver:
1359+
>>> from dpdata.driver import Driver
1360+
>>> from ase.optimize import BFGS
1361+
>>> driver = driver.get_driver("dp")("some_model.pb")
1362+
>>> some_system.minimize(minimizer="ase", driver=driver, optimizer=BFGS, fmax=1e-5)
1363+
"""
1364+
if not isinstance(minimizer, Minimizer):
1365+
minimizer = Minimizer.get_minimizer(minimizer)(*args, **kwargs)
1366+
new_multisystems = dpdata.MultiSystems(type_map=self.atom_names)
1367+
for ss in self:
1368+
new_multisystems.append(ss.minimize(*args, minimizer=minimizer, **kwargs))
1369+
return new_multisystems
13151370

13161371
def pick_atom_idx(self, idx, nopbc=None):
13171372
"""Pick atom index

tests/test_predict.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def setUp(self) :
7878

7979

8080
@unittest.skipIf(skip_ase,"skip ase related test. install ase to fix")
81-
class TestASEtraj1(unittest.TestCase, CompLabeledSys, IsPBC):
81+
class TestASEDriver(unittest.TestCase, CompLabeledSys, IsPBC):
8282
def setUp (self) :
8383
ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md',
8484
fmt = 'deepmd/raw',
@@ -90,3 +90,34 @@ def setUp (self) :
9090
self.e_places = 6
9191
self.f_places = 6
9292
self.v_places = 4
93+
94+
95+
@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix")
96+
class TestMinimize(unittest.TestCase, CompLabeledSys, IsPBC):
97+
def setUp (self) :
98+
ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md',
99+
fmt = 'deepmd/raw',
100+
type_map = ['O', 'H'])
101+
zero_driver = ZeroDriver()
102+
self.system_1 = ori_sys.predict(driver=zero_driver)
103+
self.system_2 = ori_sys.minimize(driver=zero_driver, minimizer="ase")
104+
self.places = 6
105+
self.e_places = 6
106+
self.f_places = 6
107+
self.v_places = 4
108+
109+
110+
@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix")
111+
class TestMinimizeMultiSystems(unittest.TestCase, CompLabeledSys, IsPBC):
112+
def setUp (self) :
113+
ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md',
114+
fmt = 'deepmd/raw',
115+
type_map = ['O', 'H'])
116+
multi_sys = dpdata.MultiSystems(ori_sys)
117+
zero_driver = ZeroDriver()
118+
self.system_1 = list(multi_sys.predict(driver=zero_driver).systems.values())[0]
119+
self.system_2 = list(multi_sys.minimize(driver=zero_driver, minimizer="ase").systems.values())[0]
120+
self.places = 6
121+
self.e_places = 6
122+
self.f_places = 6
123+
self.v_places = 4

0 commit comments

Comments
 (0)