Skip to content

Commit 26e4bcc

Browse files
authored
add sqm driver (#286)
1 parent 0a1c18b commit 26e4bcc

File tree

3 files changed

+103
-4
lines changed

3 files changed

+103
-4
lines changed

dpdata/plugins/amber.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import tempfile
2+
import os
3+
import subprocess as sp
4+
15
import dpdata.amber.md
26
import dpdata.amber.sqm
37
from dpdata.format import Format
8+
from dpdata.driver import Driver
49

510

611
@Format.register("amber/md")
@@ -49,5 +54,70 @@ class SQMINFormat(Format):
4954
def to_system(self, data, fname=None, frame_idx=0, **kwargs):
5055
"""
5156
Generate input files for semi-emperical calculation in sqm software
57+
58+
Parameters
59+
----------
60+
data : dict
61+
system data
62+
fname : str
63+
output file name
64+
frame_idx : int, default=0
65+
index of frame to write
66+
67+
Other Parameters
68+
----------------
69+
**kwargs : dict
70+
valid parameters are:
71+
theory : str, default=dftb3
72+
level of theory. Options includes AM1, RM1, MNDO, PM3-PDDG, MNDO-PDDG,
73+
PM3-CARB1, MNDO/d, AM1/d, PM6, DFTB2, DFTB3
74+
charge : int, default=0
75+
total charge in electron units
76+
maxcyc : int, default=0
77+
maximum number of minimization cycles to allow. 0 represents a
78+
single-point calculation
79+
mult : int, default=1
80+
multiplicity. Only 1 is allowed.
5281
"""
5382
return dpdata.amber.sqm.make_sqm_in(data, fname, frame_idx, **kwargs)
83+
84+
85+
@Driver.register("sqm")
86+
class SQMDriver(Driver):
87+
"""AMBER sqm program driver.
88+
89+
Parameters
90+
----------
91+
sqm_exec : str, default=sqm
92+
path to sqm program
93+
**kwargs : dict
94+
other arguments to make input files. See :class:`SQMINFormat`
95+
96+
Examples
97+
--------
98+
Use DFTB3 method to calculate potential energy:
99+
>>> labeled_system = system.predict(theory="DFTB3", driver="sqm")
100+
>>> labeled_system['energies'][0]
101+
-15.41111246
102+
"""
103+
def __init__(self, sqm_exec: str="sqm", **kwargs: dict) -> None:
104+
self.sqm_exec = sqm_exec
105+
self.kwargs = kwargs
106+
107+
def label(self, data: dict) -> dict:
108+
ori_system = dpdata.System(data=data)
109+
labeled_system = dpdata.LabeledSystem()
110+
with tempfile.TemporaryDirectory() as d:
111+
for ii, ss in enumerate(ori_system):
112+
inp_fn = os.path.join(d, "%d.in" % ii)
113+
out_fn = os.path.join(d, "%d.out" % ii)
114+
ss.to("sqm/in", inp_fn, **self.kwargs)
115+
try:
116+
sp.check_output([*self.sqm_exec.split(), "-O", "-i", inp_fn, "-o", out_fn])
117+
except sp.CalledProcessError as e:
118+
with open(out_fn) as f:
119+
raise RuntimeError(
120+
"Run sqm failed! Output:\n" + f.read()
121+
) from e
122+
labeled_system.append(dpdata.LabeledSystem(out_fn, fmt="sqm/out"))
123+
return labeled_system.data

tests/comp_sys.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@ def test_atom_names(self):
1818
self.system_2.data['atom_names'])
1919

2020
def test_atom_types(self):
21-
self.assertEqual(self.system_1.data['atom_types'][0],
22-
self.system_2.data['atom_types'][0])
23-
self.assertEqual(self.system_1.data['atom_types'][1],
24-
self.system_2.data['atom_types'][1])
21+
np.testing.assert_array_equal(self.system_1.data['atom_types'],
22+
self.system_2.data['atom_types'])
2523

2624
def test_orig(self):
2725
for d0 in range(3) :

tests/test_sqm_driver.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import unittest
2+
import shutil
3+
4+
import numpy as np
5+
from context import dpdata
6+
from comp_sys import CompSys, IsNoPBC
7+
8+
9+
@unittest.skipIf(shutil.which("sqm") is None, "sqm is not installed")
10+
class TestSQMdriver(unittest.TestCase, CompSys, IsNoPBC):
11+
"""Test sqm with a hydrogen ion."""
12+
@classmethod
13+
def setUpClass(cls):
14+
cls.system_1 = dpdata.System(data={
15+
"atom_names": ["H"],
16+
"atom_numbs": [1],
17+
"atom_types": np.zeros((1,), dtype=int),
18+
"coords": np.zeros((1, 1, 3), dtype=np.float32),
19+
"cells": np.zeros((1, 3, 3), dtype=np.float32),
20+
"orig": np.zeros(3, dtype=np.float32),
21+
"nopbc": True,
22+
})
23+
cls.system_2 = cls.system_1.predict(theory="DFTB3", charge=1, driver="sqm")
24+
cls.places = 6
25+
26+
def test_energy(self):
27+
self.assertAlmostEqual(self.system_2['energies'].ravel()[0], 6.549447)
28+
29+
def test_forces(self):
30+
forces = self.system_2['forces']
31+
np.testing.assert_allclose(forces, np.zeros_like(forces))

0 commit comments

Comments
 (0)