Skip to content

Commit 0b93d91

Browse files
committed
add shuffle method to shuffle data
1 parent efdbb94 commit 0b93d91

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed

dpdata/system.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,14 @@ def nopbc(self):
773773
return True
774774
return False
775775

776+
def shuffle(self):
777+
"""Shuffle frames randomly."""
778+
idx = np.random.permutation(self.get_nframes())
779+
for ii in ['cells', 'coords']:
780+
self.data[ii] = self.data[ii][idx]
781+
return idx
782+
783+
776784
def get_cell_perturb_matrix(cell_pert_fraction):
777785
if cell_pert_fraction<0:
778786
raise RuntimeError('cell_pert_fraction can not be negative')
@@ -1131,6 +1139,14 @@ def sort_atom_types(self):
11311139
if ii in self.data:
11321140
self.data[ii] = self.data[ii][:, idx]
11331141

1142+
def shuffle(self):
1143+
"""Also shuffle labeled data e.g. energies and forces."""
1144+
idx = System.shuffle(self)
1145+
for ii in ['energies', 'forces', 'virials', 'atom_pref']:
1146+
if ii in self.data:
1147+
self.data[ii] = self.data[ii][idx]
1148+
return idx
1149+
11341150

11351151
class MultiSystems:
11361152
'''A set containing several systems.'''

tests/test_qe_pw_scf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def test_cell(self) :
3838
for ii in range(cell.shape[0]) :
3939
for jj in range(cell.shape[1]) :
4040
self.assertAlmostEqual(self.system_h2o.data['cells'][0][ii][jj], cell[ii][jj])
41+
fp.close()
4142

4243

4344
def test_coord(self) :

tests/test_shuffle.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import unittest
2+
from context import dpdata
3+
from comp_sys import CompLabeledSys, IsPBC
4+
5+
class TestDeepmdLoadRaw(unittest.TestCase, CompLabeledSys, IsPBC):
6+
def setUp (self) :
7+
original_system = dpdata.LabeledSystem('poscars/OUTCAR.h2o.md',
8+
fmt = 'vasp/outcar')
9+
original_system += original_system
10+
original_system += original_system
11+
original_system += original_system
12+
self.system_1 = dpdata.LabeledSystem()
13+
self.system_2 = original_system.copy()
14+
idx = self.system_2.shuffle()
15+
for ii in idx:
16+
self.system_1.append(original_system.sub_system(ii))
17+
18+
self.places = 6
19+
self.e_places = 6
20+
self.f_places = 6
21+
self.v_places = 6

0 commit comments

Comments
 (0)