Skip to content

Commit 199afc1

Browse files
thangcktpre-commit-ci[bot]njzjz
authored
improve ASE traj (#633)
add functions to convert from others formats to ASE traj format <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced system and labeled system handling with new parameters and functionalities. - Improved stress calculations by modifying method parameters. - **Bug Fixes** - Corrected return types for several methods to ensure consistency and reliability. - **Tests** - Added new test cases to validate system setups and trajectory file operations. - **Chores** - Updated Python version matrix in workflow to include only version 3.11. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Signed-off-by: C. Thang Nguyen <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <[email protected]>
1 parent c5b36bb commit 199afc1

File tree

2 files changed

+70
-5
lines changed

2 files changed

+70
-5
lines changed

dpdata/plugins/ase.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
import os
4+
from typing import TYPE_CHECKING, Generator
45

56
import numpy as np
67

@@ -94,7 +95,7 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict:
9495
"forces": np.array([forces]),
9596
}
9697
try:
97-
stress = atoms.get_stress(False)
98+
stress = atoms.get_stress(voigt=False)
9899
except PropertyNotImplementedError:
99100
pass
100101
else:
@@ -110,7 +111,7 @@ def from_multi_systems(
110111
step: int | None = None,
111112
ase_fmt: str | None = None,
112113
**kwargs,
113-
) -> ase.Atoms:
114+
) -> Generator[ase.Atoms, None, None]:
114115
"""Convert a ASE supported file to ASE Atoms.
115116
116117
It will finally be converted to MultiSystems.
@@ -140,7 +141,7 @@ def from_multi_systems(
140141
frames = ase.io.read(file_name, format=ase_fmt, index=slice(begin, end, step))
141142
yield from frames
142143

143-
def to_system(self, data, **kwargs):
144+
def to_system(self, data, **kwargs) -> list[ase.Atoms]:
144145
"""Convert System to ASE Atom obj."""
145146
from ase import Atoms
146147

@@ -158,7 +159,7 @@ def to_system(self, data, **kwargs):
158159

159160
return structures
160161

161-
def to_labeled_system(self, data, *args, **kwargs):
162+
def to_labeled_system(self, data, *args, **kwargs) -> list[ase.Atoms]:
162163
"""Convert System to ASE Atoms object."""
163164
from ase import Atoms
164165
from ase.calculators.singlepoint import SinglePointCalculator
@@ -300,6 +301,46 @@ def from_labeled_system(
300301

301302
return dict_frames
302303

304+
def to_system(self, data, file_name: str = "confs.traj", **kwargs) -> None:
305+
"""Convert System to ASE Atoms object.
306+
307+
Parameters
308+
----------
309+
file_name : str
310+
path to file
311+
"""
312+
from ase.io import Trajectory
313+
314+
if os.path.isfile(file_name):
315+
os.remove(file_name)
316+
317+
list_atoms = ASEStructureFormat().to_system(data, **kwargs)
318+
traj = Trajectory(file_name, "a")
319+
_ = [traj.write(atom) for atom in list_atoms]
320+
traj.close()
321+
return
322+
323+
def to_labeled_system(
324+
self, data, file_name: str = "labeled_confs.traj", *args, **kwargs
325+
) -> None:
326+
"""Convert System to ASE Atoms object.
327+
328+
Parameters
329+
----------
330+
file_name : str
331+
path to file
332+
"""
333+
from ase.io import Trajectory
334+
335+
if os.path.isfile(file_name):
336+
os.remove(file_name)
337+
338+
list_atoms = ASEStructureFormat().to_labeled_system(data, *args, **kwargs)
339+
traj = Trajectory(file_name, "a")
340+
_ = [traj.write(atom) for atom in list_atoms]
341+
traj.close()
342+
return
343+
303344

304345
@Driver.register("ase")
305346
class ASEDriver(Driver):

tests/test_ase_traj.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,29 @@ def setUp(self):
6767
self.v_places = 4
6868

6969

70+
@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix")
71+
class TestASEtraj4(unittest.TestCase, CompSys, IsPBC):
72+
def setUp(self):
73+
self.system_1 = dpdata.System("ase_traj/MoS2", fmt="deepmd")
74+
self.system_1.to(file_name="ase_traj/tmp.traj", fmt="ase/traj")
75+
self.system_2 = dpdata.System("ase_traj/tmp.traj", fmt="ase/traj")
76+
self.places = 6
77+
self.e_places = 6
78+
self.f_places = 6
79+
self.v_places = 4
80+
81+
82+
@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix")
83+
class TestASEtraj4Labeled(unittest.TestCase, CompLabeledSys, IsPBC):
84+
def setUp(self):
85+
self.system_1 = dpdata.LabeledSystem("ase_traj/MoS2", fmt="deepmd")
86+
self.system_1.to(file_name="ase_traj/tmp1.traj", fmt="ase/traj")
87+
self.system_2 = dpdata.LabeledSystem("ase_traj/tmp1.traj", fmt="ase/traj")
88+
self.places = 6
89+
self.e_places = 6
90+
self.f_places = 6
91+
self.v_places = 4
92+
93+
7094
if __name__ == "__main__":
7195
unittest.main()

0 commit comments

Comments
 (0)