Skip to content

Commit 0447615

Browse files
authored
Merge pull request #338 from hsulab/fixAseCalculatorStress
fix an error in stress by ase interface
2 parents 1a6cd1c + a24971f commit 0447615

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

source/train/calculator.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@
2424
```
2525
"""
2626

27-
from ase.calculators.calculator import Calculator, all_changes
27+
from ase.calculators.calculator import (
28+
Calculator, all_changes, PropertyNotImplementedError
29+
)
2830
import deepmd.DeepPot as DeepPot
2931

3032

3133
class DP(Calculator):
3234
name = "DP"
33-
implemented_properties = ["energy", "forces", "stress"]
35+
implemented_properties = ["energy", "forces", "virial", "stress"]
3436

3537
def __init__(self, model, label="DP", type_dict=None, **kwargs):
3638
Calculator.__init__(self, label=label, **kwargs)
@@ -40,7 +42,7 @@ def __init__(self, model, label="DP", type_dict=None, **kwargs):
4042
else:
4143
self.type_dict = dict(zip(self.dp.get_type_map(), range(self.dp.get_ntypes())))
4244

43-
def calculate(self, atoms=None, properties=["energy", "forces", "stress"], system_changes=all_changes):
45+
def calculate(self, atoms=None, properties=["energy", "forces", "virial"], system_changes=all_changes):
4446
coord = atoms.get_positions().reshape([1, -1])
4547
if sum(atoms.get_pbc())>0:
4648
cell = atoms.get_cell().reshape([1, -1])
@@ -49,7 +51,17 @@ def calculate(self, atoms=None, properties=["energy", "forces", "stress"], syste
4951
symbols = atoms.get_chemical_symbols()
5052
atype = [self.type_dict[k] for k in symbols]
5153
e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)
52-
self.results['energy'] = e[0]
54+
self.results['energy'] = e[0][0]
5355
self.results['forces'] = f[0]
54-
self.results['stress'] = v[0]
56+
self.results['virial'] = v[0].reshape(3,3)
5557

58+
# convert virial into stress for lattice relaxation
59+
if "stress" in properties:
60+
if sum(atoms.get_pbc()) > 0:
61+
# the usual convention (tensile stress is positive)
62+
# stress = -virial / volume
63+
stress = -0.5*(v[0].copy()+v[0].copy().T) / atoms.get_volume()
64+
# Voigt notation
65+
self.results['stress'] = stress.flat[[0,4,8,5,2,1]]
66+
else:
67+
raise PropertyNotImplementedError

0 commit comments

Comments
 (0)