Skip to content

Commit 4008687

Browse files
authored
add ASE's traj support (#614)
1 parent 46a8952 commit 4008687

File tree

10 files changed

+255
-2
lines changed

10 files changed

+255
-2
lines changed

dpdata/plugins/ase.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
try:
1010
import ase.io
1111
from ase.calculators.calculator import PropertyNotImplementedError
12+
from ase.io import Trajectory
1213

1314
if TYPE_CHECKING:
1415
from ase.optimize.optimize import Optimizer
@@ -43,7 +44,7 @@ def from_system(self, atoms: "ase.Atoms", **kwargs) -> dict:
4344
data dict
4445
"""
4546
symbols = atoms.get_chemical_symbols()
46-
atom_names = list(set(symbols))
47+
atom_names = list(dict.fromkeys(symbols))
4748
atom_numbs = [symbols.count(symbol) for symbol in atom_names]
4849
atom_types = np.array([atom_names.index(symbol) for symbol in symbols]).astype(
4950
int
@@ -187,6 +188,115 @@ def to_labeled_system(self, data, *args, **kwargs):
187188
return structures
188189

189190

191+
@Format.register("ase/traj")
192+
class ASETrajFormat(Format):
193+
"""Format for the ASE's trajectory format <https://wiki.fysik.dtu.dk/ase/ase/io/trajectory.html#module-ase.io.trajectory>`_ (ase).'
194+
a `traj' contains a sequence of frames, each of which is an `Atoms' object.
195+
"""
196+
197+
def from_system(
198+
self,
199+
file_name: str,
200+
begin: Optional[int] = 0,
201+
end: Optional[int] = None,
202+
step: Optional[int] = 1,
203+
**kwargs,
204+
) -> dict:
205+
"""Read ASE's trajectory file to `System` of multiple frames.
206+
207+
Parameters
208+
----------
209+
file_name : str
210+
ASE's trajectory file
211+
begin : int, optional
212+
begin frame index
213+
end : int, optional
214+
end frame index
215+
step : int, optional
216+
frame index step
217+
**kwargs : dict
218+
other parameters
219+
220+
Returns
221+
-------
222+
dict_frames: dict
223+
a dictionary containing data of multiple frames
224+
"""
225+
traj = Trajectory(file_name)
226+
sub_traj = traj[begin:end:step]
227+
dict_frames = ASEStructureFormat().from_system(sub_traj[0])
228+
for atoms in sub_traj[1:]:
229+
tmp = ASEStructureFormat().from_system(atoms)
230+
dict_frames["cells"] = np.append(dict_frames["cells"], tmp["cells"][0])
231+
dict_frames["coords"] = np.append(dict_frames["coords"], tmp["coords"][0])
232+
233+
## Correct the shape of numpy arrays
234+
dict_frames["cells"] = dict_frames["cells"].reshape(-1, 3, 3)
235+
dict_frames["coords"] = dict_frames["coords"].reshape(len(sub_traj), -1, 3)
236+
237+
return dict_frames
238+
239+
def from_labeled_system(
240+
self,
241+
file_name: str,
242+
begin: Optional[int] = 0,
243+
end: Optional[int] = None,
244+
step: Optional[int] = 1,
245+
**kwargs,
246+
) -> dict:
247+
"""Read ASE's trajectory file to `System` of multiple frames.
248+
249+
Parameters
250+
----------
251+
file_name : str
252+
ASE's trajectory file
253+
begin : int, optional
254+
begin frame index
255+
end : int, optional
256+
end frame index
257+
step : int, optional
258+
frame index step
259+
**kwargs : dict
260+
other parameters
261+
262+
Returns
263+
-------
264+
dict_frames: dict
265+
a dictionary containing data of multiple frames
266+
"""
267+
traj = Trajectory(file_name)
268+
sub_traj = traj[begin:end:step]
269+
270+
## check if the first frame has a calculator
271+
if sub_traj[0].calc is None:
272+
raise ValueError(
273+
"The input trajectory does not contain energies and forces, may not be a labeled system."
274+
)
275+
276+
dict_frames = ASEStructureFormat().from_labeled_system(sub_traj[0])
277+
for atoms in sub_traj[1:]:
278+
tmp = ASEStructureFormat().from_labeled_system(atoms)
279+
dict_frames["cells"] = np.append(dict_frames["cells"], tmp["cells"][0])
280+
dict_frames["coords"] = np.append(dict_frames["coords"], tmp["coords"][0])
281+
dict_frames["energies"] = np.append(
282+
dict_frames["energies"], tmp["energies"][0]
283+
)
284+
dict_frames["forces"] = np.append(dict_frames["forces"], tmp["forces"][0])
285+
if "virials" in tmp.keys() and "virials" in dict_frames.keys():
286+
dict_frames["virials"] = np.append(
287+
dict_frames["virials"], tmp["virials"][0]
288+
)
289+
290+
## Correct the shape of numpy arrays
291+
dict_frames["cells"] = dict_frames["cells"].reshape(-1, 3, 3)
292+
dict_frames["coords"] = dict_frames["coords"].reshape(len(sub_traj), -1, 3)
293+
dict_frames["forces"] = dict_frames["forces"].reshape(len(sub_traj), -1, 3)
294+
if "virials" in dict_frames.keys():
295+
dict_frames["virials"] = dict_frames["virials"].reshape(-1, 3, 3)
296+
297+
return dict_frames
298+
299+
190300
@Driver.register("ase")
191301
class ASEDriver(Driver):
192302
"""ASE Driver.

tests/ase_traj/MoS2.traj

21 KB
Binary file not shown.

tests/ase_traj/MoS2/box.raw

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
1.571723287959847148e+01 1.382976769974756167e-14 -4.335347714314398588e-27 -2.135776446678914908e+01 3.745820611488967700e+01 -5.045434369425391968e-16 2.142284677455918376e-24 8.545634906358754423e-15 2.312999999999935952e+01
2+
1.571734059897834079e+01 1.382976314475292595e-14 -4.334590580314998760e-27 -2.135791084402753981e+01 3.745820611488959173e+01 -5.045434369425390982e-16 2.155934654129174673e-24 8.545634906358737068e-15 2.312999999999930978e+01
3+
1.571733836538147244e+01 1.382976369106826543e-14 1.070818208280100644e-26 -2.135790780884721940e+01 3.745820611488941410e+01 -2.741620021286372525e-14 -1.227011684011214072e-23 -4.602381967266446384e-15 2.312999999999919609e+01

0 commit comments

Comments
 (0)