Skip to content

Commit 3605671

Browse files
committed
Finished the SVHN dataset
1 parent 47f419c commit 3605671

File tree

2 files changed

+67
-7
lines changed

2 files changed

+67
-7
lines changed

test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from torch.utils.data import Dataset
2+
from torchvision.datasets import SVHN
3+
4+
SVHN('data/', download=True)

utils/dataloaders/svhn.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,68 @@
1+
import os
2+
from scipy.io import loadmat
3+
import torch as th
4+
from torchvision import transforms
15
from torch.utils.data import Dataset
6+
from torchvision.datasets import SVHN
27

3-
4-
class SVHN(Dataset):
5-
def __init__(self):
8+
class SVHNDataset(Dataset):
9+
def __init__(self,
10+
datapath: str,
11+
transforms=None,
12+
download_data=True,
13+
split='train'):
14+
"""
15+
Initializes the SVHNDataset object.
16+
Args:
17+
datapath (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
18+
transforms: Torch composite of transformations which are to be applied to the dataset images.
19+
download_data (bool): If True, downloads the dataset to the specified datapath.
20+
split (str): The dataset split to use, either 'train' or 'test'.
21+
Raises:
22+
AssertionError: If the split is not 'train' or 'test'.
23+
"""
624
super().__init__()
7-
25+
assert split == 'train' or split == 'test'
26+
27+
if download_data:
28+
self._download_data(datapath, split)
29+
30+
data = loadmat(os.path.join(datapath, f'{split}_32x32.mat'))
31+
32+
# Images on the form N x H x W x C
33+
self.images = data['X'].transpose(3, 1, 0, 2)
34+
self.labels = data['y'].flatten()
35+
self.labels[self.labels == 10] = 0
36+
37+
self.transforms = transforms
38+
def _download_data(self, path: str, split: str):
39+
"""
40+
Downloads the SVHN dataset.
41+
Args:
42+
path (str): The directory where the dataset will be downloaded.
43+
split (str): The dataset split to download, either 'train' or 'test'.
44+
"""
45+
print(f'Downloading SVHN data into {path}')
46+
SVHN(path, split=split, download=True)
47+
848
def __len__(self):
9-
return
10-
49+
"""
50+
Returns the number of samples in the dataset.
51+
Returns:
52+
int: The number of samples.
53+
"""
54+
return len(self.labels)
1155
def __getitem__(self, index):
12-
return
56+
"""
57+
Retrieves the image and label at the specified index.
58+
Args:
59+
index (int): The index of the sample to retrieve.
60+
Returns:
61+
tuple: A tuple containing the image and its corresponding label.
62+
"""
63+
img, lab = self.images[index], self.labels[index]
64+
65+
if self.transforms is not None:
66+
img = self.transforms(img)
67+
68+
return img, lab

0 commit comments

Comments
 (0)