Skip to content

Commit 6e73a62

Browse files
authored
fix DP driver energy shape (#289)
The shape of energies should be (nframes,) instead of (nframes, 1)
1 parent ad73779 commit 6e73a62

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

dpdata/plugins/deepmd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def label(self, data: dict) -> dict:
163163
cell = None
164164
e, f, v = self.dp.eval(coord, cell, atype)
165165
data = ss.data
166-
data['energies'] = e.reshape((1, 1))
166+
data['energies'] = e.reshape((1,))
167167
data['forces'] = f.reshape((1, ss.get_natoms(), 3))
168168
data['virials'] = v.reshape((1, 3, 3))
169169
this_sys = dpdata.LabeledSystem.from_dict({'data': data})
@@ -178,7 +178,7 @@ def label(self, data: dict) -> dict:
178178
cell = None
179179
e, f, v = self.dp.eval(coord, cell, atype)
180180
data = ori_sys.data.copy()
181-
data['energies'] = e.reshape((ori_sys.get_nframes(), 1))
181+
data['energies'] = e.reshape((ori_sys.get_nframes(),))
182182
data['forces'] = f.reshape((ori_sys.get_nframes(), ori_sys.get_natoms(), 3))
183183
data['virials'] = v.reshape((ori_sys.get_nframes(), 3, 3))
184184
return data

0 commit comments

Comments
 (0)