11import os
22
3+ import h5py
34import numpy as np
5+ from PIL import Image
46from scipy .io import loadmat
57from torch .utils .data import Dataset
68from torchvision .datasets import SVHN
@@ -27,22 +29,23 @@ def __init__(
2729 AssertionError: If the split is not 'train' or 'test'.
2830 """
2931 super ().__init__ ()
32+ self .data_path = data_path
3033 self .split = "train" if train else "test"
3134
3235 if download :
3336 self ._download_data (data_path )
3437
35- data = loadmat (os .path .join (data_path , f"{ self .split } _32x32.mat" ))
36-
37- # Reform images to the form N x H x W x C
38- self .images = data ["X" ].transpose (3 , 1 , 0 , 2 )
39- self .labels = data ["y" ].flatten ()
40-
41- self .labels [self .labels == 10 ] = 0
42-
4338 self .nr_channels = nr_channels
4439 self .transforms = transform
40+
41+
42+ assert os .path .exists (os .path .join (self .data_path , f'svhn_{ self .split } data.h5' )), f'File svhn_{ self .split } data.h5 does not exists. Run download=True'
43+ with h5py .File (os .path .join (self .data_path , f'svhn_{ self .split } data.h5' ), 'r' ) as h5f :
44+ self .labels = h5f ['labels' ][:]
45+
4546 self .num_classes = len (np .unique (self .labels ))
47+
48+
4649
4750 def _download_data (self , path : str ):
4851 """
@@ -52,7 +55,17 @@ def _download_data(self, path: str):
5255 """
5356 print (f"Downloading SVHN data into { path } " )
5457 SVHN (path , split = self .split , download = True )
58+ data = loadmat (os .path .join (path , f'{ self .split } _32x32.mat' ))
5559
60+ images , labels = data ['X' ], data ['y' ]
61+ images = images .transpose (3 ,1 ,0 ,2 )
62+ labels [labels == 10 ] = 0
63+ labels = labels .flatten ()
64+
65+ with h5py .File (os .path .join (self .data_path , f'svhn_{ self .split } data.h5' ), 'w' ) as h5f :
66+ h5f .create_dataset ('images' , data = images )
67+ h5f .create_dataset ('labels' , data = labels )
68+
5669 def __len__ (self ):
5770 """
5871 Returns the number of samples in the dataset.
@@ -69,11 +82,15 @@ def __getitem__(self, index):
6982 Returns:
7083 tuple: A tuple containing the image and its corresponding label.
7184 """
72- img , lab = self .images [index ], self .labels [index ]
73-
85+ lab = self .labels [index ]
86+ with h5py .File (os .path .join (self .data_path , f'svhn_{ self .split } data.h5' ), 'r' ) as h5f :
87+ img = Image .fromarray (h5f ['images' ][index ])
88+
7489 if self .nr_channels == 1 :
75- img = np .mean (img , axis = 2 , keepdims = True )
90+ img = img .convert ('L' )
91+
7692 if self .transforms is not None :
7793 img = self .transforms (img )
7894
7995 return img , lab
96+
0 commit comments