Skip to content

Commit 71957eb

Browse files
committed
Merge branch 'main' of github.com:SFI-Visual-Intelligence/Collaborative-Coding-Exam into johan/devbranch
2 parents e3b103f + 4350664 commit 71957eb

File tree

13 files changed

+389
-12
lines changed

13 files changed

+389
-12
lines changed

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def main():
6666
"--modelname",
6767
type=str,
6868
default="MagnusModel",
69-
choices=["MagnusModel"],
69+
choices=["MagnusModel", "ChristianModel"],
7070
help="Model which to be trained on",
7171
)
7272
parser.add_argument(
@@ -196,7 +196,7 @@ def main():
196196
model.eval()
197197
with th.no_grad():
198198
for x, y in valiloader:
199-
x = x.to(device)
199+
x, y = x.to(device), y.to(device)
200200
pred = model.forward(x)
201201
loss = criterion(y, pred)
202202
evalloss.append(loss.item())

tests/test_createfolders.py

Whitespace-only changes.

tests/test_dataloaders.py

Whitespace-only changes.

tests/test_metrics.py

Whitespace-only changes.

tests/test_models.py

Whitespace-only changes.

utils/dataloaders/usps_0_6.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ def __getitem__(self, idx):
106106

107107
data = data.reshape(16, 16)
108108

109+
# one hot encode the target
110+
target = np.eye(self.num_classes, dtype=np.float32)[target]
111+
112+
# Add channel dimension
113+
data = np.expand_dims(data, axis=0)
114+
109115
if self.transform:
110116
data = self.transform(data)
111117

utils/dataloaders/uspsh5_7_9.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from torch.utils.data import Dataset
2+
import numpy as np
3+
import h5py
4+
from torchvision import transforms
5+
from PIL import Image
6+
import torch
7+
8+
9+
class USPSH5_Digit_7_9_Dataset(Dataset):
10+
"""
11+
Custom USPS dataset class that loads images with digits 7-9 from an .h5 file.
12+
13+
Parameters
14+
----------
15+
h5_path : str
16+
Path to the USPS `.h5` file.
17+
18+
transform : callable, optional, default=None
19+
A transform function to apply on images. If None, no transformation is applied.
20+
21+
Attributes
22+
----------
23+
images : numpy.ndarray
24+
The filtered images corresponding to digits 7-9.
25+
26+
labels : numpy.ndarray
27+
The filtered labels corresponding to digits 7-9.
28+
29+
transform : callable, optional
30+
A transform function to apply to the images.
31+
"""
32+
33+
def __init__(self, h5_path, mode, transform=None):
34+
super().__init__()
35+
"""
36+
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
37+
38+
Parameters
39+
----------
40+
h5_path : str
41+
Path to the USPS `.h5` file.
42+
43+
transform : callable, optional, default=None
44+
A transform function to apply on images.
45+
"""
46+
47+
self.transform = transform
48+
self.mode = mode
49+
self.h5_path = h5_path
50+
# Load the dataset from the HDF5 file
51+
with h5py.File(self.h5_path, "r") as hf:
52+
images = hf[self.mode]["data"][:]
53+
labels = hf[self.mode]["target"][:]
54+
55+
# Filter only digits 7, 8, and 9
56+
mask = np.isin(labels, [7, 8, 9])
57+
self.images = images[mask]
58+
self.labels = labels[mask]
59+
60+
def __len__(self):
61+
"""
62+
Returns the total number of samples in the dataset.
63+
64+
Returns
65+
-------
66+
int
67+
The number of images in the dataset.
68+
"""
69+
return len(self.images)
70+
71+
def __getitem__(self, id):
72+
"""
73+
Returns a sample from the dataset given an index.
74+
75+
Parameters
76+
----------
77+
idx : int
78+
The index of the sample to retrieve.
79+
80+
Returns
81+
-------
82+
tuple
83+
- image (PIL Image): The image at the specified index.
84+
- label (int): The label corresponding to the image.
85+
"""
86+
# Convert to PIL Image (USPS images are typically grayscale 16x16)
87+
image = Image.fromarray(self.images[id].astype(np.uint8), mode="L")
88+
label = int(self.labels[id]) # Convert label to integer
89+
90+
if self.transform:
91+
image = self.transform(image)
92+
93+
return image, label
94+
95+
96+
def main():
97+
# Example Usage:
98+
transform = transforms.Compose([
99+
transforms.Resize((16, 16)), # Ensure images are 16x16
100+
transforms.ToTensor(),
101+
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
102+
])
103+
104+
# Load the dataset
105+
dataset = USPSH5_Digit_7_9_Dataset(h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5", mode="train", transform=transform)
106+
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
107+
batch = next(iter(data_loader)) # grab a batch from the dataloader
108+
img, label = batch
109+
print(img.shape)
110+
print(label.shape)
111+
112+
# Check dataset size
113+
print(f"Dataset size: {len(dataset)}")
114+
115+
if __name__ == '__main__':
116+
main()

utils/load_model.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import torch.nn as nn
22

3-
from .models import MagnusModel
3+
from .models import ChristianModel, MagnusModel
44

55

6-
def load_model(modelname: str) -> nn.Module:
7-
if modelname == "MagnusModel":
8-
return MagnusModel()
9-
else:
10-
raise ValueError(
11-
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
12-
)
6+
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
7+
match modelname.lower():
8+
case "magnusmodel":
9+
return MagnusModel(*args, **kwargs)
10+
case "christianmodel":
11+
return ChristianModel(*args, **kwargs)
12+
case _:
13+
raise ValueError(
14+
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
15+
)

utils/metrics/F1.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
5+
class F1Score(nn.Module):
6+
"""
7+
F1 Score implementation with direct averaging inside the compute method.
8+
9+
Parameters
10+
----------
11+
num_classes : int
12+
Number of classes.
13+
14+
Attributes
15+
----------
16+
num_classes : int
17+
The number of classes.
18+
19+
tp : torch.Tensor
20+
Tensor for True Positives (TP) for each class.
21+
22+
fp : torch.Tensor
23+
Tensor for False Positives (FP) for each class.
24+
25+
fn : torch.Tensor
26+
Tensor for False Negatives (FN) for each class.
27+
"""
28+
def __init__(self, num_classes):
29+
"""
30+
Initializes the F1Score object, setting up the necessary state variables.
31+
32+
Parameters
33+
----------
34+
num_classes : int
35+
The number of classes in the classification task.
36+
37+
"""
38+
39+
super().__init__()
40+
41+
self.num_classes = num_classes
42+
43+
# Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
44+
self.tp = torch.zeros(num_classes)
45+
self.fp = torch.zeros(num_classes)
46+
self.fn = torch.zeros(num_classes)
47+
48+
def update(self, preds, target):
49+
"""
50+
Update the variables with predictions and true labels.
51+
52+
Parameters
53+
----------
54+
preds : torch.Tensor
55+
Predicted logits (shape: [batch_size, num_classes]).
56+
57+
target : torch.Tensor
58+
True labels (shape: [batch_size]).
59+
"""
60+
preds = torch.argmax(preds, dim=1)
61+
62+
# Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
63+
for i in range(self.num_classes):
64+
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
65+
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
66+
self.fn[i] += torch.sum((preds != i) & (target == i)).float()
67+
68+
def compute(self):
69+
"""
70+
Compute the F1 score.
71+
72+
Returns
73+
-------
74+
torch.Tensor
75+
The computed F1 score.
76+
"""
77+
78+
# Compute F1 score based on the specified averaging method
79+
f1_score = 2 * torch.sum(self.tp) / (2 * torch.sum(self.tp) + torch.sum(self.fp) + torch.sum(self.fn))
80+
81+
return f1_score
82+
83+
84+
def test_f1score():
85+
f1_metric = F1Score(num_classes=3)
86+
preds = torch.tensor([[0.8, 0.1, 0.1],
87+
[0.2, 0.7, 0.1],
88+
[0.2, 0.3, 0.5],
89+
[0.1, 0.2, 0.7]])
90+
91+
target = torch.tensor([0, 1, 0, 2])
92+
93+
f1_metric.update(preds, target)
94+
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
95+
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
96+
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."

utils/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
__all__ = ["EntropyPrediction"]
1+
__all__ = ["EntropyPrediction", "Recall"]
22

33
from .EntropyPred import EntropyPrediction
4+
from .recall import Recall

0 commit comments

Comments
 (0)