Skip to content

Commit 365d389

Browse files
authored
Merge pull request #73 from SFI-Visual-Intelligence/solveig-branch
Solveig branch
2 parents d060ff5 + 6e0c345 commit 365d389

File tree

6 files changed

+27
-16
lines changed

6 files changed

+27
-16
lines changed

main.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ def main():
3030

3131
device = args.device
3232

33-
3433
if "usps" in args.dataset.lower():
35-
3634
transform = transforms.Compose(
3735
[
3836
transforms.Resize((28, 28)),
@@ -47,7 +45,6 @@ def main():
4745
data_dir=args.datafolder,
4846
transform=transform,
4947
val_size=args.val_size,
50-
5148
)
5249

5350
train_metrics = MetricWrapper(
@@ -129,7 +126,6 @@ def main():
129126
project=args.run_name,
130127
tags=[args.modelname, args.dataset],
131128
config=args,
132-
133129
)
134130
wandb.watch(model)
135131

tests/test_models.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33

4-
from utils.models import ChristianModel, JanModel, MagnusModel
4+
from utils.models import ChristianModel, JanModel, MagnusModel, SolveigModel
55

66

77
@pytest.mark.parametrize(
@@ -34,6 +34,21 @@ def test_jan_model(image_shape, num_classes):
3434
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
3535

3636

37+
@pytest.mark.parametrize(
38+
"image_shape, num_classes",
39+
[((3, 16, 16), 3), ((3, 16, 16), 7)],
40+
)
41+
def test_solveig_model(image_shape, num_classes):
42+
n, c, h, w = 5, *image_shape
43+
44+
model = SolveigModel(image_shape, num_classes)
45+
46+
x = torch.randn(n, c, h, w)
47+
y = model(x)
48+
49+
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
50+
51+
3752
@pytest.mark.parametrize("image_shape", [(3, 28, 28)])
3853
def test_magnus_model(image_shape):
3954
import torch as th

utils/arg_parser.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def get_args():
3333
help="Whether model should be saved or not.",
3434
)
3535

36-
3736
# Data/Model specific values
3837
parser.add_argument(
3938
"--modelname",
@@ -83,7 +82,6 @@ def get_args():
8382
"--macro_averaging",
8483
action="store_true",
8584
help="If the flag is included, the metrics will be calculated using macro averaging.",
86-
8785
)
8886

8987
# Training specific values

utils/dataloaders/svhn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22

3-
43
import h5py
54
import numpy as np
65
from PIL import Image
@@ -95,7 +94,6 @@ def __getitem__(self, index):
9594
img = Image.fromarray(h5f["images"][index])
9695

9796
if self.nr_channels == 1:
98-
9997
img = img.convert("L")
10098
if self.transforms is not None:
10199
img = self.transforms(img)

utils/dataloaders/uspsh5_7_9.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from pathlib import Path
2+
13
import h5py
24
import numpy as np
35
import torch
@@ -30,7 +32,7 @@ class USPSH5_Digit_7_9_Dataset(Dataset):
3032
A transform function to apply to the images.
3133
"""
3234

33-
def __init__(self, h5_path, mode, transform=None):
35+
def __init__(self, data_path, train=False, transform=None):
3436
super().__init__()
3537
"""
3638
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
@@ -43,12 +45,13 @@ def __init__(self, h5_path, mode, transform=None):
4345
transform : callable, optional, default=None
4446
A transform function to apply on images.
4547
"""
46-
48+
self.filename = "usps.h5"
49+
path = data_path if isinstance(data_path, Path) else Path(data_path)
50+
self.filepath = path / self.filename
4751
self.transform = transform
48-
self.mode = mode
49-
self.h5_path = h5_path
52+
self.mode = "train" if train else "test"
5053
# Load the dataset from the HDF5 file
51-
with h5py.File(self.h5_path, "r") as hf:
54+
with h5py.File(self.filepath, "r") as hf:
5255
images = hf[self.mode]["data"][:]
5356
labels = hf[self.mode]["target"][:]
5457

@@ -105,8 +108,8 @@ def main():
105108

106109
# Load the dataset
107110
dataset = USPSH5_Digit_7_9_Dataset(
108-
h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5",
109-
mode="train",
111+
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git",
112+
train=False,
110113
transform=transform,
111114
)
112115
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

utils/metrics/F1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def _macro_F1(self):
112112

113113
def forward(self, preds, target):
114114
"""
115+
115116
Update the True Positives, False Positives, and False Negatives, and compute the F1 score.
116117
117118
This method computes the F1 score based on the predictions and true labels. It can compute either the

0 commit comments

Comments
 (0)