@@ -26,7 +26,7 @@ class USPSDataset0_6(Dataset):
2626 Args
2727 ----
2828 data_path : pathlib.Path
29- Path to the USPS dataset file .
29+ Path to the data directory .
3030 train : bool, optional
3131 Mode of the dataset.
3232 transform : callable, optional
@@ -60,18 +60,29 @@ class USPSDataset0_6(Dataset):
6060
6161 Examples
6262 --------
63+ >>> from torchvision import transforms
6364 >>> from src.datahandlers import USPSDataset0_6
64- >>> dataset = USPSDataset0_6(path="data/usps.h5", mode="train")
65+ >>> transform = transforms.Compose([
66+ ... transforms.Resize((16, 16)),
67+ ... transforms.ToTensor()
68+ ... ])
69+ >>> dataset = USPSDataset0_6(
70+ ... data_path="data",
71+ ... transform=transform
72+ ... download=True,
73+ ... train=True,
74+ ... )
6575 >>> len(dataset)
6676 5460
6777 >>> data, target = dataset[0]
6878 >>> data.shape
69- (16, 16)
79+ (1, 16, 16)
7080 >>> target
71- 6
81+ tensor([1., 0., 0., 0., 0., 0., 0.])
7282 """
7383
7484 filename = "usps.h5"
85+ num_classes = 7
7586
7687 def __init__ (
7788 self ,
@@ -85,7 +96,6 @@ def __init__(
8596 path = data_path if isinstance (data_path , Path ) else Path (data_path )
8697 self .filepath = path / self .filename
8798 self .transform = transform
88- self .num_classes = 7 # 0-6
8999 self .mode = "train" if train else "test"
90100
91101 # Download the dataset if it does not exist in a temporary directory
@@ -116,7 +126,24 @@ def _dataset_ok(self):
116126 return True
117127
118128 def download (self , url , filepath , checksum , mode ):
119- """Download the USPS dataset."""
129+ """Download the USPS dataset, and save it as an HDF5 file.
130+
131+ Args
132+ ----
133+ url : str
134+ URL to download the dataset from.
135+ filepath : pathlib.Path
136+ Path to save the downloaded dataset.
137+ checksum : str
138+ MD5 checksum of the downloaded file.
139+ mode : str
140+ Mode of the dataset, either train or test.
141+
142+ Raises
143+ ------
144+ ValueError
145+ If the checksum of the downloaded file does not match the expected checksum.
146+ """
120147
121148 def reporthook (blocknum , blocksize , totalsize ):
122149 """Report download progress."""
@@ -164,7 +191,20 @@ def reporthook(blocknum, blocksize, totalsize):
164191
165192 @staticmethod
166193 def check_integrity (filepath , checksum ):
167- """Check the integrity of the USPS dataset file."""
194+ """Check the integrity of the USPS dataset file.
195+
196+ Args
197+ ----
198+ filepath : pathlib.Path
199+ Path to the USPS dataset file.
200+ checksum : str
201+ MD5 checksum of the dataset file.
202+
203+ Returns
204+ -------
205+ bool
206+ True if the checksum of the file matches the expected checksum, False otherwise
207+ """
168208
169209 file_hash = hashlib .md5 (filepath .read_bytes ()).hexdigest ()
170210
0 commit comments