Skip to content

Commit 42be217

Browse files
authored
Merge pull request #85 from njzjz/predict
add predict method
2 parents 4cc5fc4 + 331062c commit 42be217

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

dpdata/system.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,44 @@ def shuffle(self):
839839
self.data[ii] = self.data[ii][idx]
840840
return idx
841841

842+
def predict(self, dp):
843+
"""
844+
Predict energies and forces by deepmd-kit.
845+
846+
Parameters
847+
----------
848+
dp : deepmd.DeepPot or str
849+
The deepmd-kit potential class or the filename of the model.
850+
851+
Returns
852+
-------
853+
labeled_sys LabeledSystem
854+
The labeled system.
855+
"""
856+
import deepmd.DeepPot as DeepPot
857+
if not isinstance(dp, DeepPot):
858+
dp = DeepPot(dp)
859+
type_map = dp.get_type_map()
860+
ori_sys = self.copy()
861+
ori_sys.sort_atom_names(type_map=type_map)
862+
atype = ori_sys['atom_types']
863+
864+
labeled_sys = LabeledSystem()
865+
866+
for ss in self:
867+
coord = ss['coords'].reshape((-1,1))
868+
if not ss.nopbc:
869+
cell = ss['cells'].reshape((-1,1))
870+
else:
871+
cell = None
872+
e, f, v = dp.eval(coord, cell, atype)
873+
data = ss.data
874+
data['energies'] = e.reshape((1, 1))
875+
data['forces'] = f.reshape((1, -1, 3))
876+
data['virials'] = v.reshape((1, 3, 3))
877+
this_sys = LabeledSystem.from_dict({'data': data})
878+
labeled_sys.append(this_sys)
879+
return labeled_sys
842880

843881
def get_cell_perturb_matrix(cell_pert_fraction):
844882
if cell_pert_fraction<0:

0 commit comments

Comments
 (0)