Skip to content

Commit ed0a6a6

Browse files
authored
support converting ase.Atoms to System and LabeledSystem (#290)
#190 added `from_labeled_system`, but it just passes a dict and does nothing else. This commit rewrites it.
1 parent 6e73a62 commit ed0a6a6

File tree

3 files changed

+121
-38
lines changed

3 files changed

+121
-38
lines changed

docs/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,4 +183,7 @@ def setup(app):
183183
intersphinx_mapping = {
184184
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
185185
"python": ("https://docs.python.org/", None),
186+
"ase": ("https://wiki.fysik.dtu.dk/ase/", None),
187+
"monty": ("https://guide.materialsvirtuallab.org/monty/", None),
188+
"h5py": ("https://docs.h5py.org/en/stable/", None),
186189
}

dpdata/plugins/ase.py

Lines changed: 93 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,41 +18,101 @@ class ASEStructureFormat(Format):
1818
automatic detection fails.
1919
"""
2020

21-
def from_labeled_system(self, data, **kwargs):
22-
return data
23-
24-
def from_multi_systems(self, file_name, begin=None, end=None, step=None, ase_fmt=None, **kwargs):
21+
def from_system(self, atoms: "ase.Atoms", **kwargs) -> dict:
22+
"""Convert ase.Atoms to a System.
23+
24+
Parameters
25+
----------
26+
atoms : ase.Atoms
27+
an ASE Atoms, containing a structure
28+
29+
Returns
30+
-------
31+
dict
32+
data dict
33+
"""
34+
symbols = atoms.get_chemical_symbols()
35+
atom_names = list(set(symbols))
36+
atom_numbs = [symbols.count(symbol) for symbol in atom_names]
37+
atom_types = np.array([atom_names.index(symbol) for symbol in symbols]).astype(int)
38+
cells = atoms.cell[:]
39+
coords = atoms.get_positions()
40+
info_dict = {
41+
'atom_names': atom_names,
42+
'atom_numbs': atom_numbs,
43+
'atom_types': atom_types,
44+
'cells': np.array([cells]).astype('float32'),
45+
'coords': np.array([coords]).astype('float32'),
46+
'orig': [0,0,0],
47+
}
48+
return info_dict
49+
50+
def from_labeled_system(self, atoms: "ase.Atoms", **kwargs) -> dict:
51+
"""Convert ase.Atoms to a LabeledSystem. Energies and forces
52+
are calculated by the calculator.
53+
54+
Parameters
55+
----------
56+
atoms : ase.Atoms
57+
an ASE Atoms, containing a structure
58+
59+
Returns
60+
-------
61+
dict
62+
data dict
63+
64+
Raises
65+
------
66+
RuntimeError
67+
ASE will raise RuntimeError if the atoms does not
68+
have a calculator
69+
"""
70+
info_dict = self.from_system(atoms)
71+
try:
72+
energies = atoms.get_potential_energy(force_consistent=True)
73+
except PropertyNotImplementedError:
74+
energies = atoms.get_potential_energy()
75+
forces = atoms.get_forces()
76+
info_dict = {
77+
** info_dict,
78+
'energies': np.array([energies]).astype('float32'),
79+
'forces': np.array([forces]).astype('float32'),
80+
}
81+
try:
82+
stress = atoms.get_stress(False)
83+
except PropertyNotImplementedError:
84+
pass
85+
else:
86+
virials = np.array([-atoms.get_volume() * stress]).astype('float32')
87+
info_dict['virials'] = virials
88+
return info_dict
89+
90+
def from_multi_systems(self, file_name: str, begin: int=None, end: int=None, step: int=None, ase_fmt: str=None, **kwargs) -> "ase.Atoms":
91+
"""Convert a ASE supported file to ASE Atoms.
92+
93+
It will finally be converted to MultiSystems.
94+
95+
Parameters
96+
----------
97+
file_name : str
98+
path to file
99+
begin : int, optional
100+
begin frame index
101+
end : int, optional
102+
end frame index
103+
step : int, optional
104+
frame index step
105+
ase_fmt : str, optional
106+
ASE format. See the ASE documentation about supported formats
107+
108+
Yields
109+
------
110+
ase.Atoms
111+
ASE atoms in the file
112+
"""
25113
frames = ase.io.read(file_name, format=ase_fmt, index=slice(begin, end, step))
26114
for atoms in frames:
27-
symbols = atoms.get_chemical_symbols()
28-
atom_names = list(set(symbols))
29-
atom_numbs = [symbols.count(symbol) for symbol in atom_names]
30-
atom_types = np.array([atom_names.index(symbol) for symbol in symbols]).astype(int)
31-
32-
cells = atoms.cell[:]
33-
coords = atoms.get_positions()
34-
try:
35-
energies = atoms.get_potential_energy(force_consistent=True)
36-
except PropertyNotImplementedError:
37-
energies = atoms.get_potential_energy()
38-
forces = atoms.get_forces()
39-
info_dict = {
40-
'atom_names': atom_names,
41-
'atom_numbs': atom_numbs,
42-
'atom_types': atom_types,
43-
'cells': np.array([cells]).astype('float32'),
44-
'coords': np.array([coords]).astype('float32'),
45-
'energies': np.array([energies]).astype('float32'),
46-
'forces': np.array([forces]).astype('float32'),
47-
'orig': [0,0,0],
48-
}
49-
try:
50-
stress = atoms.get_stress(False)
51-
virials = np.array([-atoms.get_volume() * stress]).astype('float32')
52-
info_dict['virials'] = virials
53-
except PropertyNotImplementedError:
54-
pass
55-
yield info_dict
115+
yield atoms
56116

57117
def to_system(self, data, **kwargs):
58118
'''

tests/test_to_ase.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
from context import dpdata
55
from comp_sys import CompSys, IsPBC
66
try:
7-
from ase import Atoms
8-
from ase.io import write
9-
exist_module=True
10-
except Exception:
11-
exist_module=False
7+
from ase import Atoms
8+
from ase.io import write
9+
except ModuleNotFoundError:
10+
exist_module=False
11+
else:
12+
exist_module=True
13+
1214

1315
@unittest.skipIf(not exist_module,"skip test_ase")
1416
class TestASE(unittest.TestCase, CompSys, IsPBC):
@@ -24,6 +26,24 @@ def setUp(self):
2426
self.f_places = 6
2527
self.v_places = 6
2628

29+
30+
@unittest.skipIf(not exist_module, "skip test_ase")
31+
class TestFromASE(unittest.TestCase, CompSys, IsPBC):
32+
"""Test ASEStructureFormat.from_system"""
33+
def setUp(self):
34+
system_1 = dpdata.System()
35+
system_1.from_lammps_lmp(os.path.join('poscars', 'conf.lmp'), type_map = ['O', 'H'])
36+
atoms = system_1.to_ase_structure()[0]
37+
self.system_1 = system_1
38+
self.system_2 = dpdata.System(atoms, fmt="ase/structure")
39+
# assign the same type_map
40+
self.system_2.sort_atom_names(type_map=self.system_1.get_atom_names())
41+
self.places = 6
42+
self.e_places = 6
43+
self.f_places = 6
44+
self.v_places = 6
45+
46+
2747
if __name__ == '__main__':
2848
unittest.main()
2949

0 commit comments

Comments
 (0)