11import os
2+
23import numpy as np
34from scipy .io import loadmat
45from torch .utils .data import Dataset
78
89class SVHNDataset (Dataset ):
910 def __init__ (
10- self ,
11- data_path : str ,
11+ self ,
12+ data_path : str ,
1213 train : bool ,
13- transform = None ,
14- download :bool = True ,
15- nr_channels = 3
16- ):
14+ transform = None ,
15+ download : bool = True ,
16+ nr_channels = 3 ,
17+ ):
1718 """
1819 Initializes the SVHNDataset object.
1920 Args:
@@ -26,8 +27,8 @@ def __init__(
2627 """
2728 super ().__init__ ()
2829 # assert split == "train" or split == "test"
29- self .split = ' train' if train else ' test'
30-
30+ self .split = " train" if train else " test"
31+
3132 if download :
3233 self ._download_data (data_path )
3334
@@ -37,7 +38,7 @@ def __init__(
3738 self .images = data ["X" ].transpose (3 , 1 , 0 , 2 )
3839 self .labels = data ["y" ].flatten ()
3940 self .labels [self .labels == 10 ] = 0
40-
41+
4142 self .nr_channels = nr_channels
4243 self .transforms = transform
4344
@@ -49,7 +50,7 @@ def _download_data(self, path: str):
4950 split (str): The dataset split to download, either 'train' or 'test'.
5051 """
5152 print (f"Downloading SVHN data into { path } " )
52-
53+
5354 SVHN (path , split = self .split , download = True )
5455
5556 def __len__ (self ):
@@ -72,7 +73,7 @@ def __getitem__(self, index):
7273
7374 if self .nr_channels == 1 :
7475 img = np .mean (img , axis = 2 , keepdims = True )
75-
76+
7677 if self .transforms is not None :
7778 img = self .transforms (img )
7879
0 commit comments