1- from torch .utils .data import Dataset
1+ from torch .utils .data import Dataset , random_split
22
3- from .dataloaders import MNISTDataset0_3 , USPSDataset0_6 , USPSH5_Digit_7_9_Dataset
3+ from .dataloaders import (
4+ Downloader ,
5+ MNISTDataset0_3 ,
6+ USPSDataset0_6 ,
7+ USPSH5_Digit_7_9_Dataset ,
8+ )
49
510
6- def load_data (dataset : str , * args , ** kwargs ) -> Dataset :
11+ def filter_labels (samples : list , wanted_labels : list ) -> list :
12+ return list (filter (lambda x : x in wanted_labels , samples ))
13+
14+
15+ def load_data (dataset : str , * args , ** kwargs ) -> tuple :
716 """
8- Load the dataset based on the dataset name.
17+ load the dataset based on the dataset name.
918
1019 Args
1120 ----
@@ -18,8 +27,8 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
1827
1928 Returns
2029 -------
21- dataset : torch.utils.data.Dataset
22- Dataset object .
30+ tuple
31+ Tuple of train, validation and test datasets .
2332
2433 Raises
2534 ------
@@ -28,17 +37,54 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
2837
2938 Examples
3039 --------
31- >>> from utils import load_data
32- >>> dataset = load_data ("usps_0-6", data_path="data", train=True, download=True)
33- >>> len(dataset )
34- 5460
40+ >>> from utils import setup_data
41+ >>> train, val, test = setup_data ("usps_0-6", data_path="data", train=True, download=True)
42+ >>> len(train), len(val), len(test )
43+ (4914, 546, 1782)
3544 """
45+
3646 match dataset .lower ():
3747 case "usps_0-6" :
38- return USPSDataset0_6 ( * args , ** kwargs )
39- case "mnist_0-3" :
40- return MNISTDataset0_3 ( * args , ** kwargs )
48+ dataset = USPSDataset0_6
49+ train_samples , test_samples = Downloader . usps ( * args , ** kwargs )
50+ labels = range ( 7 )
4151 case "usps_7-9" :
42- return USPSH5_Digit_7_9_Dataset (* args , ** kwargs )
52+ dataset = USPSH5_Digit_7_9_Dataset
53+ train_samples , test_samples = Downloader .usps (* args , ** kwargs )
54+ labels = range (7 , 10 )
55+ case "mnist_0-3" :
56+ dataset = MNISTDataset0_3
57+ train_samples , test_samples = Downloader .mnist (* args , ** kwargs )
58+ labels = range (4 )
4359 case _:
4460 raise NotImplementedError (f"Dataset: { dataset } not implemented." )
61+
62+ val_size = kwargs .get ("val_size" , 0.1 )
63+
64+ train_samples = filter_labels (train_samples , labels )
65+ test_samples = filter_labels (test_samples , labels )
66+
67+ train_samples , val_samples = random_split (train_samples , [1 - val_size , val_size ])
68+
69+ train = dataset (
70+ * args ,
71+ sample_ids = train_samples ,
72+ train = True ,
73+ ** kwargs ,
74+ )
75+
76+ val = dataset (
77+ * args ,
78+ sample_ids = val_samples ,
79+ train = True ,
80+ ** kwargs ,
81+ )
82+
83+ test = dataset (
84+ * args ,
85+ sample_ids = test_samples ,
86+ train = False ,
87+ ** kwargs ,
88+ )
89+
90+ return train , val , test
0 commit comments