Skip to content

Commit 3e96c2c

Browse files
authored
fix support for deepmd-kit 2.0.2 (#207)
fix #206.
1 parent b58aa63 commit 3e96c2c

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

dpdata/system.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ def predict(self, dp):
661661
662662
Returns
663663
-------
664-
labeled_sys LabeledSystem
664+
labeled_sys : LabeledSystem
665665
The labeled system.
666666
"""
667667
try:
@@ -679,19 +679,33 @@ def predict(self, dp):
679679

680680
labeled_sys = LabeledSystem()
681681

682-
for ss in self:
683-
coord = ss['coords'].reshape((-1,1))
684-
if not ss.nopbc:
685-
cell = ss['cells'].reshape((-1,1))
682+
if 'auto_batch_size' not in DeepPot.__init__.__code__.co_varnames:
683+
for ss in self:
684+
coord = ss['coords'].reshape((1, ss.get_natoms()*3))
685+
if not ss.nopbc:
686+
cell = ss['cells'].reshape((1, 9))
687+
else:
688+
cell = None
689+
e, f, v = dp.eval(coord, cell, atype)
690+
data = ss.data
691+
data['energies'] = e.reshape((1, 1))
692+
data['forces'] = f.reshape((1, ss.get_natoms(), 3))
693+
data['virials'] = v.reshape((1, 3, 3))
694+
this_sys = LabeledSystem.from_dict({'data': data})
695+
labeled_sys.append(this_sys)
696+
else:
697+
# since v2.0.2, auto batch size is supported
698+
coord = self.data['coords'].reshape((self.get_nframes(), self.get_natoms()*3))
699+
if not self.nopbc:
700+
cell = self.data['cells'].reshape((self.get_nframes(), 9))
686701
else:
687702
cell = None
688703
e, f, v = dp.eval(coord, cell, atype)
689-
data = ss.data
690-
data['energies'] = e.reshape((1, 1))
691-
data['forces'] = f.reshape((1, -1, 3))
692-
data['virials'] = v.reshape((1, 3, 3))
693-
this_sys = LabeledSystem.from_dict({'data': data})
694-
labeled_sys.append(this_sys)
704+
data = self.data.copy()
705+
data['energies'] = e.reshape((self.get_nframes(), 1))
706+
data['forces'] = f.reshape((self.get_nframes(), self.get_natoms(), 3))
707+
data['virials'] = v.reshape((self.get_nframes(), 3, 3))
708+
labeled_sys = LabeledSystem.from_dict({'data': data})
695709
return labeled_sys
696710

697711
def pick_atom_idx(self, idx, nopbc=None):

0 commit comments

Comments
 (0)