11import os
2-
2+ import numpy as np
33from scipy .io import loadmat
44from torch .utils .data import Dataset
55from torchvision .datasets import SVHN
66
77
88class SVHNDataset (Dataset ):
99 def __init__ (
10- self , datapath : str ,
11- transforms = None ,
12- download_data = True ,
13- split = "train"
10+ self ,
11+ data_path : str ,
12+ train : bool ,
13+ transform = None ,
14+ download :bool = True ,
15+ nr_channels = 3
1416 ):
1517 """
1618 Initializes the SVHNDataset object.
1719 Args:
18- datapath (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
20+ data_path (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
1921 transforms: Torch composite of transformations which are to be applied to the dataset images.
20- download_data (bool): If True, downloads the dataset to the specified datapath .
22+ download_data (bool): If True, downloads the dataset to the specified data_path .
2123 split (str): The dataset split to use, either 'train' or 'test'.
2224 Raises:
2325 AssertionError: If the split is not 'train' or 'test'.
2426 """
2527 super ().__init__ ()
26- assert split == "train" or split == "test"
27-
28- if download_data :
29- self ._download_data (datapath , split )
28+ # assert split == "train" or split == "test"
29+ self .split = 'train' if train else 'test'
30+
31+ if download :
32+ self ._download_data (data_path )
3033
31- data = loadmat (os .path .join (datapath , f"{ split } _32x32.mat" ))
34+ data = loadmat (os .path .join (data_path , f"{ self . split } _32x32.mat" ))
3235
3336 # Images on the form N x H x W x C
3437 self .images = data ["X" ].transpose (3 , 1 , 0 , 2 )
3538 self .labels = data ["y" ].flatten ()
3639 self .labels [self .labels == 10 ] = 0
40+
41+ self .nr_channels = nr_channels
42+ self .transforms = transform
3743
38- self .transforms = transforms
39-
40- def _download_data (self , path : str , split : str ):
44+ def _download_data (self , path : str ):
4145 """
4246 Downloads the SVHN dataset.
4347 Args:
4448 path (str): The directory where the dataset will be downloaded.
4549 split (str): The dataset split to download, either 'train' or 'test'.
4650 """
4751 print (f"Downloading SVHN data into { path } " )
48- SVHN (path , split = split , download = True )
52+
53+ SVHN (path , split = self .split , download = True )
4954
5055 def __len__ (self ):
5156 """
@@ -65,6 +70,9 @@ def __getitem__(self, index):
6570 """
6671 img , lab = self .images [index ], self .labels [index ]
6772
73+ if self .nr_channels == 1 :
74+ img = np .mean (img , axis = 2 , keepdims = True )
75+
6876 if self .transforms is not None :
6977 img = self .transforms (img )
7078
0 commit comments