|
3 | 3 | import os
|
4 | 4 | from copy import deepcopy
|
5 | 5 | from enum import Enum, unique
|
6 |
| -from typing import Any, Tuple, Union |
| 6 | +from typing import Any, Dict, Optional, Tuple, Union |
7 | 7 |
|
8 | 8 | import numpy as np
|
9 | 9 | from monty.json import MSONable
|
@@ -1584,6 +1584,63 @@ def correction(self, hl_sys: "MultiSystems"):
|
1584 | 1584 | corrected_sys.append(ll_ss.correction(hl_ss))
|
1585 | 1585 | return corrected_sys
|
1586 | 1586 |
|
| 1587 | + def train_test_split( |
| 1588 | + self, test_size: Union[float, int], seed: Optional[int] = None |
| 1589 | + ) -> Tuple["MultiSystems", "MultiSystems", Dict[str, np.ndarray]]: |
| 1590 | + """Split systems into random train and test subsets. |
| 1591 | +
|
| 1592 | + Parameters |
| 1593 | + ---------- |
| 1594 | + test_size : float or int |
| 1595 | + If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. |
| 1596 | + If int, represents the absolute number of test samples. |
| 1597 | + seed : int, default=None |
| 1598 | + Random seed |
| 1599 | +
|
| 1600 | + Returns |
| 1601 | + ------- |
| 1602 | + MultiSystems |
| 1603 | + The training set |
| 1604 | + MultiSystems |
| 1605 | + The testing set |
| 1606 | + Dict[str, np.ndarray] |
| 1607 | + The bool array of training and testing sets for each system. False for training set and True for testing set. |
| 1608 | + """ |
| 1609 | + nframes = self.get_nframes() |
| 1610 | + if isinstance(test_size, float): |
| 1611 | + assert 0 <= test_size <= 1 |
| 1612 | + test_size = int(np.floor(test_size * nframes)) |
| 1613 | + elif isinstance(test_size, int): |
| 1614 | + assert 0 <= test_size <= nframes |
| 1615 | + else: |
| 1616 | + raise RuntimeError("test_size should be float or int") |
| 1617 | + # get random indices |
| 1618 | + rng = np.random.default_rng(seed=seed) |
| 1619 | + test_idx = rng.choice(nframes, test_size, replace=False) |
| 1620 | + select_test = np.zeros(nframes, dtype=bool) |
| 1621 | + select_test[test_idx] = True |
| 1622 | + select_train = np.logical_not(select_test) |
| 1623 | + # flatten systems dict |
| 1624 | + system_names, system_sizes = zip( |
| 1625 | + *((kk, len(vv)) for (kk, vv) in self.systems.items()) |
| 1626 | + ) |
| 1627 | + system_idx = np.empty(len(system_sizes) + 1, dtype=int) |
| 1628 | + system_idx[0] = 0 |
| 1629 | + np.cumsum(system_sizes, out=system_idx[1:]) |
| 1630 | + # make new systems |
| 1631 | + train_systems = MultiSystems(type_map=self.atom_names) |
| 1632 | + test_systems = MultiSystems(type_map=self.atom_names) |
| 1633 | + test_system_idx = {} |
| 1634 | + for ii, nn in enumerate(system_names): |
| 1635 | + sub_train = self[nn][select_train[system_idx[ii] : system_idx[ii + 1]]] |
| 1636 | + if len(sub_train): |
| 1637 | + train_systems.append(sub_train) |
| 1638 | + sub_test = self[nn][select_test[system_idx[ii] : system_idx[ii + 1]]] |
| 1639 | + if len(sub_test): |
| 1640 | + test_systems.append(sub_test) |
| 1641 | + test_system_idx[nn] = select_test[system_idx[ii] : system_idx[ii + 1]] |
| 1642 | + return train_systems, test_systems, test_system_idx |
| 1643 | + |
1587 | 1644 |
|
1588 | 1645 | def get_cls_name(cls: object) -> str:
|
1589 | 1646 | """Returns the fully qualified name of a class, such as `np.ndarray`.
|
|
0 commit comments