Skip to content

Commit 203ace1

Browse files
committed
Updated my documentations
1 parent 80e89eb commit 203ace1

File tree

3 files changed

+192
-69
lines changed

3 files changed

+192
-69
lines changed

CollaborativeCoding/dataloaders/uspsh5_7_9.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,45 @@
88

99
class USPSH5_Digit_7_9_Dataset(Dataset):
1010
"""
11-
Custom USPS dataset class that loads images with digits 7-9 from an .h5 file.
11+
This class loads a subset of the USPS dataset, specifically images of digits 7, 8, and 9, from an HDF5 file.
12+
It allows for applying transformations to the images and provides methods to retrieve images and their corresponding labels.
1213
1314
Parameters
1415
----------
15-
h5_path : str
16-
Path to the USPS `.h5` file.
16+
data_path : str or Path
17+
Path to the directory containing the USPS `.h5` file. This file should contain the data in the "train" or "test" group.
18+
19+
sample_ids : list of int
20+
A list of sample indices to be used from the dataset. This allows for filtering or selecting a subset of the full dataset.
21+
22+
train : bool, optional, default=False
23+
If `True`, the dataset is loaded in training mode (using the "train" group). If `False`, the dataset is loaded in test mode (using the "test" group).
1724
1825
transform : callable, optional, default=None
19-
A transform function to apply on images. If None, no transformation is applied.
26+
A transformation function to apply to each image. If `None`, no transformation is applied. Typically used for data augmentation or normalization.
27+
28+
nr_channels : int, optional, default=1
29+
The number of channels in the image. USPS images are typically grayscale, so this should generally be set to 1. This parameter allows for potential future flexibility.
2030
2131
Attributes
2232
----------
2333
images : numpy.ndarray
24-
The filtered images corresponding to digits 7-9.
34+
Array of images corresponding to digits 7, 8, and 9 from the USPS dataset. The images are loaded from the HDF5 file and filtered based on the labels.
2535
2636
labels : numpy.ndarray
27-
The filtered labels corresponding to digits 7-9.
37+
Array of labels corresponding to the images. Only labels of digits 7, 8, and 9 are retained, and they are mapped to 0, 1, and 2 for classification tasks.
2838
2939
transform : callable, optional
30-
A transform function to apply to the images.
40+
A transformation function to apply to the images. This is passed as an argument during initialization.
41+
42+
label_shift : function
43+
A function to shift the labels for classification purposes. It maps the original labels (7, 8, 9) to (0, 1, 2).
44+
45+
label_restore : function
46+
A function to restore the original labels (7, 8, 9) from the shifted labels (0, 1, 2).
47+
48+
num_classes : int
49+
The number of unique labels in the dataset, which is 3 (for digits 7, 8, and 9).
3150
"""
3251

