11import os
22
3- import h5py
3+ import h5py
44import numpy as np
55from PIL import Image
66from scipy .io import loadmat
@@ -37,15 +37,16 @@ def __init__(
3737
3838 self .nr_channels = nr_channels
3939 self .transforms = transform
40-
41-
42- assert os .path .exists (os .path .join (self .data_path , f'svhn_{ self .split } data.h5' )), f'File svhn_{ self .split } data.h5 does not exists. Run download=True'
43- with h5py .File (os .path .join (self .data_path , f'svhn_{ self .split } data.h5' ), 'r' ) as h5f :
44- self .labels = h5f ['labels' ][:]
45-
40+
41+ assert os .path .exists (
42+ os .path .join (self .data_path , f"svhn_{ self .split } data.h5" )
43+ ), f"File svhn_{ self .split } data.h5 does not exists. Run download=True"
44+ with h5py .File (
45+ os .path .join (self .data_path , f"svhn_{ self .split } data.h5" ), "r"
46+ ) as h5f :
47+ self .labels = h5f ["labels" ][:]
48+
4649 self .num_classes = len (np .unique (self .labels ))
47-
48-
4950
5051 def _download_data (self , path : str ):
5152 """
@@ -55,17 +56,19 @@ def _download_data(self, path: str):
5556 """
5657 print (f"Downloading SVHN data into { path } " )
5758 SVHN (path , split = self .split , download = True )
58- data = loadmat (os .path .join (path , f' { self .split } _32x32.mat' ))
59+ data = loadmat (os .path .join (path , f" { self .split } _32x32.mat" ))
5960
60- images , labels = data ['X' ], data ['y' ]
61- images = images .transpose (3 ,1 , 0 , 2 )
61+ images , labels = data ["X" ], data ["y" ]
62+ images = images .transpose (3 , 1 , 0 , 2 )
6263 labels [labels == 10 ] = 0
6364 labels = labels .flatten ()
64-
65- with h5py .File (os .path .join (self .data_path , f'svhn_{ self .split } data.h5' ), 'w' ) as h5f :
66- h5f .create_dataset ('images' , data = images )
67- h5f .create_dataset ('labels' , data = labels )
68-
65+
66+ with h5py .File (
67+ os .path .join (self .data_path , f"svhn_{ self .split } data.h5" ), "w"
68+ ) as h5f :
69+ h5f .create_dataset ("images" , data = images )
70+ h5f .create_dataset ("labels" , data = labels )
71+
6972 def __len__ (self ):
7073 """
7174 Returns the number of samples in the dataset.
@@ -83,14 +86,15 @@ def __getitem__(self, index):
8386 tuple: A tuple containing the image and its corresponding label.
8487 """
8588 lab = self .labels [index ]
86- with h5py .File (os .path .join (self .data_path , f'svhn_{ self .split } data.h5' ), 'r' ) as h5f :
87- img = Image .fromarray (h5f ['images' ][index ])
88-
89+ with h5py .File (
90+ os .path .join (self .data_path , f"svhn_{ self .split } data.h5" ), "r"
91+ ) as h5f :
92+ img = Image .fromarray (h5f ["images" ][index ])
93+
8994 if self .nr_channels == 1 :
90- img = img .convert ('L' )
91-
95+ img = img .convert ("L" )
96+
9297 if self .transforms is not None :
9398 img = self .transforms (img )
9499
95100 return img , lab
96-
0 commit comments