Skip to content

Commit 6a04084

Browse files
committed
Revert "Finished the SVHN dataset"
This reverts commit 3605671.
1 parent 6fb8ede commit 6a04084

File tree

2 files changed

+7
-67
lines changed

2 files changed

+7
-67
lines changed

test.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

utils/dataloaders/svhn.py

Lines changed: 7 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,12 @@
1-
import os
2-
from scipy.io import loadmat
3-
import torch as th
4-
from torchvision import transforms
51
from torch.utils.data import Dataset
6-
from torchvision.datasets import SVHN
72

8-
class SVHNDataset(Dataset):
9-
def __init__(self,
10-
datapath: str,
11-
transforms=None,
12-
download_data=True,
13-
split='train'):
14-
"""
15-
Initializes the SVHNDataset object.
16-
Args:
17-
datapath (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
18-
transforms: Torch composite of transformations which are to be applied to the dataset images.
19-
download_data (bool): If True, downloads the dataset to the specified datapath.
20-
split (str): The dataset split to use, either 'train' or 'test'.
21-
Raises:
22-
AssertionError: If the split is not 'train' or 'test'.
23-
"""
3+
4+
class SVHN(Dataset):
5+
def __init__(self):
246
super().__init__()
25-
assert split == 'train' or split == 'test'
26-
27-
if download_data:
28-
self._download_data(datapath, split)
29-
30-
data = loadmat(os.path.join(datapath, f'{split}_32x32.mat'))
31-
32-
# Images on the form N x H x W x C
33-
self.images = data['X'].transpose(3, 1, 0, 2)
34-
self.labels = data['y'].flatten()
35-
self.labels[self.labels == 10] = 0
36-
37-
self.transforms = transforms
38-
def _download_data(self, path: str, split: str):
39-
"""
40-
Downloads the SVHN dataset.
41-
Args:
42-
path (str): The directory where the dataset will be downloaded.
43-
split (str): The dataset split to download, either 'train' or 'test'.
44-
"""
45-
print(f'Downloading SVHN data into {path}')
46-
SVHN(path, split=split, download=True)
47-
7+
488
def __len__(self):
49-
"""
50-
Returns the number of samples in the dataset.
51-
Returns:
52-
int: The number of samples.
53-
"""
54-
return len(self.labels)
9+
return
10+
5511
def __getitem__(self, index):
56-
"""
57-
Retrieves the image and label at the specified index.
58-
Args:
59-
index (int): The index of the sample to retrieve.
60-
Returns:
61-
tuple: A tuple containing the image and its corresponding label.
62-
"""
63-
img, lab = self.images[index], self.labels[index]
64-
65-
if self.transforms is not None:
66-
img = self.transforms(img)
67-
68-
return img, lab
12+
return

0 commit comments

Comments
 (0)