Skip to content

Commit 970fe05

Browse files
authored
Merge pull request #22 from SFI-Visual-Intelligence/solveig-branch
Added folders for our test, this resolves issue #17
2 parents 60abd72 + afeae2a commit 970fe05

File tree

6 files changed

+212
-0
lines changed

6 files changed

+212
-0
lines changed

tests/test_createfolders.py

Whitespace-only changes.

tests/test_dataloaders.py

Whitespace-only changes.

tests/test_metrics.py

Whitespace-only changes.

tests/test_models.py

Whitespace-only changes.

utils/dataloaders/uspsh5_7_9.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from torch.utils.data import Dataset
2+
import numpy as np
3+
import h5py
4+
from torchvision import transforms
5+
from PIL import Image
6+
import torch
7+
8+
9+
class USPSH5_Digit_7_9_Dataset(Dataset):
10+
"""
11+
Custom USPS dataset class that loads images with digits 7-9 from an .h5 file.
12+
13+
Parameters
14+
----------
15+
h5_path : str
16+
Path to the USPS `.h5` file.
17+
18+
transform : callable, optional, default=None
19+
A transform function to apply on images. If None, no transformation is applied.
20+
21+
Attributes
22+
----------
23+
images : numpy.ndarray
24+
The filtered images corresponding to digits 7-9.
25+
26+
labels : numpy.ndarray
27+
The filtered labels corresponding to digits 7-9.
28+
29+
transform : callable, optional
30+
A transform function to apply to the images.
31+
"""
32+
33+
def __init__(self, h5_path, mode, transform=None):
34+
super().__init__()
35+
"""
36+
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
37+
38+
Parameters
39+
----------
40+
h5_path : str
41+
Path to the USPS `.h5` file.
42+
43+
transform : callable, optional, default=None
44+
A transform function to apply on images.
45+
"""
46+
47+
self.transform = transform
48+
self.mode = mode
49+
self.h5_path = h5_path
50+
# Load the dataset from the HDF5 file
51+
with h5py.File(self.h5_path, "r") as hf:
52+
images = hf[self.mode]["data"][:]
53+
labels = hf[self.mode]["target"][:]
54+
55+
# Filter only digits 7, 8, and 9
56+
mask = np.isin(labels, [7, 8, 9])
57+
self.images = images[mask]
58+
self.labels = labels[mask]
59+
60+
def __len__(self):
61+
"""
62+
Returns the total number of samples in the dataset.
63+
64+
Returns
65+
-------
66+
int
67+
The number of images in the dataset.
68+
"""
69+
return len(self.images)
70+
71+
def __getitem__(self, id):
72+
"""
73+
Returns a sample from the dataset given an index.
74+
75+
Parameters
76+
----------
77+
idx : int
78+
The index of the sample to retrieve.
79+
80+
Returns
81+
-------
82+
tuple
83+
- image (PIL Image): The image at the specified index.
84+
- label (int): The label corresponding to the image.
85+
"""
86+
# Convert to PIL Image (USPS images are typically grayscale 16x16)
87+
image = Image.fromarray(self.images[id].astype(np.uint8), mode="L")
88+
label = int(self.labels[id]) # Convert label to integer
89+
90+
if self.transform:
91+
image = self.transform(image)
92+
93+
return image, label
94+
95+
96+
def main():
97+
# Example Usage:
98+
transform = transforms.Compose([
99+
transforms.Resize((16, 16)), # Ensure images are 16x16
100+
transforms.ToTensor(),
101+
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
102+
])
103+
104+
# Load the dataset
105+
dataset = USPSH5_Digit_7_9_Dataset(h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5", mode="train", transform=transform)
106+
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
107+
batch = next(iter(data_loader)) # grab a batch from the dataloader
108+
img, label = batch
109+
print(img.shape)
110+
print(label.shape)
111+
112+
# Check dataset size
113+
print(f"Dataset size: {len(dataset)}")
114+
115+
if __name__ == '__main__':
116+
main()

utils/metrics/F1.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
5+
class F1Score(nn.Module):
6+
"""
7+
F1 Score implementation with direct averaging inside the compute method.
8+
9+
Parameters
10+
----------
11+
num_classes : int
12+
Number of classes.
13+
14+
Attributes
15+
----------
16+
num_classes : int
17+
The number of classes.
18+
19+
tp : torch.Tensor
20+
Tensor for True Positives (TP) for each class.
21+
22+
fp : torch.Tensor
23+
Tensor for False Positives (FP) for each class.
24+
25+
fn : torch.Tensor
26+
Tensor for False Negatives (FN) for each class.
27+
"""
28+
def __init__(self, num_classes):
29+
"""
30+
Initializes the F1Score object, setting up the necessary state variables.
31+
32+
Parameters
33+
----------
34+
num_classes : int
35+
The number of classes in the classification task.
36+
37+
"""
38+
39+
super().__init__()
40+
41+
self.num_classes = num_classes
42+
43+
# Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
44+
self.tp = torch.zeros(num_classes)
45+
self.fp = torch.zeros(num_classes)
46+
self.fn = torch.zeros(num_classes)
47+
48+
def update(self, preds, target):
49+
"""
50+
Update the variables with predictions and true labels.
51+
52+
Parameters
53+
----------
54+
preds : torch.Tensor
55+
Predicted logits (shape: [batch_size, num_classes]).
56+
57+
target : torch.Tensor
58+
True labels (shape: [batch_size]).
59+
"""
60+
preds = torch.argmax(preds, dim=1)
61+
62+
# Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
63+
for i in range(self.num_classes):
64+
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
65+
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
66+
self.fn[i] += torch.sum((preds != i) & (target == i)).float()
67+
68+
def compute(self):
69+
"""
70+
Compute the F1 score.
71+
72+
Returns
73+
-------
74+
torch.Tensor
75+
The computed F1 score.
76+
"""
77+
78+
# Compute F1 score based on the specified averaging method
79+
f1_score = 2 * torch.sum(self.tp) / (2 * torch.sum(self.tp) + torch.sum(self.fp) + torch.sum(self.fn))
80+
81+
return f1_score
82+
83+
84+
def test_f1score():
85+
f1_metric = F1Score(num_classes=3)
86+
preds = torch.tensor([[0.8, 0.1, 0.1],
87+
[0.2, 0.7, 0.1],
88+
[0.2, 0.3, 0.5],
89+
[0.1, 0.2, 0.7]])
90+
91+
target = torch.tensor([0, 1, 0, 2])
92+
93+
f1_metric.update(preds, target)
94+
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
95+
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
96+
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."

0 commit comments

Comments
 (0)