@@ -16,27 +16,28 @@ def __init__(
1616 nr_channels = 3 ,
1717 ):
1818 """
19- Initializes the SVHNDataset object.
19+ Initializes the SVHNDataset object for loading the Street View House Numbers (SVHN) dataset .
2020 Args:
21- data_path (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
22- transforms: Torch composite of transformations which are to be applied to the dataset images.
23- download_data (bool): If True, downloads the dataset to the specified data_path.
24- split (str): The dataset split to use, either 'train' or 'test'.
21+ data_path (str): Path to where the data is stored. If `download` is set to True, this is where the data will be downloaded.
22+ train (bool): If True, loads the training split of the dataset; otherwise, loads the test split.
23+ transform (callable, optional): A function/transform to apply to the images.
24+ download (bool): If True, downloads the dataset to the specified `data_path`.
25+ nr_channels (int): Number of channels in the images. Default is 3 for RGB images.
2526 Raises:
2627 AssertionError: If the split is not 'train' or 'test'.
2728 """
2829 super ().__init__ ()
29- # assert split == "train" or split == "test"
3030 self .split = "train" if train else "test"
3131
3232 if download :
3333 self ._download_data (data_path )
3434
3535 data = loadmat (os .path .join (data_path , f"{ self .split } _32x32.mat" ))
3636
37- # Images on the form N x H x W x C
37+ # Reform images to the form N x H x W x C
3838 self .images = data ["X" ].transpose (3 , 1 , 0 , 2 )
3939 self .labels = data ["y" ].flatten ()
40+
4041 self .labels [self .labels == 10 ] = 0
4142
4243 self .nr_channels = nr_channels
@@ -45,13 +46,11 @@ def __init__(
4546
4647 def _download_data (self , path : str ):
4748 """
48- Downloads the SVHN dataset.
49+ Downloads the SVHN dataset to the specified directory .
4950 Args:
5051 path (str): The directory where the dataset will be downloaded.
51- split (str): The dataset split to download, either 'train' or 'test'.
5252 """
5353 print (f"Downloading SVHN data into { path } " )
54-
5554 SVHN (path , split = self .split , download = True )
5655
5756 def __len__ (self ):
@@ -74,7 +73,6 @@ def __getitem__(self, index):
7473
7574 if self .nr_channels == 1 :
7675 img = np .mean (img , axis = 2 , keepdims = True )
77-
7876 if self .transforms is not None :
7977 img = self .transforms (img )
8078
0 commit comments