Skip to content

Commit 442a918

Browse files
committed
fix sub_system shape
1 parent 8ef6758 commit 442a918

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

dpdata/system.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ def sub_system(self, f_idx) :
282282
tmp = System()
283283
for ii in ['atom_numbs', 'atom_names', 'atom_types', 'orig'] :
284284
tmp.data[ii] = self.data[ii]
285-
tmp.data['cells'] = self.data['cells'][f_idx]
286-
tmp.data['coords'] = self.data['coords'][f_idx]
285+
tmp.data['cells'] = self.data['cells'][f_idx].reshape(-1, 3, 3)
286+
tmp.data['coords'] = self.data['coords'][f_idx].reshape(-1, self.data['coords'].shape[1], 3)
287287
return tmp
288288

289289

@@ -863,10 +863,10 @@ def sub_system(self, f_idx) :
863863
"""
864864
tmp_sys = LabeledSystem()
865865
tmp_sys.data = System.sub_system(self, f_idx).data
866-
tmp_sys.data['energies'] = self.data['energies'][f_idx]
867-
tmp_sys.data['forces'] = self.data['forces'][f_idx]
866+
tmp_sys.data['energies'] = np.atleast_1d(self.data['energies'][f_idx])
867+
tmp_sys.data['forces'] = self.data['forces'][f_idx].reshape(-1, self.data['forces'].shape[1], 3)
868868
if 'virials' in self.data:
869-
tmp_sys.data['virials'] = self.data['virials'][f_idx]
869+
tmp_sys.data['virials'] = self.data['virials'][f_idx].reshape(-1, 3, 3)
870870
return tmp_sys
871871

872872

0 commit comments

Comments
 (0)