3352
def __init__(
@@ -36,14 +55,25 @@ def __init__(
3655
super().__init__()
3756
"""
3857
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
39-
58+
59+
The dataset is filtered to only include images of digits 7, 8, and 9, which are mapped to labels 0, 1, and 2 respectively for classification purposes.
60+
4061
Parameters
4162
----------
42-
h5_path : str
43-
Path to the USPS `.h5` file.
44-
63+
data_path : str or Path
64+
Path to the directory containing the USPS `.h5` file.
65+
66+
sample_ids : list of int
67+
List of sample indices to load from the dataset.
68+
69+
train : bool, optional, default=False
70+
If `True`, loads the training data from the HDF5 file. If `False`, loads the test data.
71+
4572
transform : callable, optional, default=None
46-
A transform function to apply on images.
73+
A function to apply transformations to the images. If None, no transformation is applied.
74+
75+
nr_channels : int, optional, default=1
76+
The number of channels in the image. Defaults to 1 for grayscale images.
4777
"""
4878
self.filename = "usps.h5"
4979
path = data_path if isinstance(data_path, Path) else Path(data_path)
@@ -72,27 +102,33 @@ def __len__(self):
72102
"""
73103
Returns the total number of samples in the dataset.
74104
105+
This method is required for PyTorch's Dataset class, as it allows PyTorch to determine the size of the dataset.
106+
75107
Returns
76108
-------
77109
int
78-
The number of images in the dataset.
110+
The number of images in the dataset (after filtering for digits 7, 8, and 9).
79111
"""
112+
80113
return len(self.images)
81114

82115
def __getitem__(self, id):
83116
"""
84117
Returns a sample from the dataset given an index.
85118
119+
This method is required for PyTorch's Dataset class, as it allows indexing into the dataset to retrieve specific samples.
120+
86121
Parameters
87122
----------
88123
idx : int
89-
The index of the sample to retrieve.
124+
The index of the sample to retrieve from the dataset.
90125
91126
Returns
92127
-------
93128
tuple
94-
- image (PIL Image): The image at the specified index.
95-
- label (int): The label corresponding to the image.
129+
A tuple containing:
130+
- image (PIL Image): The image at the specified index.
131+
- label (int): The label corresponding to the image, shifted to be in the range [0, 2] for classification.
96132
"""
97133
# Convert to PIL Image (USPS images are typically grayscale 16x16)
98134
image = Image.fromarray(self.images[id].astype(np.uint8), mode="L")

CollaborativeCoding/metrics/F1.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,67 @@
55

66
class F1Score(nn.Module):
77
"""
8-
F1 Score implementation with support for both macro and micro averaging.
9-
This class computes the F1 score during training using either macro or micro averaging.
8+
Computes the F1 score for classification tasks with support for both macro and micro averaging.
9+
10+
This class allows you to compute the F1 score during training or evaluation. You can select between two methods of averaging:
11+
- **Micro Averaging**: Computes the F1 score globally, treating each individual prediction as equally important.
12+
- **Macro Averaging**: Computes the F1 score for each class individually and then averages the scores.
13+
1014
Parameters
1115
----------
1216
num_classes : int
1317
The number of classes in the classification task.
1418
15-
macro_averaging : bool, default=False
16-
If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score.
19+
macro_averaging : bool, optional, default=False
20+
If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score. Default is micro averaging.
21+
22+
Attributes
23+
----------
24+
num_classes : int
25+
The number of classes in the classification task.
26+
27+
macro_averaging : bool
28+
A flag to determine whether to compute the macro-averaged or micro-averaged F1 score.
29+
30+
y_true : list
31+
A list to store true labels for the current batch.
32+
33+
y_pred : list
34+
A list to store predicted labels for the current batch.
35+
36+
Methods
37+
-------
38+
forward(target, preds)
39+
Stores predictions and true labels for computing the F1 score during training or evaluation.
40+
41+
compute_f1()
42+
Computes and returns the F1 score based on the stored predictions and true labels.
43+
44+
_micro_F1(target, preds)
45+
Computes the micro-averaged F1 score based on the global true positive, false positive, and false negative counts.
46+
47+
_macro_F1(target, preds)
48+
Computes the macro-averaged F1 score by calculating the F1 score per class and then averaging across all classes.
49+
50+
__returnmetric__()
51+
Computes and returns the F1 score (Micro or Macro) as specified.
52+
53+
__reset__()
54+
Resets the stored predictions and true labels, preparing for the next batch or epoch.
1755
"""
1856

1957
def __init__(self, num_classes, macro_averaging=False):
58+
"""
59+
Initializes the F1Score object with the number of classes and averaging mode.
60+
61+
Parameters
62+
----------
63+
num_classes : int
64+
The number of classes in the classification task.
65+
66+
macro_averaging : bool, optional, default=False
67+
If True, compute the macro-averaged F1 score. If False, compute the micro-averaged F1 score.
68+
"""
2069
super().__init__()
2170
self.num_classes = num_classes
2271
self.macro_averaging = macro_averaging
@@ -25,14 +74,15 @@ def __init__(self, num_classes, macro_averaging=False):
2574

2675
def forward(self, target, preds):
2776
"""
28-
Stores predictions and targets for computing the F1 score.
77+
Stores the true labels and predictions to compute the F1 score.
2978
3079
Parameters
3180
----------
32-
preds : torch.Tensor
33-
Predicted logits (shape: [batch_size, num_classes]).
3481
target : torch.Tensor
3582
True labels (shape: [batch_size]).
83+
84+
preds : torch.Tensor
85+
Predicted logits (shape: [batch_size, num_classes]).
3686
"""
3787
preds = torch.argmax(preds, dim=-1) # Convert logits to class indices
3888
self.y_true.append(target.detach())
@@ -47,7 +97,7 @@ def compute_f1(self):
4797
Returns
4898
-------
4999
torch.Tensor
50-
The computed F1 score.
100+
The computed F1 score. Returns NaN if no predictions or targets are available.
51101
"""
52102
if not self.y_true or not self.y_pred: # Check if empty
53103
return torch.tensor(np.nan)
@@ -63,7 +113,7 @@ def compute_f1(self):
63113
)
64114

65115
def _micro_F1(self, target, preds):
66-
"""Computes Micro F1 Score (global TP, FP, FN)."""
116+
"""Computes the Micro-averaged F1 score (global TP, FP, FN)."""
67117
tp = torch.sum(preds == target).float()
68118
fp = torch.sum(preds != target).float()
69119
fn = fp # Since all errors are either FP or FN
@@ -75,7 +125,7 @@ def _micro_F1(self, target, preds):
75125
return f1
76126

77127
def _macro_F1(self, target, preds):
78-
"""Computes Macro F1 Score in a vectorized way (no loops)."""
128+
"""Computes the Macro-averaged F1 score."""
79129
num_classes = self.num_classes
80130
target = target.long() # Ensure target is a LongTensor
81131
preds = preds.long()
@@ -100,12 +150,12 @@ def _macro_F1(self, target, preds):
100150

101151
def __returnmetric__(self):
102152
"""
103-
Computes and returns the F1 score (Micro or Macro).
153+
Computes and returns the F1 score (Micro or Macro) based on the stored predictions and targets.
104154
105155
Returns
106156
-------
107157
torch.Tensor
108-
The computed F1 score.
158+
The computed F1 score. Returns NaN if no predictions or targets are available.
109159
"""
110160
if not self.y_true or not self.y_pred: # Check if empty
111161
return torch.tensor(np.nan)
@@ -121,6 +171,6 @@ def __returnmetric__(self):
121171
)
122172

123173
def __reset__(self):
124-
"""Resets stored predictions and targets."""
174+
"""Resets the stored predictions and targets for the next batch or epoch."""
125175
self.y_true = []
126176
self.y_pred = []

0 commit comments

Comments
 (0)