Skip to content

Commit b1f13f6

Browse files
add train_test_split method (#459)
Split a MultiSystems into training and test sub sets --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c09e0ee commit b1f13f6

File tree

2 files changed

+83
-1
lines changed

2 files changed

+83
-1
lines changed

dpdata/system.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from copy import deepcopy
55
from enum import Enum, unique
6-
from typing import Any, Tuple, Union
6+
from typing import Any, Dict, Optional, Tuple, Union
77

88
import numpy as np
99
from monty.json import MSONable
@@ -1584,6 +1584,63 @@ def correction(self, hl_sys: "MultiSystems"):
15841584
corrected_sys.append(ll_ss.correction(hl_ss))
15851585
return corrected_sys
15861586

1587+
def train_test_split(
1588+
self, test_size: Union[float, int], seed: Optional[int] = None
1589+
) -> Tuple["MultiSystems", "MultiSystems", Dict[str, np.ndarray]]:
1590+
"""Split systems into random train and test subsets.
1591+
1592+
Parameters
1593+
----------
1594+
test_size : float or int
1595+
If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split.
1596+
If int, represents the absolute number of test samples.
1597+
seed : int, default=None
1598+
Random seed
1599+
1600+
Returns
1601+
-------
1602+
MultiSystems
1603+
The training set
1604+
MultiSystems
1605+
The testing set
1606+
Dict[str, np.ndarray]
1607+
The bool array of training and testing sets for each system. False for training set and True for testing set.
1608+
"""
1609+
nframes = self.get_nframes()
1610+
if isinstance(test_size, float):
1611+
assert 0 <= test_size <= 1
1612+
test_size = int(np.floor(test_size * nframes))
1613+
elif isinstance(test_size, int):
1614+
assert 0 <= test_size <= nframes
1615+
else:
1616+
raise RuntimeError("test_size should be float or int")
1617+
# get random indices
1618+
rng = np.random.default_rng(seed=seed)
1619+
test_idx = rng.choice(nframes, test_size, replace=False)
1620+
select_test = np.zeros(nframes, dtype=bool)
1621+
select_test[test_idx] = True
1622+
select_train = np.logical_not(select_test)
1623+
# flatten systems dict
1624+
system_names, system_sizes = zip(
1625+
*((kk, len(vv)) for (kk, vv) in self.systems.items())
1626+
)
1627+
system_idx = np.empty(len(system_sizes) + 1, dtype=int)
1628+
system_idx[0] = 0
1629+
np.cumsum(system_sizes, out=system_idx[1:])
1630+
# make new systems
1631+
train_systems = MultiSystems(type_map=self.atom_names)
1632+
test_systems = MultiSystems(type_map=self.atom_names)
1633+
test_system_idx = {}
1634+
for ii, nn in enumerate(system_names):
1635+
sub_train = self[nn][select_train[system_idx[ii] : system_idx[ii + 1]]]
1636+
if len(sub_train):
1637+
train_systems.append(sub_train)
1638+
sub_test = self[nn][select_test[system_idx[ii] : system_idx[ii + 1]]]
1639+
if len(sub_test):
1640+
test_systems.append(sub_test)
1641+
test_system_idx[nn] = select_test[system_idx[ii] : system_idx[ii + 1]]
1642+
return train_systems, test_systems, test_system_idx
1643+
15871644

15881645
def get_cls_name(cls: object) -> str:
15891646
"""Returns the fully qualified name of a class, such as `np.ndarray`.

tests/test_split_dataset.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import unittest
2+
3+
import numpy as np
4+
from context import dpdata
5+
6+
7+
class TestSplitDataset(unittest.TestCase):
8+
def setUp(self):
9+
self.systems = dpdata.MultiSystems()
10+
sing_sys = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar")
11+
for ii in range(10):
12+
self.systems.append(sing_sys.copy())
13+
14+
def test_split_dataset(self):
15+
train, test, test_idx = self.systems.train_test_split(0.2)
16+
self.assertEqual(
17+
train.get_nframes(), int(np.floor(self.systems.get_nframes() * 0.8))
18+
)
19+
self.assertEqual(
20+
test.get_nframes(), int(np.floor(self.systems.get_nframes() * 0.2))
21+
)
22+
self.assertEqual(
23+
sum([np.count_nonzero(x) for x in test_idx.values()]),
24+
int(np.floor(self.systems.get_nframes() * 0.2)),
25+
)

0 commit comments

Comments
 (0)