Skip to content

Commit 9a03f77

Browse files
njzjzrobinzyb
andauthored
fix virial in HybridDriver (#604)
Based on #603 --------- Co-authored-by: robinzyb <[email protected]>
1 parent e43f00e commit 9a03f77

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

dpdata/driver.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ def label(self, data: dict) -> dict:
163163
else:
164164
labeled_data["energies"] += lb_data["energies"]
165165
labeled_data["forces"] += lb_data["forces"]
166+
if "virials" in labeled_data and "virials" in lb_data:
167+
labeled_data["virials"] += lb_data["virials"]
166168
return labeled_data
167169

168170

tests/comp_sys.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def test_virial(self):
9292
# if len(self.system_1['virials']) == 0:
9393
# self.assertEqual(len(self.system_1['virials']), 0)
9494
# return
95-
if "virials" not in self.system_1:
96-
self.assertFalse("virials" in self.system_2)
95+
if not self.system_1.has_virial():
96+
self.assertFalse(self.system_2.has_virial())
9797
return
9898
np.testing.assert_almost_equal(
9999
self.system_1["virials"],

tests/test_predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def setUp(self):
7272
self.system_2 = dpdata.LabeledSystem(
7373
"poscars/deepmd.h2o.md", fmt="deepmd/raw", type_map=["O", "H"]
7474
)
75-
for pp in ("energies", "forces"):
75+
for pp in ("energies", "forces", "virials"):
7676
self.system_2.data[pp][:] = 3.0
7777

7878
self.places = 6

0 commit comments

Comments
 (0)