|
1 | | -import os |
2 | | -from scipy.io import loadmat |
3 | | -import torch as th |
4 | | -from torchvision import transforms |
5 | 1 | from torch.utils.data import Dataset |
6 | | -from torchvision.datasets import SVHN |
7 | 2 |
|
8 | | -class SVHNDataset(Dataset): |
9 | | - def __init__(self, |
10 | | - datapath: str, |
11 | | - transforms=None, |
12 | | - download_data=True, |
13 | | - split='train'): |
14 | | - """ |
15 | | - Initializes the SVHNDataset object. |
16 | | - Args: |
17 | | - datapath (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded. |
18 | | - transforms: Torch composite of transformations which are to be applied to the dataset images. |
19 | | - download_data (bool): If True, downloads the dataset to the specified datapath. |
20 | | - split (str): The dataset split to use, either 'train' or 'test'. |
21 | | - Raises: |
22 | | - AssertionError: If the split is not 'train' or 'test'. |
23 | | - """ |
| 3 | + |
| 4 | +class SVHN(Dataset): |
| 5 | + def __init__(self): |
24 | 6 | super().__init__() |
25 | | - assert split == 'train' or split == 'test' |
26 | | - |
27 | | - if download_data: |
28 | | - self._download_data(datapath, split) |
29 | | - |
30 | | - data = loadmat(os.path.join(datapath, f'{split}_32x32.mat')) |
31 | | - |
32 | | - # Images on the form N x H x W x C |
33 | | - self.images = data['X'].transpose(3, 1, 0, 2) |
34 | | - self.labels = data['y'].flatten() |
35 | | - self.labels[self.labels == 10] = 0 |
36 | | - |
37 | | - self.transforms = transforms |
38 | | - def _download_data(self, path: str, split: str): |
39 | | - """ |
40 | | - Downloads the SVHN dataset. |
41 | | - Args: |
42 | | - path (str): The directory where the dataset will be downloaded. |
43 | | - split (str): The dataset split to download, either 'train' or 'test'. |
44 | | - """ |
45 | | - print(f'Downloading SVHN data into {path}') |
46 | | - SVHN(path, split=split, download=True) |
47 | | - |
| 7 | + |
48 | 8 | def __len__(self): |
49 | | - """ |
50 | | - Returns the number of samples in the dataset. |
51 | | - Returns: |
52 | | - int: The number of samples. |
53 | | - """ |
54 | | - return len(self.labels) |
| 9 | + return |
| 10 | + |
55 | 11 | def __getitem__(self, index): |
56 | | - """ |
57 | | - Retrieves the image and label at the specified index. |
58 | | - Args: |
59 | | - index (int): The index of the sample to retrieve. |
60 | | - Returns: |
61 | | - tuple: A tuple containing the image and its corresponding label. |
62 | | - """ |
63 | | - img, lab = self.images[index], self.labels[index] |
64 | | - |
65 | | - if self.transforms is not None: |
66 | | - img = self.transforms(img) |
67 | | - |
68 | | - return img, lab |
| 12 | + return |
0 commit comments