Skip to content

Commit c4f6027

Browse files
authored
Merge pull request #111 from SFI-Visual-Intelligence/solveig
Fixed errors in USPS 7-9 and F1
2 parents 3da72bd + 2696958 commit c4f6027

File tree

2 files changed

+69
-152
lines changed

2 files changed

+69
-152
lines changed

CollaborativeCoding/dataloaders/uspsh5_7_9.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22

33
import h5py
44
import numpy as np
5-
import torch
65
from PIL import Image
76
from torch.utils.data import Dataset
8-
from torchvision import transforms
97

108

119
class USPSH5_Digit_7_9_Dataset(Dataset):
@@ -55,6 +53,7 @@ def __init__(
5553
self.h5_path = data_path / self.filename
5654
self.sample_ids = sample_ids
5755
self.nr_channels = nr_channels
56+
self.num_classes = 3
5857

5958
# Load the dataset from the HDF5 file
6059
with h5py.File(self.filepath, "r") as hf:
@@ -104,33 +103,3 @@ def __getitem__(self, id):
104103

105104
return image, label
106105

107-
108-
def main():
109-
# Example Usage:
110-
transform = transforms.Compose(
111-
[
112-
transforms.Resize((16, 16)), # Ensure images are 16x16
113-
transforms.ToTensor(),
114-
transforms.Normalize((0.5,), (0.5,)), # Normalize to [-1, 1]
115-
]
116-
)
117-
indices = np.array([7, 8, 9])
118-
# Load the dataset
119-
dataset = USPSH5_Digit_7_9_Dataset(
120-
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git",
121-
sample_ids=indices,
122-
train=False,
123-
transform=transform,
124-
)
125-
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
126-
batch = next(iter(data_loader)) # grab a batch from the dataloader
127-
img, label = batch
128-
print(img.shape)
129-
print(label.shape)
130-
131-
# Check dataset size
132-
print(f"Dataset size: {len(dataset)}")
133-
134-
135-
if __name__ == "__main__":
136-
main()

CollaborativeCoding/metrics/F1.py

Lines changed: 68 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -6,166 +6,114 @@
66
class F1Score(nn.Module):
77
"""
88
F1 Score implementation with support for both macro and micro averaging.
9-
109
This class computes the F1 score during training using either macro or micro averaging.
11-
The F1 score is calculated based on the true positives (TP), false positives (FP),
12-
and false negatives (FN) for each class.
13-
1410
Parameters
1511
----------
1612
num_classes : int
1713
The number of classes in the classification task.
1814
19-
macro_averaging : bool, optional, default=False
15+
macro_averaging : bool, default=False
2016
If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score.
21-
22-
Attributes
23-
----------
24-
num_classes : int
25-
The number of classes in the classification task.
26-
27-
tp : torch.Tensor
28-
Tensor storing the count of True Positives (TP) for each class.
29-
30-
fp : torch.Tensor
31-
Tensor storing the count of False Positives (FP) for each class.
32-
33-
fn : torch.Tensor
34-
Tensor storing the count of False Negatives (FN) for each class.
35-
36-
macro_averaging : bool
37-
A flag indicating whether to compute the macro-averaged F1 score or not.
3817
"""
3918

4019
def __init__(self, num_classes, macro_averaging=False):
41-
"""
42-
Initializes the F1Score object, setting up the necessary state variables.
43-
44-
Parameters
45-
----------
46-
num_classes : int
47-
The number of classes in the classification task.
48-
49-
macro_averaging : bool, optional, default=False
50-
If True, computes the macro-averaged F1 score. If False, computes the micro-averaged F1 score.
51-
"""
5220
super().__init__()
53-
5421
self.num_classes = num_classes
5522
self.macro_averaging = macro_averaging
5623
self.y_true = []
5724
self.y_pred = []
58-
# Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
59-
self.tp = torch.zeros(num_classes)
60-
self.fp = torch.zeros(num_classes)
61-
self.fn = torch.zeros(num_classes)
6225

63-
def _micro_F1(self, target, preds):
26+
27+
def forward(self, target, preds):
6428
"""
65-
Compute the Micro F1 score by aggregating TP, FP, and FN across all classes.
29+
Stores predictions and targets for computing the F1 score.
6630
67-
Micro F1 score is calculated globally by considering all predictions together, regardless of class.
31+
Parameters
32+
----------
33+
preds : torch.Tensor
34+
Predicted logits (shape: [batch_size, num_classes]).
35+
target : torch.Tensor
36+
True labels (shape: [batch_size]).
37+
"""
38+
preds = torch.argmax(preds, dim=-1) # Convert logits to class indices
39+
self.y_true.append(target.detach())
40+
if preds.dim() == 0: # Scalar (e.g., single class prediction)
41+
preds = preds.unsqueeze(0) # Add batch dimension
42+
self.y_pred.append(preds.detach())
43+
44+
def compute_f1(self):
45+
"""
46+
Computes the F1 score (Micro or Macro).
6847
6948
Returns
7049
-------
7150
torch.Tensor
72-
The micro-averaged F1 score.
51+
The computed F1 score.
7352
"""
74-
for i in range(self.num_classes):
75-
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
76-
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
77-
self.fn[i] += torch.sum((preds != i) & (target == i)).float()
53+
if not self.y_true or not self.y_pred: # Check if empty
54+
return torch.tensor(np.nan)
7855

79-
tp = torch.sum(self.tp)
80-
fp = torch.sum(self.fp)
81-
fn = torch.sum(self.fn)
56+
# Convert lists to tensors
57+
y_true = torch.cat(self.y_true)
58+
y_pred = torch.cat(self.y_pred)
59+
60+
return self._macro_F1(y_true, y_pred) if self.macro_averaging else self._micro_F1(y_true, y_pred)
61+
62+
def _micro_F1(self, target, preds):
63+
"""Computes Micro F1 Score (global TP, FP, FN)."""
64+
tp = torch.sum(preds == target).float()
65+
fp = torch.sum(preds != target).float()
66+
fn = fp # Since all errors are either FP or FN
8267

83-
precision = tp / (tp + fp + 1e-8) # Avoid division by zero
84-
recall = tp / (tp + fn + 1e-8) # Avoid division by zero
68+
precision = tp / (tp + fp + 1e-8)
69+
recall = tp / (tp + fn + 1e-8)
70+
f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
8571

86-
f1 = (
87-
2 * precision * recall / (precision + recall + 1e-8)
88-
) # Avoid division by zero
8972
return f1
9073

9174
def _macro_F1(self, target, preds):
92-
"""
93-
Compute the Macro F1 score by calculating the F1 score per class and averaging.
94-
95-
Macro F1 score is calculated as the average of per-class F1 scores. This approach treats all classes equally,
96-
regardless of their frequency.
75+
"""Computes Macro F1 Score in a vectorized way (no loops)."""
76+
num_classes = self.num_classes
77+
target = target.long() # Ensure target is a LongTensor
78+
preds = preds.long()
79+
# Create one-hot encodings of the true and predicted labels
80+
target_one_hot = torch.nn.functional.one_hot(target, num_classes=num_classes)
81+
preds_one_hot = torch.nn.functional.one_hot(preds, num_classes=num_classes)
9782

98-
Returns
99-
-------
100-
torch.Tensor
101-
The macro-averaged F1 score.
102-
"""
103-
# Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
104-
for i in range(self.num_classes):
105-
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
106-
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
107-
self.fn[i] += torch.sum((preds != i) & (target == i)).float()
108-
109-
precision_per_class = self.tp / (
110-
self.tp + self.fp + 1e-8
111-
) # Avoid division by zero
112-
recall_per_class = self.tp / (
113-
self.tp + self.fn + 1e-8
114-
) # Avoid division by zero
115-
f1_per_class = (
116-
2
117-
* precision_per_class
118-
* recall_per_class
119-
/ (precision_per_class + recall_per_class + 1e-8)
120-
) # Avoid division by zero
121-
122-
# Take the average of F1 scores across all classes
123-
f1_score = torch.mean(f1_per_class)
124-
return f1_score
125-
126-
def forward(self, preds, target):
127-
"""
83+
# Compute TP, FP, FN for each class
84+
tp = torch.sum(target_one_hot * preds_one_hot, dim=0).float()
85+
fp = torch.sum(preds_one_hot * (1 - target_one_hot), dim=0).float()
86+
fn = torch.sum(target_one_hot * (1 - preds_one_hot), dim=0).float()
12887

129-
Update the True Positives, False Positives, and False Negatives, and compute the F1 score.
88+
# Compute precision and recall per class
89+
precision = tp / (tp + fp + 1e-8)
90+
recall = tp / (tp + fn + 1e-8)
13091

131-
This method computes the F1 score based on the predictions and true labels. It can compute either the
132-
macro-averaged or micro-averaged F1 score, depending on the `macro_averaging` flag.
92+
# Compute per-class F1 score
93+
f1_per_class = 2 * (precision * recall) / (precision + recall + 1e-8)
13394

134-
Parameters
135-
----------
136-
preds : torch.Tensor
137-
Predicted logits or class indices (shape: [batch_size, num_classes]).
138-
These logits are typically the output of a softmax or sigmoid activation.
95+
# Compute Macro F1 (mean over all classes)
96+
return torch.mean(f1_per_class)
13997

140-
target : torch.Tensor
141-
True labels (shape: [batch_size]), where each element is an integer representing the true class.
98+
def __returnmetric__(self):
99+
"""
100+
Computes and returns the F1 score (Micro or Macro).
142101
143102
Returns
144103
-------
145104
torch.Tensor
146-
The computed F1 score (either micro or macro, based on `macro_averaging`).
105+
The computed F1 score.
147106
"""
148-
preds = torch.argmax(preds, dim=-1)
149-
self.y_true.append(target)
150-
self.y_pred.append(preds)
107+
if not self.y_true or not self.y_pred: # Check if empty
108+
return torch.tensor(np.nan)
151109

152-
def __returnmetric__(self):
153-
if self.y_true == [] or self.y_pred == []:
154-
return np.nan
155-
if isinstance(self.y_true, list):
156-
if len(self.y_true) == 1:
157-
self.y_true = self.y_true[0]
158-
self.y_pred = self.y_pred[0]
159-
else:
160-
self.y_true = torch.cat(self.y_true)
161-
self.y_pred = torch.cat(self.y_pred)
162-
return (
163-
self._micro_F1(self.y_true, self.y_pred)
164-
if not self.macro_averaging
165-
else self._macro_F1(self.y_true, self.y_pred)
166-
)
110+
# Convert lists to tensors
111+
y_true = torch.cat([t.unsqueeze(0) if t.dim() == 0 else t for t in self.y_true])
112+
y_pred = torch.cat([t.unsqueeze(0) if t.dim() == 0 else t for t in self.y_pred])
113+
114+
return self._macro_F1(y_true, y_pred) if self.macro_averaging else self._micro_F1(y_true, y_pred)
167115

168116
def __reset__(self):
117+
"""Resets stored predictions and targets."""
169118
self.y_true = []
170-
self.y_pred = []
171-
return None
119+
self.y_pred = []

0 commit comments

Comments
 (0)