11from torch .utils .data import Dataset
2+ from scipy .io import loadmat
3+ import os
4+ from torchvision .datasets import SVHN
25
3-
4- class SVHN (Dataset ):
5- def __init__ (self ):
6+ class SVHNDataset (Dataset ):
7+ def __init__ (self ,
8+ datapath : str ,
9+ transforms = None ,
10+ download_data = True ,
11+ split = 'train' ):
12+ """
13+ Initializes the SVHNDataset object.
14+ Args:
15+ datapath (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
16+ transforms: Torch composite of transformations which are to be applied to the dataset images.
17+ download_data (bool): If True, downloads the dataset to the specified datapath.
18+ split (str): The dataset split to use, either 'train' or 'test'.
19+ Raises:
20+ AssertionError: If the split is not 'train' or 'test'.
21+ """
622 super ().__init__ ()
7-
23+ assert split == 'train' or split == 'test'
24+
25+ if download_data :
26+ self ._download_data (datapath , split )
27+
28+ data = loadmat (os .path .join (datapath , f'{ split } _32x32.mat' ))
29+
30+ # Images on the form N x H x W x C
31+ self .images = data ['X' ].transpose (3 , 1 , 0 , 2 )
32+ self .labels = data ['y' ].flatten ()
33+ self .labels [self .labels == 10 ] = 0
34+
35+ self .transforms = transforms
36+ def _download_data (self , path : str , split : str ):
37+ """
38+ Downloads the SVHN dataset.
39+ Args:
40+ path (str): The directory where the dataset will be downloaded.
41+ split (str): The dataset split to download, either 'train' or 'test'.
42+ """
43+ print (f'Downloading SVHN data into { path } ' )
44+ SVHN (path , split = split , download = True )
45+
846 def __len__ (self ):
9- return
10-
47+ """
48+ Returns the number of samples in the dataset.
49+ Returns:
50+ int: The number of samples.
51+ """
52+ return len (self .labels )
1153 def __getitem__ (self , index ):
12- return
54+ """
55+ Retrieves the image and label at the specified index.
56+ Args:
57+ index (int): The index of the sample to retrieve.
58+ Returns:
59+ tuple: A tuple containing the image and its corresponding label.
60+ """
61+ img , lab = self .images [index ], self .labels [index ]
62+
63+ if self .transforms is not None :
64+ img = self .transforms (img )
65+
66+ return img , lab
0 commit comments