44This module contains the Dataset class for the USPS dataset with labels 0-6.
55"""
66
7+ import bz2
8+ import hashlib
79from pathlib import Path
10+ from tempfile import TemporaryDirectory
11+ from urllib .request import urlretrieve
812
913import h5py as h5
1014import numpy as np
15+ from PIL import Image
1116from torch .utils .data import Dataset
17+ from torchvision import transforms
18+
19+ from .datasources import USPS_SOURCE
1220
1321
1422class USPSDataset0_6 (Dataset ):
@@ -28,7 +36,7 @@ class USPSDataset0_6(Dataset):
2836
2937 Attributes
3038 ----------
31- path : pathlib.Path
39+ filepath : pathlib.Path
3240 Path to the USPS dataset file.
3341 mode : str
3442 Mode of the dataset, either train or test.
@@ -63,6 +71,8 @@ class USPSDataset0_6(Dataset):
6371 6
6472 """
6573
74+ filename = "usps.h5"
75+
6676 def __init__ (
6777 self ,
6878 data_path : Path ,
@@ -71,18 +81,97 @@ def __init__(
7181 download : bool = False ,
7282 ):
7383 super ().__init__ ()
74- self .path = data_path
84+
85+ path = data_path if isinstance (data_path , Path ) else Path (data_path )
86+ self .filepath = path / self .filename
7587 self .transform = transform
76- self .num_classes = 7
88+ self .num_classes = 7 # 0-6
89+ self .mode = "train" if train else "test"
7790
78- if download :
79- raise NotImplementedError ("Download functionality not implemented." )
91+ # Download the dataset if it does not exist in a temporary directory
92+ # to automatically clean up the downloaded file
93+ if download and not self ._dataset_ok ():
94+ url , _ , checksum = USPS_SOURCE [self .mode ]
95+
96+ print (f"Downloading USPS dataset ({ self .mode } )..." )
97+ self .download (url , self .filepath , checksum , self .mode )
8098
81- self .mode = "train" if train else "test"
8299 self .idx = self ._index ()
83100
101+ def _dataset_ok (self ):
102+ """Check if the dataset file exists and contains the required datasets."""
103+
104+ if not self .filepath .exists ():
105+ print (f"Dataset file { self .filepath } does not exist." )
106+ return False
107+
108+ with h5 .File (self .filepath , "r" ) as f :
109+ for mode in ["train" , "test" ]:
110+ if mode not in f :
111+ print (
112+ f"Dataset file { self .filepath } is missing the { mode } dataset."
113+ )
114+ return False
115+
116+ return True
117+
118+ def download (self , url , filepath , checksum , mode ):
119+ """Download the USPS dataset."""
120+
121+ def reporthook (blocknum , blocksize , totalsize ):
122+ """Report download progress."""
123+ denom = 1024 * 1024
124+ readsofar = blocknum * blocksize
125+ if totalsize > 0 :
126+ percent = readsofar * 1e2 / totalsize
127+ s = f"\r { int (percent ):^3} % { readsofar / denom :.2f} of { totalsize / denom :.2f} MB"
128+ print (s , end = "" , flush = True )
129+ if readsofar >= totalsize :
130+ print ()
131+
132+ # Download the dataset to a temporary file
133+ with TemporaryDirectory () as tmpdir :
134+ tmpdir = Path (tmpdir )
135+ tmpfile = tmpdir / "usps.bz2"
136+ urlretrieve (
137+ url ,
138+ tmpfile ,
139+ reporthook = reporthook ,
140+ )
141+
142+ # For fun we can check the integrity of the downloaded file
143+ if not self .check_integrity (tmpfile , checksum ):
144+ errmsg = (
145+ "The checksum of the downloaded file does "
146+ "not match the expected checksum."
147+ )
148+ raise ValueError (errmsg )
149+
150+ # Load the dataset and save it as an HDF5 file
151+ with bz2 .open (tmpfile ) as fp :
152+ raw = [line .decode ().split () for line in fp .readlines ()]
153+
154+ tmp = [[x .split (":" )[- 1 ] for x in data [1 :]] for data in raw ]
155+
156+ imgs = np .asarray (tmp , dtype = np .float32 )
157+ imgs = ((imgs + 1 ) / 2 * 255 ).astype (dtype = np .uint8 )
158+
159+ targets = [int (d [0 ]) - 1 for d in raw ]
160+
161+ with h5 .File (self .filepath , "a" ) as f :
162+ f .create_dataset (f"{ mode } /data" , data = imgs , dtype = np .float32 )
163+ f .create_dataset (f"{ mode } /target" , data = targets , dtype = np .int32 )
164+
165+ @staticmethod
166+ def check_integrity (filepath , checksum ):
167+ """Check the integrity of the USPS dataset file."""
168+
169+ file_hash = hashlib .md5 (filepath .read_bytes ()).hexdigest ()
170+
171+ return checksum == file_hash
172+
84173 def _index (self ):
85- with h5 .File (self .path , "r" ) as f :
174+ with h5 .File (self .filepath , "r" ) as f :
86175 labels = f [self .mode ]["target" ][:]
87176
88177 # Get indices of samples with labels 0-6
@@ -92,8 +181,8 @@ def _index(self):
92181 return idx
93182
94183 def _load_data (self , idx ):
95- with h5 .File (self .path , "r" ) as f :
96- data = f [self .mode ]["data" ][idx ]
184+ with h5 .File (self .filepath , "r" ) as f :
185+ data = f [self .mode ]["data" ][idx ]. astype ( np . uint8 )
97186 label = f [self .mode ]["target" ][idx ]
98187
99188 return data , label
@@ -103,16 +192,33 @@ def __len__(self):
103192
104193 def __getitem__ (self , idx ):
105194 data , target = self ._load_data (self .idx [idx ])
106-
107- data = data .reshape (16 , 16 )
195+ data = Image .fromarray (data , mode = "L" )
108196
109197 # one hot encode the target
110198 target = np .eye (self .num_classes , dtype = np .float32 )[target ]
111199
112- # Add channel dimension
113- data = np .expand_dims (data , axis = 0 )
114-
115200 if self .transform :
116201 data = self .transform (data )
117202
118203 return data , target
204+
205+
206+ if __name__ == "__main__" :
207+ # Example usage:
208+ transform = transforms .Compose (
209+ [
210+ transforms .Resize ((16 , 16 )),
211+ transforms .ToTensor (),
212+ ]
213+ )
214+
215+ dataset = USPSDataset0_6 (
216+ data_path = "data" ,
217+ train = True ,
218+ download = False ,
219+ transform = transform ,
220+ )
221+ print (len (dataset ))
222+ data , target = dataset [0 ]
223+ print (data .shape )
224+ print (target )
0 commit comments