You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: CollaborativeCoding/dataloaders/uspsh5_7_9.py
+52-16Lines changed: 52 additions & 16 deletions
Original file line number
Diff line number
Diff line change
@@ -8,26 +8,45 @@
8
8
9
9
classUSPSH5_Digit_7_9_Dataset(Dataset):
10
10
"""
11
-
Custom USPS dataset class that loads images with digits 7-9 from an .h5 file.
11
+
This class loads a subset of the USPS dataset, specifically images of digits 7, 8, and 9, from an HDF5 file.
12
+
It allows for applying transformations to the images and provides methods to retrieve images and their corresponding labels.
12
13
13
14
Parameters
14
15
----------
15
-
h5_path : str
16
-
Path to the USPS `.h5` file.
16
+
data_path : str or Path
17
+
Path to the directory containing the USPS `.h5` file. This file should contain the data in the "train" or "test" group.
18
+
19
+
sample_ids : list of int
20
+
A list of sample indices to be used from the dataset. This allows for filtering or selecting a subset of the full dataset.
21
+
22
+
train : bool, optional, default=False
23
+
If `True`, the dataset is loaded in training mode (using the "train" group). If `False`, the dataset is loaded in test mode (using the "test" group).
17
24
18
25
transform : callable, optional, default=None
19
-
A transform function to apply on images. If None, no transformation is applied.
26
+
A transformation function to apply to each image. If `None`, no transformation is applied. Typically used for data augmentation or normalization.
27
+
28
+
nr_channels : int, optional, default=1
29
+
The number of channels in the image. USPS images are typically grayscale, so this should generally be set to 1. This parameter allows for potential future flexibility.
20
30
21
31
Attributes
22
32
----------
23
33
images : numpy.ndarray
24
-
The filtered images corresponding to digits 7-9.
34
+
Array of images corresponding to digits 7, 8, and 9 from the USPS dataset. The images are loaded from the HDF5 file and filtered based on the labels.
25
35
26
36
labels : numpy.ndarray
27
-
The filtered labels corresponding to digits 7-9.
37
+
Array of labels corresponding to the images. Only labels of digits 7, 8, and 9 are retained, and they are mapped to 0, 1, and 2 for classification tasks.
28
38
29
39
transform : callable, optional
30
-
A transform function to apply to the images.
40
+
A transformation function to apply to the images. This is passed as an argument during initialization.
41
+
42
+
label_shift : function
43
+
A function to shift the labels for classification purposes. It maps the original labels (7, 8, 9) to (0, 1, 2).
44
+
45
+
label_restore : function
46
+
A function to restore the original labels (7, 8, 9) from the shifted labels (0, 1, 2).
47
+
48
+
num_classes : int
49
+
The number of unique labels in the dataset, which is 3 (for digits 7, 8, and 9).
31
50
"""
32
51
33
52
def__init__(
@@ -36,14 +55,25 @@ def __init__(
36
55
super().__init__()
37
56
"""
38
57
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
39
-
58
+
59
+
The dataset is filtered to only include images of digits 7, 8, and 9, which are mapped to labels 0, 1, and 2 respectively for classification purposes.
60
+
40
61
Parameters
41
62
----------
42
-
h5_path : str
43
-
Path to the USPS `.h5` file.
44
-
63
+
data_path : str or Path
64
+
Path to the directory containing the USPS `.h5` file.
65
+
66
+
sample_ids : list of int
67
+
List of sample indices to load from the dataset.
68
+
69
+
train : bool, optional, default=False
70
+
If `True`, loads the training data from the HDF5 file. If `False`, loads the test data.
71
+
45
72
transform : callable, optional, default=None
46
-
A transform function to apply on images.
73
+
A function to apply transformations to the images. If None, no transformation is applied.
74
+
75
+
nr_channels : int, optional, default=1
76
+
The number of channels in the image. Defaults to 1 for grayscale images.
0 commit comments