@@ -31,7 +31,7 @@ def __init__(
3131 """
3232 super ().__init__ ()
3333
34- self .data_path = data_path
34+ self .data_path = data_path / "SVHN"
3535 self .indexes = sample_ids
3636 self .split = "train" if train else "test"
3737
@@ -41,7 +41,7 @@ def __init__(
4141 if not os .path .exists (
4242 os .path .join (self .data_path , f"svhn_{ self .split } data.h5" )
4343 ):
44- self ._download_data (self .data_path )
44+ self ._create_h5py (self .data_path )
4545
4646 assert os .path .exists (
4747 os .path .join (self .data_path , f"svhn_{ self .split } data.h5" )
@@ -53,15 +53,14 @@ def __init__(
5353
5454 self .num_classes = len (np .unique (self .labels ))
5555
56- def _download_data (self , path : str ):
56+ def _create_h5py (self , path : str ):
5757 """
5858 Downloads the SVHN dataset to the specified directory.
5959 Args:
6060 path (str): The directory where the dataset will be downloaded.
6161 """
6262 print (f"Downloading SVHN data into { path } " )
6363
64- SVHN (path , split = self .split , download = True )
6564 data = loadmat (os .path .join (path , f"{ self .split } _32x32.mat" ))
6665
6766 images , labels = data ["X" ], data ["y" ]
0 commit comments