Skip to content

Commit 94741be

Browse files
add psi4/inp format (#564)
Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4536aa6 commit 94741be

File tree

4 files changed

+145
-1
lines changed

4 files changed

+145
-1
lines changed

dpdata/plugins/psi4.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from dpdata.format import Format
4+
from dpdata.psi4.input import write_psi4_input
45
from dpdata.psi4.output import read_psi4_output
56
from dpdata.unit import EnergyConversion, ForceConversion, LengthConversion
67

@@ -50,3 +51,53 @@ def from_labeled_system(self, file_name: str, **kwargs) -> dict:
5051
"orig": np.zeros(3),
5152
"nopbc": True,
5253
}
54+
55+
56+
@Format.register("psi4/inp")
57+
class PSI4InputFormat(Format):
58+
"""Psi4 input file."""
59+
60+
def to_system(
61+
self,
62+
data: dict,
63+
file_name: str,
64+
method: str,
65+
basis: str,
66+
charge: int = 0,
67+
multiplicity: int = 1,
68+
frame_idx=0,
69+
**kwargs,
70+
):
71+
"""Write PSI4 input.
72+
73+
Parameters
74+
----------
75+
data : dict
76+
system data
77+
file_name : str
78+
file name
79+
method : str
80+
computational method
81+
basis : str
82+
basis set; see https://psicode.org/psi4manual/master/basissets_tables.html
83+
charge : int, default=0
84+
charge of system
85+
multiplicity : int, default=1
86+
multiplicity of system
87+
frame_idx : int, default=0
88+
The index of the frame to dump
89+
**kwargs
90+
keyword arguments
91+
"""
92+
types = np.array(data["atom_names"])[data["atom_types"]]
93+
with open(file_name, "w") as fout:
94+
fout.write(
95+
write_psi4_input(
96+
types=types,
97+
coords=data["coords"][frame_idx],
98+
method=method,
99+
basis=basis,
100+
charge=charge,
101+
multiplicity=multiplicity,
102+
)
103+
)

dpdata/psi4/input.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import numpy as np
2+
3+
# Angston is used in Psi4 by default
4+
template = """molecule {{
5+
{atoms:s}
6+
{charge:d} {multiplicity:d}
7+
}}
8+
set basis {basis:s}
9+
set gradient_write on
10+
G, wfn = gradient("WB97M-D3BJ", return_wfn=True)
11+
wfn.energy()
12+
wfn.gradient().print_out()
13+
"""
14+
15+
16+
def write_psi4_input(
17+
types: np.ndarray,
18+
coords: np.ndarray,
19+
method: str,
20+
basis: str,
21+
charge: int = 0,
22+
multiplicity: int = 1,
23+
) -> str:
24+
"""Write Psi4 input file.
25+
26+
Parameters
27+
----------
28+
types : np.ndarray
29+
atomic symbols
30+
coords : np.ndarray
31+
atomic coordinates
32+
method : str
33+
computational method
34+
basis : str
35+
basis set; see https://psicode.org/psi4manual/master/basissets_tables.html
36+
charge : int, default=0
37+
charge of system
38+
multiplicity : int, default=1
39+
multiplicity of system
40+
41+
Returns
42+
-------
43+
str
44+
content of Psi4 input file
45+
"""
46+
return template.format(
47+
atoms="\n".join(
48+
[
49+
"{:s} {:16.9f} {:16.9f} {:16.9f}".format(*ii)
50+
for ii in zip(types, *coords.T)
51+
]
52+
),
53+
charge=charge,
54+
multiplicity=multiplicity,
55+
method=method,
56+
basis=basis,
57+
)

tests/comp_sys.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def test_cell(self):
4747
def test_coord(self):
4848
self.assertEqual(self.system_1.get_nframes(), self.system_2.get_nframes())
4949
# think about direct coord
50+
if self.system_1.nopbc:
51+
# nopbc doesn't need to test cells
52+
return
5053
tmp_cell = self.system_1.data["cells"]
5154
tmp_cell = np.reshape(tmp_cell, [-1, 3])
5255
tmp_cell_norm = np.reshape(np.linalg.norm(tmp_cell, axis=1), [-1, 1, 3])

tests/test_psi4.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import tempfile
2+
import textwrap
13
import unittest
24

35
import numpy as np
46
from comp_sys import CompLabeledSys, IsNoPBC
57
from context import dpdata
68

79

8-
class TestDeepmdLoadDumpHDF5(unittest.TestCase, CompLabeledSys, IsNoPBC):
10+
class TestPsi4Output(unittest.TestCase, CompLabeledSys, IsNoPBC):
911
def setUp(self):
1012
length_convert = dpdata.unit.LengthConversion("bohr", "angstrom").value()
1113
energy_convert = dpdata.unit.EnergyConversion("hartree", "eV").value()
@@ -60,3 +62,34 @@ def setUp(self):
6062
self.e_places = 6
6163
self.f_places = 6
6264
self.v_places = 6
65+
66+
67+
class TestPsi4Input(unittest.TestCase):
68+
def test_psi4_input(self):
69+
system = dpdata.LabeledSystem("psi4/psi4.out", fmt="psi4/out")
70+
with tempfile.NamedTemporaryFile("r") as f:
71+
system.to_psi4_inp(f.name, method="WB97M-D3BJ", basis="def2-TZVPPD")
72+
content = f.read()
73+
self.assertEqual(
74+
content,
75+
textwrap.dedent(
76+
"""\
77+
molecule {
78+
C 0.692724290 -0.280972290 0.149966626
79+
C -0.690715864 0.280527594 -0.157432416
80+
H 1.241584247 -0.707702380 -0.706342645
81+
H 0.492994289 -1.086482665 0.876517411
82+
H 1.330104482 0.478557878 0.633157279
83+
H -1.459385451 -0.498843014 -0.292862623
84+
H -0.623545813 0.873377524 -1.085142510
85+
H -1.005665735 0.946387574 0.663566976
86+
0 1
87+
}
88+
set basis def2-TZVPPD
89+
set gradient_write on
90+
G, wfn = gradient("WB97M-D3BJ", return_wfn=True)
91+
wfn.energy()
92+
wfn.gradient().print_out()
93+
"""
94+
),
95+
)

0 commit comments

Comments
 (0)