|
7 | 7 |
|
8 | 8 | class SVHNDataset(Dataset): |
9 | 9 | def __init__( |
10 | | - self, datapath: str, |
11 | | - transforms=None, |
12 | | - download_data=True, |
| 10 | + self, |
| 11 | + data_path: str, |
| 12 | + train: bool, |
| 13 | + transform=None, |
| 14 | + download:bool=True, |
13 | 15 | nr_channels=3 |
14 | 16 | ): |
15 | 17 | """ |
16 | 18 | Initializes the SVHNDataset object. |
17 | 19 | Args: |
18 | | - datapath (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded. |
| 20 | + data_path (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded. |
19 | 21 | transforms: Torch composite of transformations which are to be applied to the dataset images. |
20 | | - download_data (bool): If True, downloads the dataset to the specified datapath. |
| 22 | + download_data (bool): If True, downloads the dataset to the specified data_path. |
21 | 23 | split (str): The dataset split to use, either 'train' or 'test'. |
22 | 24 | Raises: |
23 | 25 | AssertionError: If the split is not 'train' or 'test'. |
24 | 26 | """ |
25 | 27 | super().__init__() |
26 | 28 | # assert split == "train" or split == "test" |
| 29 | + self.split = 'train' if train else 'test' |
| 30 | + |
| 31 | + if download: |
| 32 | + self._download_data(data_path) |
27 | 33 |
|
28 | | - if download_data: |
29 | | - self._download_data(datapath) |
30 | | - |
31 | | - data = loadmat(os.path.join(datapath, f"train_32x32.mat")) |
| 34 | + data = loadmat(os.path.join(data_path, f"{self.split}_32x32.mat")) |
32 | 35 |
|
33 | 36 | # Images on the form N x H x W x C |
34 | 37 | self.images = data["X"].transpose(3, 1, 0, 2) |
35 | 38 | self.labels = data["y"].flatten() |
36 | 39 | self.labels[self.labels == 10] = 0 |
37 | 40 |
|
38 | 41 | self.nr_channels = nr_channels |
39 | | - self.transforms = transforms |
| 42 | + self.transforms = transform |
40 | 43 |
|
41 | | - def _download_data(self, path: str, split: str): |
| 44 | + def _download_data(self, path: str): |
42 | 45 | """ |
43 | 46 | Downloads the SVHN dataset. |
44 | 47 | Args: |
45 | 48 | path (str): The directory where the dataset will be downloaded. |
46 | 49 | split (str): The dataset split to download, either 'train' or 'test'. |
47 | 50 | """ |
48 | 51 | print(f"Downloading SVHN data into {path}") |
49 | | - SVHN(path, split='train', download=True) |
| 52 | + |
| 53 | + SVHN(path, split=self.split, download=True) |
50 | 54 |
|
51 | 55 | def __len__(self): |
52 | 56 | """ |
|
0 commit comments