1212
1313import h5py as h5
1414import numpy as np
15+ from PIL import Image
1516from torch .utils .data import Dataset
17+ from torchvision import transforms
1618
1719from .datasources import USPS_SOURCE
1820
@@ -88,18 +90,36 @@ def __init__(
8890
8991 # Download the dataset if it does not exist in a temporary directory
9092 # to automatically clean up the downloaded file
91- if download :
93+ if download and not self . _dataset_ok () :
9294 url , _ , checksum = USPS_SOURCE [self .mode ]
9395
9496 print (f"Downloading USPS dataset ({ self .mode } )..." )
9597 self .download (url , self .filepath , checksum , self .mode )
9698
9799 self .idx = self ._index ()
98100
101+ def _dataset_ok (self ):
102+ """Check if the dataset file exists and contains the required datasets."""
103+
104+ if not self .filepath .exists ():
105+ print (f"Dataset file { self .filepath } does not exist." )
106+ return False
107+
108+ with h5 .File (self .filepath , "r" ) as f :
109+ for mode in ["train" , "test" ]:
110+ if mode not in f :
111+ print (
112+ f"Dataset file { self .filepath } is missing the { mode } dataset."
113+ )
114+ return False
115+
116+ return True
117+
99118 def download (self , url , filepath , checksum , mode ):
100119 """Download the USPS dataset."""
101120
102121 def reporthook (blocknum , blocksize , totalsize ):
122+ """Report download progress."""
103123 denom = 1024 * 1024
104124 readsofar = blocknum * blocksize
105125 if totalsize > 0 :
@@ -109,6 +129,7 @@ def reporthook(blocknum, blocksize, totalsize):
109129 if readsofar >= totalsize :
110130 print ()
111131
132+ # Download the dataset to a temporary file
112133 with TemporaryDirectory () as tmpdir :
113134 tmpdir = Path (tmpdir )
114135 tmpfile = tmpdir / "usps.bz2"
@@ -137,7 +158,7 @@ def reporthook(blocknum, blocksize, totalsize):
137158
138159 targets = [int (d [0 ]) - 1 for d in raw ]
139160
140- with h5 .File (self .filepath , "w " ) as f :
161+ with h5 .File (self .filepath , "a " ) as f :
141162 f .create_dataset (f"{ mode } /data" , data = imgs , dtype = np .float32 )
142163 f .create_dataset (f"{ mode } /target" , data = targets , dtype = np .int32 )
143164
@@ -161,7 +182,7 @@ def _index(self):
161182
162183 def _load_data (self , idx ):
163184 with h5 .File (self .filepath , "r" ) as f :
164- data = f [self .mode ]["data" ][idx ]
185+ data = f [self .mode ]["data" ][idx ]. astype ( np . uint8 )
165186 label = f [self .mode ]["target" ][idx ]
166187
167188 return data , label
@@ -171,23 +192,32 @@ def __len__(self):
171192
172193 def __getitem__ (self , idx ):
173194 data , target = self ._load_data (self .idx [idx ])
174-
175- data = data .reshape (16 , 16 )
195+ data = Image .fromarray (data , mode = "L" )
176196
177197 # one hot encode the target
178198 target = np .eye (self .num_classes , dtype = np .float32 )[target ]
179199
180- # Add channel dimension
181- data = np .expand_dims (data , axis = 0 )
182-
183200 if self .transform :
184201 data = self .transform (data )
185202
186203 return data , target
187204
188205
189206if __name__ == "__main__" :
190- dataset = USPSDataset0_6 (data_path = "data" , train = True , download = True )
207+ # Example usage:
208+ transform = transforms .Compose (
209+ [
210+ transforms .Resize ((16 , 16 )),
211+ transforms .ToTensor (),
212+ ]
213+ )
214+
215+ dataset = USPSDataset0_6 (
216+ data_path = "data" ,
217+ train = True ,
218+ download = False ,
219+ transform = transform ,
220+ )
191221 print (len (dataset ))
192222 data , target = dataset [0 ]
193223 print (data .shape )
0 commit comments