1- from torch .utils .data import Dataset
1+ import os
2+
23from scipy .io import loadmat
3- import os
4+ from torch . utils . data import Dataset
45from torchvision .datasets import SVHN
56
7+
68class SVHNDataset (Dataset ):
7- def __init__ (self ,
8- datapath : str ,
9- transforms = None ,
10- download_data = True ,
11- split = 'train' ):
9+ def __init__ (
10+ self , datapath : str , transforms = None , download_data = True , split = "train"
11+ ):
1212 """
1313 Initializes the SVHNDataset object.
1414 Args:
@@ -20,36 +20,38 @@ def __init__(self,
2020 AssertionError: If the split is not 'train' or 'test'.
2121 """
2222 super ().__init__ ()
23- assert split == ' train' or split == ' test'
24-
23+ assert split == " train" or split == " test"
24+
2525 if download_data :
2626 self ._download_data (datapath , split )
27-
28- data = loadmat (os .path .join (datapath , f' { split } _32x32.mat' ))
29-
27+
28+ data = loadmat (os .path .join (datapath , f" { split } _32x32.mat" ))
29+
3030 # Images on the form N x H x W x C
31- self .images = data ['X' ].transpose (3 , 1 , 0 , 2 )
32- self .labels = data ['y' ].flatten ()
31+ self .images = data ["X" ].transpose (3 , 1 , 0 , 2 )
32+ self .labels = data ["y" ].flatten ()
3333 self .labels [self .labels == 10 ] = 0
34-
34+
3535 self .transforms = transforms
36+
3637 def _download_data (self , path : str , split : str ):
3738 """
3839 Downloads the SVHN dataset.
3940 Args:
4041 path (str): The directory where the dataset will be downloaded.
4142 split (str): The dataset split to download, either 'train' or 'test'.
4243 """
43- print (f' Downloading SVHN data into { path } ' )
44- SVHN (path , split = split , download = True )
45-
44+ print (f" Downloading SVHN data into { path } " )
45+ SVHN (path , split = split , download = True )
46+
4647 def __len__ (self ):
4748 """
4849 Returns the number of samples in the dataset.
4950 Returns:
5051 int: The number of samples.
5152 """
5253 return len (self .labels )
54+
5355 def __getitem__ (self , index ):
5456 """
5557 Retrieves the image and label at the specified index.
@@ -59,8 +61,8 @@ def __getitem__(self, index):
5961 tuple: A tuple containing the image and its corresponding label.
6062 """
6163 img , lab = self .images [index ], self .labels [index ]
62-
64+
6365 if self .transforms is not None :
6466 img = self .transforms (img )
65-
66- return img , lab
67+
68+ return img , lab
0 commit comments