Skip to content

Commit e4919ca

Browse files
authored
dpdata driver <--> ase calculator (#302)
* add ase calculator * add ASEDriver * add ase_calculator property
1 parent 2bbfc9b commit e4919ca

File tree

4 files changed

+149
-3
lines changed

4 files changed

+149
-3
lines changed

dpdata/ase_calculator.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import List, Optional, TYPE_CHECKING
2+
3+
from ase.calculators.calculator import (
4+
Calculator, all_changes, PropertyNotImplementedError
5+
)
6+
7+
import dpdata
8+
from .driver import Driver
9+
10+
if TYPE_CHECKING:
11+
from ase import Atoms
12+
13+
14+
class DPDataCalculator(Calculator):
15+
"""Implementation of ASE deepmd calculator based on a driver.
16+
17+
Parameters
18+
----------
19+
driver : Driver
20+
dpdata driver
21+
"""
22+
23+
name = "dpdata"
24+
implemented_properties = [
25+
"energy", "free_energy", "forces", "virial", "stress"]
26+
27+
def __init__(
28+
self,
29+
driver: Driver,
30+
**kwargs
31+
) -> None:
32+
Calculator.__init__(self, label=Driver.__name__, **kwargs)
33+
self.driver = driver
34+
35+
def calculate(
36+
self,
37+
atoms: Optional["Atoms"] = None,
38+
properties: List[str] = ["energy", "forces"],
39+
system_changes: List[str] = all_changes,
40+
):
41+
"""Run calculation with a driver.
42+
43+
Parameters
44+
----------
45+
atoms : Optional[Atoms], optional
46+
atoms object to run the calculation on, by default None
47+
properties : List[str], optional
48+
unused, only for function signature compatibility,
49+
by default ["energy", "forces"]
50+
system_changes : List[str], optional
51+
unused, only for function signature compatibility, by default all_changes
52+
"""
53+
if atoms is not None:
54+
self.atoms = atoms.copy()
55+
56+
system = dpdata.System(self.atoms, fmt="ase/structure")
57+
data = system.predict(driver=self.driver).data
58+
59+
self.results['energy'] = data['energies'][0]
60+
# see https://gitlab.com/ase/ase/-/merge_requests/2485
61+
self.results['free_energy'] = data['energies'][0]
62+
self.results['forces'] = data['forces'][0]
63+
if 'virials' in data:
64+
self.results['virial'] = data['virials'][0].reshape(3, 3)
65+
66+
# convert virial into stress for lattice relaxation
67+
if "stress" in properties:
68+
if sum(atoms.get_pbc()) > 0:
69+
# the usual convention (tensile stress is positive)
70+
# stress = -virial / volume
71+
stress = -0.5 * (data['virials'][0].copy() + data['virials'][0].copy().T) / \
72+
atoms.get_volume()
73+
# Voigt notation
74+
self.results['stress'] = stress.flat[[0, 4, 8, 5, 2, 1]]
75+
else:
76+
raise PropertyNotImplementedError

dpdata/driver.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Driver plugin system."""
2-
from typing import Callable, List, Union
2+
from typing import Callable, List, Union, TYPE_CHECKING
33
from .plugin import Plugin
44
from abc import ABC, abstractmethod
55

6+
if TYPE_CHECKING:
7+
import ase
68

79
class Driver(ABC):
810
"""The base class for a driver plugin. A driver can
@@ -79,6 +81,12 @@ def label(self, data: dict) -> dict:
7981
"""
8082
return NotImplemented
8183

84+
@property
85+
def ase_calculator(self) -> "ase.calculators.calculator.Calculator":
86+
"""Returns an ase calculator based on this driver."""
87+
from .ase_calculator import DPDataCalculator
88+
return DPDataCalculator(self)
89+
8290

8391
@Driver.register("hybrid")
8492
class HybridDriver(Driver):

dpdata/plugins/ase.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from dpdata.driver import Driver
12
from dpdata.format import Format
23
import numpy as np
4+
import dpdata
35
try:
46
import ase.io
57
from ase.calculators.calculator import PropertyNotImplementedError
@@ -162,3 +164,42 @@ def to_labeled_system(self, data, *args, **kwargs):
162164
structures.append(structure)
163165

164166
return structures
167+
168+
169+
@Driver.register("ase")
170+
class ASEDriver(Driver):
171+
"""ASE Driver.
172+
173+
Parameters
174+
----------
175+
calculator : ase.calculators.calculator.Calculato
176+
ASE calculator
177+
"""
178+
179+
def __init__(self, calculator: "ase.calculators.calculator.Calculator") -> None:
180+
"""Setup the driver."""
181+
self.calculator = calculator
182+
183+
def label(self, data: dict) -> dict:
184+
"""Label a system data. Returns new data with energy, forces, and virials.
185+
186+
Parameters
187+
----------
188+
data : dict
189+
data with coordinates and atom types
190+
191+
Returns
192+
-------
193+
dict
194+
labeled data with energies and forces
195+
"""
196+
# convert data to ase data
197+
system = dpdata.System(data=data)
198+
# list[Atoms]
199+
structures = system.to_ase_structure()
200+
labeled_system = dpdata.LabeledSystem()
201+
for atoms in structures:
202+
atoms.calc = self.calculator
203+
ls = dpdata.LabeledSystem(atoms, fmt="ase/structure", type_map=data['atom_names'])
204+
labeled_system.append(ls)
205+
return labeled_system.data

tests/test_predict.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
import unittest
22
import numpy as np
33

4-
from comp_sys import CompLabeledSys
4+
from comp_sys import CompLabeledSys, IsPBC
55
from context import dpdata
6+
try:
7+
import ase
8+
except ModuleNotFoundError:
9+
skip_ase = True
10+
else:
11+
skip_ase = False
612

713

814
@dpdata.driver.Driver.register("zero")
@@ -17,7 +23,7 @@ def label(self, data):
1723

1824

1925
@dpdata.driver.Driver.register("one")
20-
class ZeroDriver(dpdata.driver.Driver):
26+
class OneDriver(dpdata.driver.Driver):
2127
def label(self, data):
2228
nframes = data['coords'].shape[0]
2329
natoms = data['coords'].shape[1]
@@ -69,3 +75,18 @@ def setUp(self) :
6975
self.e_places = 6
7076
self.f_places = 6
7177
self.v_places = 6
78+
79+
80+
@unittest.skipIf(skip_ase,"skip ase related test. install ase to fix")
81+
class TestASEtraj1(unittest.TestCase, CompLabeledSys, IsPBC):
82+
def setUp (self) :
83+
ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md',
84+
fmt = 'deepmd/raw',
85+
type_map = ['O', 'H'])
86+
one_driver = OneDriver()
87+
self.system_1 = ori_sys.predict(driver=one_driver)
88+
self.system_2 = ori_sys.predict(one_driver.ase_calculator, driver="ase")
89+
self.places = 6
90+
self.e_places = 6
91+
self.f_places = 6
92+
self.v_places = 4

0 commit comments

Comments
 (0)