1+ from torch .utils .data import Dataset
2+ import numpy as np
3+ import h5py
4+ from torchvision import transforms
5+ from PIL import Image
6+ import torch
7+
8+
9+ class USPSH5_Digit_7_9_Dataset (Dataset ):
10+ """
11+ Custom USPS dataset class that loads images with digits 7-9 from an .h5 file.
12+
13+ Parameters
14+ ----------
15+ h5_path : str
16+ Path to the USPS `.h5` file.
17+
18+ transform : callable, optional, default=None
19+ A transform function to apply on images. If None, no transformation is applied.
20+
21+ Attributes
22+ ----------
23+ images : numpy.ndarray
24+ The filtered images corresponding to digits 7-9.
25+
26+ labels : numpy.ndarray
27+ The filtered labels corresponding to digits 7-9.
28+
29+ transform : callable, optional
30+ A transform function to apply to the images.
31+ """
32+
33+ def __init__ (self , h5_path , mode , transform = None ):
34+ super ().__init__ ()
35+ """
36+ Initializes the USPS dataset by loading images and labels from the given `.h5` file.
37+
38+ Parameters
39+ ----------
40+ h5_path : str
41+ Path to the USPS `.h5` file.
42+
43+ transform : callable, optional, default=None
44+ A transform function to apply on images.
45+ """
46+
47+ self .transform = transform
48+ self .mode = mode
49+ self .h5_path = h5_path
50+ # Load the dataset from the HDF5 file
51+ with h5py .File (self .h5_path , "r" ) as hf :
52+ images = hf [self .mode ]["data" ][:]
53+ labels = hf [self .mode ]["target" ][:]
54+
55+ # Filter only digits 7, 8, and 9
56+ mask = np .isin (labels , [7 , 8 , 9 ])
57+ self .images = images [mask ]
58+ self .labels = labels [mask ]
59+
60+ def __len__ (self ):
61+ """
62+ Returns the total number of samples in the dataset.
63+
64+ Returns
65+ -------
66+ int
67+ The number of images in the dataset.
68+ """
69+ return len (self .images )
70+
71+ def __getitem__ (self , id ):
72+ """
73+ Returns a sample from the dataset given an index.
74+
75+ Parameters
76+ ----------
77+ idx : int
78+ The index of the sample to retrieve.
79+
80+ Returns
81+ -------
82+ tuple
83+ - image (PIL Image): The image at the specified index.
84+ - label (int): The label corresponding to the image.
85+ """
86+ # Convert to PIL Image (USPS images are typically grayscale 16x16)
87+ image = Image .fromarray (self .images [id ].astype (np .uint8 ), mode = "L" )
88+ label = int (self .labels [id ]) # Convert label to integer
89+
90+ if self .transform :
91+ image = self .transform (image )
92+
93+ return image , label
94+
95+
96+ def main ():
97+ # Example Usage:
98+ transform = transforms .Compose ([
99+ transforms .Resize ((16 , 16 )), # Ensure images are 16x16
100+ transforms .ToTensor (),
101+ transforms .Normalize ((0.5 ,), (0.5 ,)) # Normalize to [-1, 1]
102+ ])
103+
104+ # Load the dataset
105+ dataset = USPSH5_Digit_7_9_Dataset (h5_path = "C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5" , mode = "train" , transform = transform )
106+ data_loader = torch .utils .data .DataLoader (dataset , batch_size = 2 , shuffle = True )
107+ batch = next (iter (data_loader )) # grab a batch from the dataloader
108+ img , label = batch
109+ print (img .shape )
110+ print (label .shape )
111+
112+ # Check dataset size
113+ print (f"Dataset size: { len (dataset )} " )
114+
115+ if __name__ == '__main__' :
116+ main ()
0 commit comments