Skip to content

Commit cc46a5e

Browse files
committed
Merge branch 'main' of github.com:SFI-Visual-Intelligence/Collaborative-Coding-Exam into johan/devbranch
2 parents ee9f548 + 97d1cec commit cc46a5e

File tree

12 files changed

+481
-117
lines changed

12 files changed

+481
-117
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ doc/autoapi
1212

1313
*.DS_Store
1414

15+
# Jan
16+
job.yaml
17+
sync.sh
1518

1619
#Magnus specific
1720
job*

CollaborativeCoding/dataloaders/download.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
class Downloader:
1818
"""
19-
Class to download and load the USPS dataset.
19+
Class used to verify availability and potentially download implemented datasets.
2020
2121
Methods
2222
-------
2323
mnist(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
24-
Download the MNIST dataset and save it as an HDF5 file to `data_dir`.
24+
Checks the availability of mnist dataset. If not present downloads it into MNIST folder in `data_dir`.
2525
svhn(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
2626
Download the SVHN dataset and save it as an HDF5 file to `data_dir`.
2727
usps(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
@@ -42,6 +42,10 @@ class Downloader:
4242
"""
4343

4444
def mnist(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
45+
"""
46+
Check the availability of mnist dataset. If not present downloads it into MNIST folder in `data_dir`.
47+
"""
48+
4549
def _chech_is_downloaded(path: Path) -> bool:
4650
path = path / "MNIST"
4751
if path.exists():

CollaborativeCoding/dataloaders/mnist_0_3.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010
class MNISTDataset0_3(Dataset):
1111
"""
1212
A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3.
13-
Parameters
13+
14+
Args
1415
----------
1516
data_path : Path
16-
The root directory where the MNIST data is stored.
17+
The root directory where the MNIST folder with data is stored.
1718
sample_ids : list
1819
A list of indices specifying which samples to load.
1920
train : bool, optional
2021
If True, load training data, otherwise load test data. Default is False.
2122
transform : callable, optional
2223
A function/transform to apply to the images. Default is None.
24+
2325
Attributes
2426
----------
25-
data_path : Path
26-
The root directory where the MNIST data is stored.
2727
mnist_path : Path
2828
The directory where the MNIST dataset is located within the root directory.
2929
idx : list
@@ -40,6 +40,7 @@ class MNISTDataset0_3(Dataset):
4040
The path to the label file (train or test) based on the `train` flag.
4141
length : int
4242
The number of samples in the dataset.
43+
4344
Methods
4445
-------
4546
__len__()
@@ -58,8 +59,7 @@ def __init__(
5859
):
5960
super().__init__()
6061

61-
self.data_path = data_path
62-
self.mnist_path = self.data_path / "MNIST"
62+
self.mnist_path = data_path / "MNIST"
6363
self.idx = sample_ids
6464
self.train = train
6565
self.transform = transform

CollaborativeCoding/dataloaders/uspsh5_7_9.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,45 @@
88

99
class USPSH5_Digit_7_9_Dataset(Dataset):
1010
"""
11-
Custom USPS dataset class that loads images with digits 7-9 from an .h5 file.
11+
This class loads a subset of the USPS dataset, specifically images of digits 7, 8, and 9, from an HDF5 file.
12+
It allows for applying transformations to the images and provides methods to retrieve images and their corresponding labels.
1213
1314
Parameters
1415
----------
15-
h5_path : str
16-
Path to the USPS `.h5` file.
16+
data_path : str or Path
17+
Path to the directory containing the USPS `.h5` file. This file should contain the data in the "train" or "test" group.
18+
19+
sample_ids : list of int
20+
A list of sample indices to be used from the dataset. This allows for filtering or selecting a subset of the full dataset.
21+
22+
train : bool, optional, default=False
23+
If `True`, the dataset is loaded in training mode (using the "train" group). If `False`, the dataset is loaded in test mode (using the "test" group).
1724
1825
transform : callable, optional, default=None
19-
A transform function to apply on images. If None, no transformation is applied.
26+
A transformation function to apply to each image. If `None`, no transformation is applied. Typically used for data augmentation or normalization.
27+
28+
nr_channels : int, optional, default=1
29+
The number of channels in the image. USPS images are typically grayscale, so this should generally be set to 1. This parameter allows for potential future flexibility.
2030
2131
Attributes
2232
----------
2333
images : numpy.ndarray
24-
The filtered images corresponding to digits 7-9.
34+
Array of images corresponding to digits 7, 8, and 9 from the USPS dataset. The images are loaded from the HDF5 file and filtered based on the labels.
2535
2636
labels : numpy.ndarray
27-
The filtered labels corresponding to digits 7-9.
37+
Array of labels corresponding to the images. Only labels of digits 7, 8, and 9 are retained, and they are mapped to 0, 1, and 2 for classification tasks.
2838
2939
transform : callable, optional
30-
A transform function to apply to the images.
40+
A transformation function to apply to the images. This is passed as an argument during initialization.
41+
42+
label_shift : function
43+
A function to shift the labels for classification purposes. It maps the original labels (7, 8, 9) to (0, 1, 2).
44+
45+
label_restore : function
46+
A function to restore the original labels (7, 8, 9) from the shifted labels (0, 1, 2).
47+
48+
num_classes : int
49+
The number of unique labels in the dataset, which is 3 (for digits 7, 8, and 9).
3150
"""
3251

3352
def __init__(
@@ -36,14 +55,25 @@ def __init__(
3655
super().__init__()
3756
"""
3857
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
39-
58+
59+
The dataset is filtered to only include images of digits 7, 8, and 9, which are mapped to labels 0, 1, and 2 respectively for classification purposes.
60+
4061
Parameters
4162
----------
42-
h5_path : str
43-
Path to the USPS `.h5` file.
44-
63+
data_path : str or Path
64+
Path to the directory containing the USPS `.h5` file.
65+
66+
sample_ids : list of int
67+
List of sample indices to load from the dataset.
68+
69+
train : bool, optional, default=False
70+
If `True`, loads the training data from the HDF5 file. If `False`, loads the test data.
71+
4572
transform : callable, optional, default=None
46-
A transform function to apply on images.
73+
A function to apply transformations to the images. If None, no transformation is applied.
74+
75+
nr_channels : int, optional, default=1
76+
The number of channels in the image. Defaults to 1 for grayscale images.
4777
"""
4878
self.filename = "usps.h5"
4979
path = data_path if isinstance(data_path, Path) else Path(data_path)
@@ -72,27 +102,33 @@ def __len__(self):
72102
"""
73103
Returns the total number of samples in the dataset.
74104
105+
This method is required for PyTorch's Dataset class, as it allows PyTorch to determine the size of the dataset.
106+
75107
Returns
76108
-------
77109
int
78-
The number of images in the dataset.
110+
The number of images in the dataset (after filtering for digits 7, 8, and 9).
79111
"""
112+
80113
return len(self.images)
81114

82115
def __getitem__(self, id):
83116
"""
84117
Returns a sample from the dataset given an index.
85118
119+
This method is required for PyTorch's Dataset class, as it allows indexing into the dataset to retrieve specific samples.
120+
86121
Parameters
87122
----------
88123
idx : int
89-
The index of the sample to retrieve.
124+
The index of the sample to retrieve from the dataset.
90125
91126
Returns
92127
-------
93128
tuple
94-
- image (PIL Image): The image at the specified index.
95-
- label (int): The label corresponding to the image.
129+
A tuple containing:
130+
- image (PIL Image): The image at the specified index.
131+
- label (int): The label corresponding to the image, shifted to be in the range [0, 2] for classification.
96132
"""
97133
# Convert to PIL Image (USPS images are typically grayscale 16x16)
98134
image = Image.fromarray(self.images[id].astype(np.uint8), mode="L")

CollaborativeCoding/load_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def filter_labels(samples: list, wanted_labels: list) -> list:
1717

1818
def load_data(dataset: str, *args, **kwargs) -> tuple:
1919
"""
20-
load the dataset based on the dataset name.
20+
Load the dataset based on the dataset name.
2121
2222
Args
2323
----

CollaborativeCoding/load_metric.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,32 @@ class MetricWrapper(nn.Module):
99
This class allows you to compute several metrics simultaneously on given
1010
true and predicted labels. It supports a variety of common metrics and
1111
provides methods to accumulate results and reset the state.
12+
1213
Args
1314
----
1415
num_classes : int
1516
The number of classes in the classification task.
1617
metrics : list[str]
1718
A list of metric names to be evaluated.
19+
macro_averaging : bool
20+
Whether to compute macro-averaged metrics for multi-class classification.
21+
1822
Attributes
1923
----------
2024
metrics : dict
2125
A dictionary mapping metric names to their corresponding functions.
2226
num_classes : int
2327
The number of classes for the classification task.
28+
2429
Methods
2530
-------
2631
__call__(y_true, y_pred)
27-
Computes the specified metrics on the provided true and predicted labels.
32+
Passes the true and predicted labels to the metric functions.
2833
getmetrics(str_prefix: str = None)
29-
Retrieves the computed metrics, optionally prefixed with a string.
34+
Retrieves the dictionary of computed metrics, optionally all keys can be prefixed with a string.
3035
resetmetric()
3136
Resets the state of all metric computations.
37+
3238
Examples
3339
--------
3440
>>> from CollaborativeCoding import MetricWrapperProposed

CollaborativeCoding/metrics/F1.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,67 @@
55

66
class F1Score(nn.Module):
77
"""
8-
F1 Score implementation with support for both macro and micro averaging.
9-
This class computes the F1 score during training using either macro or micro averaging.
8+
Computes the F1 score for classification tasks with support for both macro and micro averaging.
9+
10+
This class allows you to compute the F1 score during training or evaluation. You can select between two methods of averaging:
11+
- **Micro Averaging**: Computes the F1 score globally, treating each individual prediction as equally important.
12+
- **Macro Averaging**: Computes the F1 score for each class individually and then averages the scores.
13+
1014
Parameters
1115
----------
1216
num_classes : int
1317
The number of classes in the classification task.
1418
15-
macro_averaging : bool, default=False
16-
If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score.
19+
macro_averaging : bool, optional, default=False
20+
If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score. Default is micro averaging.
21+
22+
Attributes
23+
----------
24+
num_classes : int
25+
The number of classes in the classification task.
26+
27+
macro_averaging : bool
28+
A flag to determine whether to compute the macro-averaged or micro-averaged F1 score.
29+
30+
y_true : list
31+
A list to store true labels for the current batch.
32+
33+
y_pred : list
34+
A list to store predicted labels for the current batch.
35+
36+
Methods
37+
-------
38+
forward(target, preds)
39+
Stores predictions and true labels for computing the F1 score during training or evaluation.
40+
41+
compute_f1()
42+
Computes and returns the F1 score based on the stored predictions and true labels.
43+
44+
_micro_F1(target, preds)
45+
Computes the micro-averaged F1 score based on the global true positive, false positive, and false negative counts.
46+
47+
_macro_F1(target, preds)
48+
Computes the macro-averaged F1 score by calculating the F1 score per class and then averaging across all classes.
49+
50+
__returnmetric__()
51+
Computes and returns the F1 score (Micro or Macro) as specified.
52+
53+
__reset__()
54+
Resets the stored predictions and true labels, preparing for the next batch or epoch.
1755
"""
1856

1957
def __init__(self, num_classes, macro_averaging=False):
58+
"""
59+
Initializes the F1Score object with the number of classes and averaging mode.
60+
61+
Parameters
62+
----------
63+
num_classes : int
64+
The number of classes in the classification task.
65+
66+
macro_averaging : bool, optional, default=False
67+
If True, compute the macro-averaged F1 score. If False, compute the micro-averaged F1 score.
68+
"""
2069
super().__init__()
2170
self.num_classes = num_classes
2271
self.macro_averaging = macro_averaging
@@ -25,14 +74,15 @@ def __init__(self, num_classes, macro_averaging=False):
2574

2675
def forward(self, target, preds):
2776
"""
28-
Stores predictions and targets for computing the F1 score.
77+
Stores the true labels and predictions to compute the F1 score.
2978
3079
Parameters
3180
----------
32-
preds : torch.Tensor
33-
Predicted logits (shape: [batch_size, num_classes]).
3481
target : torch.Tensor
3582
True labels (shape: [batch_size]).
83+
84+
preds : torch.Tensor
85+
Predicted logits (shape: [batch_size, num_classes]).
3686
"""
3787
preds = torch.argmax(preds, dim=-1) # Convert logits to class indices
3888
self.y_true.append(target.detach())
@@ -47,7 +97,7 @@ def compute_f1(self):
4797
Returns
4898
-------
4999
torch.Tensor
50-
The computed F1 score.
100+
The computed F1 score. Returns NaN if no predictions or targets are available.
51101
"""
52102
if not self.y_true or not self.y_pred: # Check if empty
53103
return torch.tensor(np.nan)
@@ -63,7 +113,7 @@ def compute_f1(self):
63113
)
64114

65115
def _micro_F1(self, target, preds):
66-
"""Computes Micro F1 Score (global TP, FP, FN)."""
116+
"""Computes the Micro-averaged F1 score (global TP, FP, FN)."""
67117
tp = torch.sum(preds == target).float()
68118
fp = torch.sum(preds != target).float()
69119
fn = fp # Since all errors are either FP or FN
@@ -75,7 +125,7 @@ def _micro_F1(self, target, preds):
75125
return f1
76126

77127
def _macro_F1(self, target, preds):
78-
"""Computes Macro F1 Score in a vectorized way (no loops)."""
128+
"""Computes the Macro-averaged F1 score."""
79129
num_classes = self.num_classes
80130
target = target.long() # Ensure target is a LongTensor
81131
preds = preds.long()
@@ -100,12 +150,12 @@ def _macro_F1(self, target, preds):
100150

101151
def __returnmetric__(self):
102152
"""
103-
Computes and returns the F1 score (Micro or Macro).
153+
Computes and returns the F1 score (Micro or Macro) based on the stored predictions and targets.
104154
105155
Returns
106156
-------
107157
torch.Tensor
108-
The computed F1 score.
158+
The computed F1 score. Returns NaN if no predictions or targets are available.
109159
"""
110160
if not self.y_true or not self.y_pred: # Check if empty
111161
return torch.tensor(np.nan)
@@ -121,6 +171,6 @@ def __returnmetric__(self):
121171
)
122172

123173
def __reset__(self):
124-
"""Resets stored predictions and targets."""
174+
"""Resets the stored predictions and targets for the next batch or epoch."""
125175
self.y_true = []
126176
self.y_pred = []

0 commit comments

Comments
 (0)