Skip to content

Commit 7854122

Browse files
committed
Five more tests to go
2 parents bfb895c + e9fa533 commit 7854122

File tree

10 files changed

+99
-84
lines changed

10 files changed

+99
-84
lines changed

CollaborativeCoding/dataloaders/usps_0_6.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
sample_ids: list,
8484
train: bool = False,
8585
transform=None,
86+
nr_channels=1,
8687
):
8788
super().__init__()
8889

@@ -91,6 +92,7 @@ def __init__(
9192
self.transform = transform
9293
self.mode = "train" if train else "test"
9394
self.sample_ids = sample_ids
95+
self.nr_channels = nr_channels
9496

9597
def __len__(self):
9698
return len(self.sample_ids)
@@ -100,11 +102,18 @@ def __getitem__(self, id):
100102

101103
with h5.File(self.filepath, "r") as f:
102104
data = f[self.mode]["data"][index].astype(np.uint8)
103-
label = f[self.mode]["target"][index]
105+
label = int(f[self.mode]["target"][index])
104106

105-
data = Image.fromarray(data, mode="L")
107+
if self.nr_channels == 1:
108+
data = Image.fromarray(data, mode="L")
109+
elif self.nr_channels == 3:
110+
data = Image.fromarray(data, mode="RGB")
111+
else:
112+
raise ValueError("Invalid number of channels")
106113

107114
if self.transform:
108115
data = self.transform(data)
109116

117+
# label = torch.tensor(label).long()
118+
110119
return data, label

CollaborativeCoding/dataloaders/uspsh5_7_9.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ class USPSH5_Digit_7_9_Dataset(Dataset):
3232
A transform function to apply to the images.
3333
"""
3434

35-
def __init__(self, data_path, sample_ids, train=False, transform=None, nr_channels=1):
35+
def __init__(
36+
self, data_path, sample_ids, train=False, transform=None, nr_channels=1
37+
):
3638
super().__init__()
3739
"""
3840
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
@@ -112,7 +114,8 @@ def main():
112114
indices = np.array([7, 8, 9])
113115
# Load the dataset
114116
dataset = USPSH5_Digit_7_9_Dataset(
115-
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git", sample_ids=indices,
117+
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git",
118+
sample_ids=indices,
116119
train=False,
117120
transform=transform,
118121
)

CollaborativeCoding/load_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,23 +93,23 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
9393
sample_ids=train_samples,
9494
train=True,
9595
transform=transform,
96-
nr_channels=kwargs.get("nr_channels"),
96+
nr_channels=kwargs.get("nr_channels", 1),
9797
)
9898

9999
val = dataset(
100100
data_path=data_dir,
101101
sample_ids=val_samples,
102102
train=True,
103103
transform=transform,
104-
nr_channels=kwargs.get("nr_channels"),
104+
nr_channels=kwargs.get("nr_channels", 1),
105105
)
106106

107107
test = dataset(
108108
data_path=data_dir,
109109
sample_ids=test_samples,
110110
train=False,
111111
transform=transform,
112-
nr_channels=kwargs.get("nr_channels"),
112+
nr_channels=kwargs.get("nr_channels", 1),
113113
)
114114

115115
return train, val, test

CollaborativeCoding/metrics/F1.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,13 @@ def __returnmetric__(self):
159159
else:
160160
self.y_true = torch.cat(self.y_true)
161161
self.y_pred = torch.cat(self.y_pred)
162-
return self._micro_F1(self.y_true, self.y_pred) if not self.macro_averaging else self._macro_F1(self.y_true, self.y_pred)
162+
return (
163+
self._micro_F1(self.y_true, self.y_pred)
164+
if not self.macro_averaging
165+
else self._macro_F1(self.y_true, self.y_pred)
166+
)
163167

164168
def __reset__(self):
165169
self.y_true = []
166170
self.y_pred = []
167171
return None
168-
169-

CollaborativeCoding/metrics/recall.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23
import torch.nn as nn
34

@@ -57,26 +58,49 @@ def __init__(self, num_classes, macro_averaging=False):
5758
self.num_classes = num_classes
5859
self.macro_averaging = macro_averaging
5960

