Skip to content

Commit e2f235f

Browse files
add support for n2p2 data format (#627)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 18a0ed5 commit e2f235f

File tree

4 files changed

+288
-0
lines changed

4 files changed

+288
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ The `System` or `LabeledSystem` can be constructed from the following file forma
9696
| ABACUS | STRU | True | True | LabeledSystem | 'abacus/relax' |
9797
| ase | structure | True | True | MultiSystems | 'ase/structure' |
9898
| DFTB+ | dftbplus | False | True | LabeledSystem | 'dftbplus' |
99+
| n2p2 | n2p2 | True | True | LabeledSystem | 'n2p2' |
99100

100101

101102
The Class `dpdata.MultiSystems` can read data from a dir which may contains many files of different systems, or from single xyz file which contains different systems.

dpdata/plugins/n2p2.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import numpy as np
2+
3+
from dpdata.format import Format
4+
5+
from ..unit import EnergyConversion, ForceConversion, LengthConversion
6+
7+
length_convert = LengthConversion("bohr", "angstrom").value()
8+
energy_convert = EnergyConversion("hartree", "eV").value()
9+
force_convert = ForceConversion("hartree/bohr", "eV/angstrom").value()
10+
11+
12+
def match_indices(atype1, atype2):
13+
# Ensure atype2 is a numpy array for efficient operations
14+
atype2 = np.array(atype2)
15+
# Placeholder for matched indices
16+
matched_indices = []
17+
# Track used indices to handle duplicates
18+
used_indices = set()
19+
20+
# Iterate over each element in atype1
21+
for element in atype1:
22+
# Find all indices of the current element in atype2
23+
# np.where returns a tuple, so [0] is used to access the array of indices
24+
indices = np.where(atype2 == element)[0]
25+
26+
# Find the first unused index
27+
for index in indices:
28+
if index not in used_indices:
29+
# Add the index to the results and mark it as used
30+
matched_indices.append(index)
31+
used_indices.add(index)
32+
break # Move to the next element in atype1
33+
34+
return matched_indices
35+
36+
37+
@Format.register("n2p2")
38+
class N2P2Format(Format):
39+
"""n2p2.
40+
41+
This class support the conversion from and to the training data of n2p2 format.
42+
For more information about the n2p2 format, please refer to https://compphysvienna.github.io/n2p2/topics/cfg_file.html
43+
"""
44+
45+
def from_labeled_system(self, file_name, **kwargs):
46+
"""Read from n2p2 format.
47+
48+
Parameters
49+
----------
50+
file_name : str
51+
file name, i.e. the first argument
52+
**kwargs : dict
53+
keyword arguments that will be passed from the method
54+
55+
Returns
56+
-------
57+
data : dict
58+
system data, whose keys are defined in LabeledSystem.DTYPES
59+
"""
60+
cells = []
61+
coords = []
62+
atypes = []
63+
forces = []
64+
energies = []
65+
natom0 = None
66+
natoms0 = None
67+
atom_types0 = None
68+
with open(file_name) as file:
69+
for line in file:
70+
line = line.strip() # Remove leading/trailing whitespace
71+
if line.lower() == "begin":
72+
current_section = [] # Start a new section
73+
cell = []
74+
coord = []
75+
atype = []
76+
force = []
77+
energy = None
78+
elif line.lower() == "end":
79+
# If we are at the end of a section, process the section
80+
assert (
81+
len(coord) == len(atype) == len(force)
82+
), "Number of atoms, atom types, and forces must match."
83+
84+
# Check if the number of atoms is consistent across all frames
85+
natom = len(coord)
86+
if natom0 is None:
87+
natom0 = natom
88+
else:
89+
assert (
90+
natom == natom0
91+
), "The number of atoms in all frames must be the same."
92+
93+
# Check if the number of atoms of each type is consistent across all frames
94+
atype = np.array(atype)
95+
unique_dict = {element: None for element in atype}
96+
unique_atypes = np.array(list(unique_dict.keys()))
97+
unique_atypes_list = list(unique_atypes)
98+
ntypes = len(unique_atypes)
99+
natoms = [len(atype[atype == at]) for at in unique_atypes]
100+
if natoms0 is None:
101+
natoms0 = natoms
102+
else:
103+
assert (
104+
natoms == natoms0
105+
), "The number of atoms of each type in all frames must be the same."
106+
if atom_types0 is None:
107+
atom_types0 = atype
108+
atom_order = match_indices(atom_types0, atype)
109+
110+
cell = np.array(cell, dtype=float)
111+
coord = np.array(coord, dtype=float)[atom_order]
112+
force = np.array(force, dtype=float)[atom_order]
113+
114+
cells.append(cell)
115+
coords.append(coord)
116+
forces.append(force)
117+
energies.append(float(energy))
118+
119+
current_section = None # Reset for the next section
120+
elif current_section is not None:
121+
# If we are inside a section, append the line to the current section
122+
# current_section.append(line)
123+
line_contents = line.split()
124+
if line_contents[0] == "lattice":
125+
cell.append(line_contents[1:])
126+
elif line_contents[0] == "atom":
127+
coord.append(line_contents[1:4])
128+
atype.append(line_contents[4])
129+
force.append(line_contents[7:10])
130+
elif line_contents[0] == "energy":
131+
energy = line_contents[1]
132+
133+
atom_names = unique_atypes_list
134+
atom_numbs = natoms
135+
atom_types = np.zeros(len(atom_types0), dtype=int)
136+
for i in range(ntypes):
137+
atom_types[atom_types0 == unique_atypes_list[i]] = i
138+
139+
cells = np.array(cells) * length_convert
140+
coords = np.array(coords) * length_convert
141+
forces = np.array(forces) * force_convert
142+
energies = np.array(energies) * energy_convert
143+
144+
return {
145+
"atom_names": list(atom_names),
146+
"atom_numbs": list(atom_numbs),
147+
"atom_types": atom_types,
148+
"coords": coords,
149+
"cells": cells,
150+
"nopbc": False,
151+
"orig": np.zeros(3),
152+
"energies": energies,
153+
"forces": forces,
154+
}
155+
156+
def to_labeled_system(self, data, file_name, **kwargs):
157+
"""Write n2p2 format.
158+
159+
By default, LabeledSystem.to will fallback to System.to.
160+
161+
Parameters
162+
----------
163+
data : dict
164+
system data, whose keys are defined in LabeledSystem.DTYPES
165+
file_name : str
166+
file name, where the data will be written
167+
*args : list
168+
arguments that will be passed from the method
169+
**kwargs : dict
170+
keyword arguments that will be passed from the method
171+
"""
172+
buff = []
173+
nframe = len(data["energies"])
174+
natom = len(data["atom_types"])
175+
atom_names = data["atom_names"]
176+
for frame in range(nframe):
177+
coord = data["coords"][frame] / length_convert
178+
force = data["forces"][frame] / force_convert
179+
energy = data["energies"][frame] / energy_convert
180+
cell = data["cells"][frame] / length_convert
181+
atype = data["atom_types"]
182+
buff.append("begin")
183+
for i in range(3):
184+
buff.append(
185+
f"lattice {cell[i][0]:15.6f} {cell[i][1]:15.6f} {cell[i][2]:15.6f}"
186+
)
187+
for i in range(natom):
188+
buff.append(
189+
f"atom {coord[i][0]:15.6f} {coord[i][1]:15.6f} {coord[i][2]:15.6f} {atom_names[atype[i]]:>7} {0:15.6f} {0:15.6f} {force[i][0]:15.6e} {force[i][1]:15.6e} {force[i][2]:15.6e}"
190+
)
191+
buff.append(f"energy {energy:15.6f}")
192+
buff.append(f"charge {0:15.6f}")
193+
buff.append("end")
194+
with open(file_name, "w") as fp:
195+
fp.write("\n".join(buff))

tests/n2p2/input.data

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
begin
2+
lattice 18.897261 0.000000 0.000000
3+
lattice 0.000000 18.897261 0.000000
4+
lattice 0.000000 0.000000 18.897261
5+
atom 1.889726 0.000000 0.000000 O 0.000000 0.000000 9.723452e-03 0.000000e+00 0.000000e+00
6+
atom 0.000000 0.000000 2.834589 H 0.000000 0.000000 0.000000e+00 0.000000e+00 1.458518e-02
7+
atom 1.889726 0.000000 5.669178 H 0.000000 0.000000 9.723452e-03 0.000000e+00 2.917036e-02
8+
energy 0.044099
9+
charge 0.000000
10+
end
11+
begin
12+
lattice 18.897261 0.000000 0.000000
13+
lattice 0.000000 18.897261 0.000000
14+
lattice 0.000000 0.000000 18.897261
15+
atom 3.779452 1.889726 1.889726 O 0.000000 0.000000 4.861726e-02 3.889381e-02 3.889381e-02
16+
atom 1.889726 1.889726 4.724315 H 0.000000 0.000000 3.889381e-02 3.889381e-02 5.347899e-02
17+
atom 3.779452 1.889726 7.558904 H 0.000000 0.000000 4.861726e-02 3.889381e-02 6.806416e-02
18+
energy 0.084523
19+
charge 0.000000
20+
end

tests/test_n2p2.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import os
2+
import unittest
3+
4+
import numpy as np
5+
from context import dpdata
6+
7+
from dpdata.unit import EnergyConversion, ForceConversion, LengthConversion
8+
9+
length_convert = LengthConversion("bohr", "angstrom").value()
10+
energy_convert = EnergyConversion("hartree", "eV").value()
11+
force_convert = ForceConversion("hartree/bohr", "eV/angstrom").value()
12+
13+
14+
class TestN2P2(unittest.TestCase):
15+
def setUp(self):
16+
self.data_ref = {
17+
"atom_numbs": [1, 2],
18+
"atom_names": ["O", "H"],
19+
"atom_types": np.array([0, 1, 1]),
20+
"orig": np.array([0.0, 0.0, 0.0]),
21+
"cells": np.array(
22+
[
23+
[[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]],
24+
[[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]],
25+
]
26+
),
27+
"coords": np.array(
28+
[
29+
[[1.0, 0.0, 0.0], [0.0, 0.0, 1.5], [1.0, 0.0, 3.0]],
30+
[[2.0, 1.0, 1.0], [1.0, 1.0, 2.5], [2.0, 1.0, 4.0]],
31+
]
32+
),
33+
"energies": np.array([1.2, 2.3]),
34+
"forces": np.array(
35+
[
36+
[[0.5, 0.0, 0.0], [0.0, 0.0, 0.75], [0.5, 0.0, 1.5]],
37+
[[2.5, 2.0, 2.0], [2.0, 2.0, 2.75], [2.5, 2.0, 3.5]],
38+
]
39+
),
40+
}
41+
42+
def test_n2p2_from_labeled_system(self):
43+
data = dpdata.LabeledSystem("n2p2/input.data", fmt="n2p2")
44+
for key in self.data_ref:
45+
if key == "atom_numbs":
46+
self.assertEqual(data[key], self.data_ref[key])
47+
elif key == "atom_names":
48+
self.assertEqual(data[key], self.data_ref[key])
49+
elif key == "atom_types":
50+
np.testing.assert_array_equal(data[key], self.data_ref[key])
51+
else:
52+
np.testing.assert_array_almost_equal(
53+
data[key], self.data_ref[key], decimal=5
54+
)
55+
56+
def test_n2p2_to_labeled_system(self):
57+
output_file = "n2p2/output.data"
58+
data = dpdata.LabeledSystem.from_dict({"data": self.data_ref})
59+
data.to_n2p2(output_file)
60+
ref_file = "n2p2/input.data"
61+
with open(ref_file) as file1, open(output_file) as file2:
62+
file1_lines = file1.readlines()
63+
file2_lines = file2.readlines()
64+
65+
file1_lines = [line.strip("\n") for line in file1_lines]
66+
file2_lines = [line.strip("\n") for line in file2_lines]
67+
68+
self.assertListEqual(file1_lines, file2_lines)
69+
70+
def tearDown(self):
71+
if os.path.isfile("n2p2/output.data"):
72+
os.remove("n2p2/output.data")

0 commit comments

Comments
 (0)