44This module contains the Dataset class for the USPS dataset with labels 0-6.
55"""
66
7+ import bz2
8+ import hashlib
79from pathlib import Path
10+ from tempfile import TemporaryDirectory , TemporaryFile
11+ from urllib .request import urlretrieve
812
913import h5py as h5
1014import numpy as np
1115from torch .utils .data import Dataset
1216
17+ from .datasources import USPS_SOURCE
18+
1319
1420class USPSDataset0_6 (Dataset ):
1521 """
@@ -28,7 +34,7 @@ class USPSDataset0_6(Dataset):
2834
2935 Attributes
3036 ----------
31- path : pathlib.Path
37+ filepath : pathlib.Path
3238 Path to the USPS dataset file.
3339 mode : str
3440 Mode of the dataset, either train or test.
@@ -63,6 +69,8 @@ class USPSDataset0_6(Dataset):
6369 6
6470 """
6571
72+ filename = "usps.h5"
73+
6674 def __init__ (
6775 self ,
6876 data_path : Path ,
@@ -71,18 +79,78 @@ def __init__(
7179 download : bool = False ,
7280 ):
7381 super ().__init__ ()
74- self .path = data_path
82+
83+ path = data_path if isinstance (data_path , Path ) else Path (data_path )
84+ self .filepath = path / self .filename
7585 self .transform = transform
76- self .num_classes = 7
86+ self .num_classes = 7 # 0-6
87+ self .mode = "train" if train else "test"
7788
89+ # Download the dataset if it does not exist in a temporary directory
90+ # to automatically clean up the downloaded file
7891 if download :
79- raise NotImplementedError ("Download functionality not implemented." )
92+ url , _ , checksum = USPS_SOURCE [self .mode ]
93+
94+ print (f"Downloading USPS dataset ({ self .mode } )..." )
95+ self .download (url , self .filepath , checksum , self .mode )
8096
81- self .mode = "train" if train else "test"
8297 self .idx = self ._index ()
8398
99+ def download (self , url , filepath , checksum , mode ):
100+ """Download the USPS dataset."""
101+
102+ def reporthook (blocknum , blocksize , totalsize ):
103+ denom = 1024 * 1024
104+ readsofar = blocknum * blocksize
105+ if totalsize > 0 :
106+ percent = readsofar * 1e2 / totalsize
107+ s = f"\r { int (percent ):^3} % { readsofar / denom :.2f} of { totalsize / denom :.2f} MB"
108+ print (s , end = "" , flush = True )
109+ if readsofar >= totalsize :
110+ print ()
111+
112+ with TemporaryDirectory () as tmpdir :
113+ tmpdir = Path (tmpdir )
114+ tmpfile = tmpdir / "usps.bz2"
115+ urlretrieve (
116+ url ,
117+ tmpfile ,
118+ reporthook = reporthook ,
119+ )
120+
121+ # For fun we can check the integrity of the downloaded file
122+ if not self .check_integrity (tmpfile , checksum ):
123+ errmsg = (
124+ "The checksum of the downloaded file does "
125+ "not match the expected checksum."
126+ )
127+ raise ValueError (errmsg )
128+
129+ # Load the dataset and save it as an HDF5 file
130+ with bz2 .open (tmpfile ) as fp :
131+ raw = [line .decode ().split () for line in fp .readlines ()]
132+
133+ tmp = [[x .split (":" )[- 1 ] for x in data [1 :]] for data in raw ]
134+
135+ imgs = np .asarray (tmp , dtype = np .float32 )
136+ imgs = ((imgs + 1 ) / 2 * 255 ).astype (dtype = np .uint8 )
137+
138+ targets = [int (d [0 ]) - 1 for d in raw ]
139+
140+ with h5 .File (self .filepath , "w" ) as f :
141+ f .create_dataset (f"{ mode } /data" , data = imgs , dtype = np .float32 )
142+ f .create_dataset (f"{ mode } /target" , data = targets , dtype = np .int32 )
143+
144+ @staticmethod
145+ def check_integrity (filepath , checksum ):
146+ """Check the integrity of the USPS dataset file."""
147+
148+ file_hash = hashlib .md5 (filepath .read_bytes ()).hexdigest ()
149+
150+ return checksum == file_hash
151+
84152 def _index (self ):
85- with h5 .File (self .path , "r" ) as f :
153+ with h5 .File (self .filepath , "r" ) as f :
86154 labels = f [self .mode ]["target" ][:]
87155
88156 # Get indices of samples with labels 0-6
@@ -92,7 +160,7 @@ def _index(self):
92160 return idx
93161
94162 def _load_data (self , idx ):
95- with h5 .File (self .path , "r" ) as f :
163+ with h5 .File (self .filepath , "r" ) as f :
96164 data = f [self .mode ]["data" ][idx ]
97165 label = f [self .mode ]["target" ][idx ]
98166
@@ -116,3 +184,11 @@ def __getitem__(self, idx):
116184 data = self .transform (data )
117185
118186 return data , target
187+
188+
189+ if __name__ == "__main__" :
190+ dataset = USPSDataset0_6 (data_path = "data" , train = True , download = True )
191+ print (len (dataset ))
192+ data , target = dataset [0 ]
193+ print (data .shape )
194+ print (target )
0 commit comments