61+
self.__y_true = []
62+
self.__y_pred = []
63+
6064
def forward(self, true, logits):
6165
pred = logits.argmax(dim=-1)
6266
y_true = one_hot_encode(true, self.num_classes)
6367
y_pred = one_hot_encode(pred, self.num_classes)
6468

69+
self.__y_true.append(y_true)
70+
self.__y_pred.append(y_pred)
71+
72+
def compute(self, y_true, y_pred):
6573
if self.macro_averaging:
66-
recall = 0
67-
for i in range(self.num_classes):
68-
tp = (y_true[:, i] * y_pred[:, i]).sum()
69-
fn = torch.sum(~y_pred[y_true[:, i].bool()].bool())
70-
recall += tp / (tp + fn)
71-
recall /= self.num_classes
72-
else:
73-
recall = self.__compute(y_true, y_pred)
74+
return self.__compute_macro_averaging(y_true, y_pred)
75+
76+
return self.__compute_micro_averaging(y_true, y_pred)
77+
78+
def __compute_macro_averaging(self, y_true, y_pred):
79+
recall = 0
80+
for i in range(self.num_classes):
81+
tp = (y_true[:, i] * y_pred[:, i]).sum()
82+
fn = torch.sum(~y_pred[y_true[:, i].bool()].bool())
83+
recall += tp / (tp + fn)
84+
recall /= self.num_classes
7485

7586
return recall
7687

77-
def __compute(self, y_true, y_pred):
88+
def __compute_micro_averaging(self, y_true, y_pred):
7889
true_positives = (y_true * y_pred).sum()
7990
false_negatives = torch.sum(~y_pred[y_true.bool()].bool())
8091

8192
recall = true_positives / (true_positives + false_negatives)
8293
return recall
94+
95+
def __returnmetric__(self):
96+
if len(self.__y_true) == 0 and len(self.__y_pred) == 0:
97+
return np.nan
98+
99+
y_true = torch.cat(self.__y_true, dim=0)
100+
y_pred = torch.cat(self.__y_pred, dim=0)
101+
102+
return self.compute(y_true, y_pred)
103+
104+
def __reset__(self):
105+
self.__y_true = []
106+
self.__y_pred = []

CollaborativeCoding/models/solveig_model.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,24 @@
44

