|
| 1 | +import os |
| 2 | +import numpy as np |
| 3 | +from scipy.io import loadmat |
1 | 4 | from torch.utils.data import Dataset |
| 5 | +from torchvision.datasets import SVHN |
2 | 6 |
|
3 | 7 |
|
4 | | -class SVHN(Dataset): |
5 | | - def __init__(self): |
| 8 | +class SVHNDataset(Dataset): |
| 9 | + def __init__( |
| 10 | + self, |
| 11 | + data_path: str, |
| 12 | + train: bool, |
| 13 | + transform=None, |
| 14 | + download:bool=True, |
| 15 | + nr_channels=3 |
| 16 | + ): |
| 17 | + """ |
| 18 | + Initializes the SVHNDataset object. |
| 19 | + Args: |
| 20 | + data_path (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded. |
| 21 | + transforms: Torch composite of transformations which are to be applied to the dataset images. |
| 22 | + download_data (bool): If True, downloads the dataset to the specified data_path. |
| 23 | + split (str): The dataset split to use, either 'train' or 'test'. |
| 24 | + Raises: |
| 25 | + AssertionError: If the split is not 'train' or 'test'. |
| 26 | + """ |
6 | 27 | super().__init__() |
| 28 | + # assert split == "train" or split == "test" |
| 29 | + self.split = 'train' if train else 'test' |
| 30 | + |
| 31 | + if download: |
| 32 | + self._download_data(data_path) |
| 33 | + |
| 34 | + data = loadmat(os.path.join(data_path, f"{self.split}_32x32.mat")) |
| 35 | + |
| 36 | + # Images on the form N x H x W x C |
| 37 | + self.images = data["X"].transpose(3, 1, 0, 2) |
| 38 | + self.labels = data["y"].flatten() |
| 39 | + self.labels[self.labels == 10] = 0 |
| 40 | + |
| 41 | + self.nr_channels = nr_channels |
| 42 | + self.transforms = transform |
| 43 | + |
| 44 | + def _download_data(self, path: str): |
| 45 | + """ |
| 46 | + Downloads the SVHN dataset. |
| 47 | + Args: |
| 48 | + path (str): The directory where the dataset will be downloaded. |
| 49 | + split (str): The dataset split to download, either 'train' or 'test'. |
| 50 | + """ |
| 51 | + print(f"Downloading SVHN data into {path}") |
| 52 | + |
| 53 | + SVHN(path, split=self.split, download=True) |
7 | 54 |
|
8 | 55 | def __len__(self): |
9 | | - return |
| 56 | + """ |
| 57 | + Returns the number of samples in the dataset. |
| 58 | + Returns: |
| 59 | + int: The number of samples. |
| 60 | + """ |
| 61 | + return len(self.labels) |
10 | 62 |
|
11 | 63 | def __getitem__(self, index): |
12 | | - return |
| 64 | + """ |
| 65 | + Retrieves the image and label at the specified index. |
| 66 | + Args: |
| 67 | + index (int): The index of the sample to retrieve. |
| 68 | + Returns: |
| 69 | + tuple: A tuple containing the image and its corresponding label. |
| 70 | + """ |
| 71 | + img, lab = self.images[index], self.labels[index] |
| 72 | + |
| 73 | + if self.nr_channels == 1: |
| 74 | + img = np.mean(img, axis=2, keepdims=True) |
| 75 | + |
| 76 | + if self.transforms is not None: |
| 77 | + img = self.transforms(img) |
| 78 | + |
| 79 | + return img, lab |
0 commit comments