55
def find_fc_input_shape(image_shape, model):
66
"""
7-
Find the shape of the input to the fully connected layer after passing through the convolutional layers.
8-
9-
Code inspired by @Seilmast (https://github.com/SFI-Visual-Intelligence/Collaborative-Coding-Exam/issues/67#issuecomment-2651212254)
10-
11-
Args
12-
----
13-
image_shape : tuple(int, int, int)
14-
Shape of the input image (C, H, W), where C is the number of channels,
15-
H is the height, and W is the width of the image.
16-
model : nn.Module
17-
The CNN model containing the convolutional layers, whose output size is used to
18-
determine the number of input features for the fully connected layer.
19-
20-
Returns
21-
-------
22-
int
23-
The number of elements in the input to the fully connected layer.
24-
"""
7+
Find the shape of the input to the fully connected layer after passing through the convolutional layers.
8+
9+
Code inspired by @Seilmast (https://github.com/SFI-Visual-Intelligence/Collaborative-Coding-Exam/issues/67#issuecomment-2651212254)
10+
11+
Args
12+
----
13+
image_shape : tuple(int, int, int)
14+
Shape of the input image (C, H, W), where C is the number of channels,
15+
H is the height, and W is the width of the image.
16+
model : nn.Module
17+
The CNN model containing the convolutional layers, whose output size is used to
18+
determine the number of input features for the fully connected layer.
19+
20+
Returns
21+
-------
22+
int
23+
The number of elements in the input to the fully connected layer.
24+
"""
2525

2626
dummy_img = torch.randn(1, *image_shape)
2727
with torch.no_grad():

main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
import torch as th
33
import torch.nn as nn
4+
import wandb
45
from torch.utils.data import DataLoader
56
from torchvision import transforms
67
from tqdm import tqdm
78

8-
import wandb
99
from CollaborativeCoding import (
1010
MetricWrapper,
1111
createfolders,
@@ -17,7 +17,6 @@
1717
# from wandb_api import WANDB_API
1818

1919

20-
2120
def main():
2221
"""
2322

tests/test_dataloaders.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
import pytest
55
import torch
6-
from PIL import Image
76
from torchvision import transforms
87

98
from CollaborativeCoding.dataloaders import (
@@ -38,42 +37,3 @@ def test_load_data(data_name, expected):
3837
assert isinstance(
3938
dataset[0][1], (int, torch.Tensor, np.ndarray)
4039
) # Should probably restrict this to only int or one-hot encoded tensor or array for consistency.
41-
42-
43-
def test_uspsdataset0_6():
44-
from tempfile import TemporaryDirectory
45-
46-
import h5py
47-
import numpy as np
48-
from torchvision import transforms
49-
50-
# Create a temporary directory (deleted after the test)
51-
with TemporaryDirectory() as tempdir:
52-
tempdir = Path(tempdir)
53-
54-
tf = tempdir / "usps.h5"
55-
56-
# Create a h5 file
57-
with h5py.File(tf, "w") as f:
58-
targets = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
59-
indices = np.arange(len(targets))
60-
# Populate the file with data
61-
f["train/data"] = np.random.rand(10, 16 * 16)
62-
f["train/target"] = targets
63-
64-
trans = transforms.Compose(
65-
[
66-
transforms.Resize((16, 16)),
67-
transforms.ToTensor(),
68-
]
69-
)
70-
dataset = USPSDataset0_6(
71-
data_path=tempdir,
72-
sample_ids=indices,
73-
train=True,
74-
transform=trans,
75-
)
76-
assert len(dataset) == 10
77-
data, target = dataset[0]
78-
assert data.shape == (1, 16, 16)
79-
assert target == 6

tests/test_metrics.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,11 @@ def test_recall():
5959
recall_micro = Recall(7)
6060
recall_macro = Recall(7, macro_averaging=True)
6161

62-
recall_micro_score = recall_micro(y_true, logits)
63-
recall_macro_score = recall_macro(y_true, logits)
62+
recall_micro(y_true, logits)
63+
recall_macro(y_true, logits)
64+
65+
recall_micro_score = recall_micro.__returnmetric__()
66+
recall_macro_score = recall_macro.__returnmetric__()
6467

6568
assert isinstance(recall_micro_score, torch.Tensor), "Expected a tensor output."
6669
assert isinstance(recall_macro_score, torch.Tensor), "Expected a tensor output."
@@ -88,8 +91,12 @@ def test_f1score():
8891
macro_f1_score = f1_macro.__returnmetric__()
8992

9093
# Check if outputs are tensors
91-
assert isinstance(micro_f1_score, torch.Tensor), "Micro F1 score should be a tensor."
92-
assert isinstance(macro_f1_score, torch.Tensor), "Macro F1 score should be a tensor."
94+
assert isinstance(micro_f1_score, torch.Tensor), (
95+
"Micro F1 score should be a tensor."
96+
)
97+
assert isinstance(macro_f1_score, torch.Tensor), (
98+
"Macro F1 score should be a tensor."
99+
)
93100

94101
# Check that F1 scores are between 0 and 1
95102
assert 0 <= micro_f1_score.item() <= 1, "Micro F1 score should be between 0 and 1."

tests/test_wrappers.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
from pathlib import Path
2+
from tempfile import TemporaryDirectory
3+
4+
import pytest
5+
import torch
6+
from torchvision import transforms
27

38
from CollaborativeCoding import MetricWrapper, load_data, load_model
49

@@ -36,7 +41,13 @@ def test_load_data():
3641
import torch as th
3742
from torchvision import transforms
3843

39-
dataset_names = ["usps_0-6", "mnist_0-3", "usps_7-9", "svhn", "mnist_4-9"]
44+
dataset_names = [
45+
"usps_0-6",
46+
"mnist_0-3",
47+
"usps_7-9",
48+
"svhn",
49+
# 'mnist_4-9' #Uncomment when implemented
50+
]
4051

4152
trans = transforms.Compose(
4253
[

0 commit comments

Comments
 (0)