Skip to content

Commit 6fb5296

Browse files
authored
Merge pull request #77 from SFI-Visual-Intelligence/mag-branch
Changed markdown header
2 parents 4174cd4 + 686e443 commit 6fb5296

File tree

10 files changed

+109
-68
lines changed

10 files changed

+109
-68
lines changed

CollaborativeCoding/dataloaders/download.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import h5py as h5
1010
import numpy as np
11+
from scipy.io import loadmat
12+
from torchvision.datasets import SVHN
1113

1214
from .datasources import MNIST_SOURCE, USPS_SOURCE
1315

@@ -84,7 +86,26 @@ def _get_labels(path: Path) -> np.ndarray:
8486
return train_labels, test_labels
8587

8688
def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
87-
raise NotImplementedError("SVHN download not implemented yet")
89+
def download_svhn(path, train: bool = True):
90+
SVHN()
91+
92+
parent_path = data_dir / "SVHN"
93+
94+
if not parent_path.exists():
95+
parent_path.mkdir(parents=True)
96+
97+
train_data = parent_path / "train_32x32.mat"
98+
test_data = parent_path / "test_32x32.mat"
99+
100+
if not train_data.exists():
101+
download_svhn(parent_path, train=True)
102+
if not test_data.exists():
103+
download_svhn(parent_path, train=False)
104+
105+
train_labels = loadmat(train_data)["y"]
106+
test_labels = loadmat(test_data)["y"]
107+
108+
return train_labels, test_labels
88109

89110
def usps(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
90111
"""

CollaborativeCoding/dataloaders/svhn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from pathlib import Path
23

34
import h5py
45
import numpy as np
@@ -11,10 +12,10 @@
1112
class SVHNDataset(Dataset):
1213
def __init__(
1314
self,
14-
data_path: str,
15+
data_path: Path,
16+
sample_ids: list,
1517
train: bool,
1618
transform=None,
17-
download: bool = True,
1819
nr_channels=3,
1920
):
2021
"""
@@ -31,11 +32,9 @@ def __init__(
3132
super().__init__()
3233

3334
self.data_path = data_path
35+
self.indexes = sample_ids
3436
self.split = "train" if train else "test"
3537

36-
if download:
37-
self._download_data(data_path)
38-
3938
self.nr_channels = nr_channels
4039
self.transforms = transform
4140

CollaborativeCoding/load_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,23 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
8686
sample_ids=train_samples,
8787
train=True,
8888
transform=transform,
89+
nr_channels=kwargs.get("nr_channels"),
8990
)
9091

9192
val = dataset(
9293
data_path=data_dir,
9394
sample_ids=val_samples,
9495
train=True,
9596
transform=transform,
97+
nr_channels=kwargs.get("nr_channels"),
9698
)
9799

98100
test = dataset(
99101
data_path=data_dir,
100102
sample_ids=test_samples,
101103
train=False,
102104
transform=transform,
105+
nr_channels=kwargs.get("nr_channels"),
103106
)
104107

105108
return train, val, test

CollaborativeCoding/load_metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __call__(self, y_true, y_pred):
8282
for key in self.metrics:
8383
self.metrics[key](y_true, y_pred)
8484

85-
def __getmetrics__(self, str_prefix: str = None):
85+
def getmetrics(self, str_prefix: str = None):
8686
return_metrics = {}
8787
for key in self.metrics:
8888
if str_prefix is not None:
@@ -91,6 +91,6 @@ def __getmetrics__(self, str_prefix: str = None):
9191
return_metrics[key] = self.metrics[key].__returnmetric__()
9292
return return_metrics
9393

94-
def __resetmetrics__(self):
94+
def resetmetric(self):
9595
for key in self.metrics:
9696
self.metrics[key].__reset__()

CollaborativeCoding/metrics/EntropyPred.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
class EntropyPrediction(nn.Module):
8-
def __init__(self, averages: str = "mean"):
8+
def __init__(self, num_classes, macro_averaging=None):
99
"""
1010
Initializes the EntropyPrediction module, which calculates the Shannon Entropy
1111
of predicted logits and aggregates the results based on the specified method.
@@ -17,11 +17,8 @@ def __init__(self, averages: str = "mean"):
1717
"""
1818
super().__init__()
1919

20-
assert averages in ["mean", "sum", "none"], (
21-
"averages must be 'mean', 'sum', or 'none'"
22-
)
23-
self.averages = averages
2420
self.stored_entropy_values = []
21+
self.num_classes = num_classes
2522

2623
def __call__(self, y_true: th.Tensor, y_logits: th.Tensor):
2724
"""
@@ -36,6 +33,10 @@ def __call__(self, y_true: th.Tensor, y_logits: th.Tensor):
3633
"""
3734

3835
assert len(y_logits.size()) == 2, f"y_logits shape: {y_logits.size()}"
36+
assert y_logits.size(-1) == self.num_classes, (
37+
f"y_logit class length: {y_logits.size(-1)}, expected: {self.num_classes}"
38+
)
39+
3940
y_pred = nn.Softmax(dim=1)(y_logits)
4041
print(f"y_pred: {y_pred}")
4142
entropy_values = entropy(y_pred, axis=1)
@@ -50,13 +51,8 @@ def __call__(self, y_true: th.Tensor, y_logits: th.Tensor):
5051

5152
def __returnmetric__(self):
5253
stored_entropy_values = th.from_numpy(np.asarray(self.stored_entropy_values))
54+
stored_entropy_values = th.mean(stored_entropy_values)
5355

54-
if self.averages == "mean":
55-
stored_entropy_values = th.mean(stored_entropy_values)
56-
elif self.averages == "sum":
57-
stored_entropy_values = th.sum(stored_entropy_values)
58-
elif self.averages == "none":
59-
pass
6056
return stored_entropy_values
6157

6258
def __reset__(self):

doc/Magnus_page.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
Magnus Individual Task
22
======================
33

4-
# Magnus Størdal Individual Task
5-
64
## Task overview
75
In addition to the overall task, I was tasked to implement a three layer linear network, a dataset loader for the SVHN dataset, and a entropy metric.
86

main.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def main():
5050
data_dir=args.datafolder,
5151
transform=transform,
5252
val_size=args.val_size,
53+
nr_channels=args.nr_channels,
5354
)
5455

5556
train_metrics = MetricWrapper(
@@ -121,7 +122,7 @@ def main():
121122
train_metrics(y, logits)
122123

123124
break
124-
print(train_metrics.accumulate())
125+
print(train_metrics.getmetrics())
125126
print("Dry run completed successfully.")
126127
exit()
127128

@@ -169,11 +170,11 @@ def main():
169170
"Train loss": np.mean(trainingloss),
170171
"Validation loss": np.mean(valloss),
171172
}
172-
| train_metrics.__getmetrics__(str_prefix="Train ")
173-
| val_metrics.__getmetrics__(str_prefix="Validation ")
173+
| train_metrics.getmetric(str_prefix="Train ")
174+
| val_metrics.getmetric(str_prefix="Validation ")
174175
)
175-
train_metrics.__resetmetrics__()
176-
val_metrics.__resetmetrics__()
176+
train_metrics.resetmetric()
177+
val_metrics.resetmetric()
177178

178179
testloss = []
179180
model.eval()
@@ -189,9 +190,9 @@ def main():
189190

190191
wandb.log(
191192
{"Epoch": 1, "Test loss": np.mean(testloss)}
192-
| test_metrics.__getmetrics__(str_prefix="Test ")
193+
| test_metrics.getmetric(str_prefix="Test ")
193194
)
194-
test_metrics.__resetmetrics__()
195+
test_metrics.resetmetric()
195196

196197

197198
if __name__ == "__main__":

tests/test_metrics.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import pytest
44

55
from CollaborativeCoding.load_metric import MetricWrapper
6-
from CollaborativeCoding.metrics import Accuracy, F1Score, Precision, Recall
6+
from CollaborativeCoding.metrics import (
7+
Accuracy,
8+
EntropyPrediction,
9+
F1Score,
10+
Precision,
11+
Recall,
12+
)
713

814

915
@pytest.mark.parametrize(
@@ -17,7 +23,7 @@
1723
("accuracy", randint(2, 10), True),
1824
("precision", randint(2, 10), False),
1925
("precision", randint(2, 10), True),
20-
# TODO: Add test for EntropyPrediction
26+
("entropy", randint(2, 10), False),
2127
],
2228
)
2329
def test_metric_wrapper(metric, num_classes, macro_averaging):
@@ -34,9 +40,9 @@ def test_metric_wrapper(metric, num_classes, macro_averaging):
3440
)
3541

3642
metrics(y_true, logits)
37-
score = metrics.accumulate()
38-
metrics.reset()
39-
empty_score = metrics.accumulate()
43+
score = metrics.getmetrics()
44+
metrics.resetmetric()
45+
empty_score = metrics.getmetrics()
4046

4147
assert isinstance(score, dict), "Expected a dictionary output."
4248
assert metric in score, f"Expected {metric} metric in the output."
@@ -145,16 +151,22 @@ def test_accuracy():
145151
def test_entropypred():
146152
import torch as th
147153

148-
pred_logits = th.rand(6, 5)
149154
true_lab = th.rand(6, 5)
150155

151-
metric = EntropyPrediction(averages="mean")
152-
metric2 = EntropyPrediction(averages="sum")
156+
metric = EntropyPrediction(num_classes=5)
153157

154-
# Test for averaging metric consistency
158+
# Test if the metric stores multiple values
159+
pred_logits = th.rand(6, 5)
155160
metric(true_lab, pred_logits)
156-
metric2(true_lab, pred_logits)
157-
assert (
158-
th.abs(th.sum(6 * metric.__returnmetric__() - metric2.__returnmetric__()))
159-
< 1e-5
160-
)
161+
162+
pred_logits = th.rand(6, 5)
163+
metric(true_lab, pred_logits)
164+
165+
pred_logits = th.rand(6, 5)
166+
metric(true_lab, pred_logits)
167+
168+
assert type(metric.__returnmetric__()) == th.Tensor
169+
170+
# Test than an error is raised with num_class != class dimension length
171+
with pytest.raises(AssertionError):
172+
metric(true_lab, th.rand(6, 6))

tests/test_models.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import pytest
22
import torch
33

4-
from CollaborativeCoding.models import ChristianModel, JanModel, MagnusModel
4+
from CollaborativeCoding.models import (
5+
ChristianModel,
6+
JanModel,
7+
JohanModel,
8+
MagnusModel,
9+
SolveigModel,
10+
)
511

612

713
@pytest.mark.parametrize(
@@ -49,15 +55,17 @@ def test_solveig_model(image_shape, num_classes):
4955
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
5056

5157

52-
@pytest.mark.parametrize("image_shape", [(3, 28, 28)])
53-
def test_magnus_model(image_shape):
58+
@pytest.mark.parametrize(
59+
"image_shape, num_classes", [((3, 28, 28), 10), ((1, 16, 16), 10)]
60+
)
61+
def test_magnus_model(image_shape, num_classes):
5462
import torch as th
5563

5664
n, c, h, w = 5, *image_shape
57-
model = MagnusModel([h, w], 10, c)
65+
model = MagnusModel([h, w], num_classes, c)
5866

5967
x = th.rand((n, c, h, w))
6068
with th.no_grad():
6169
y = model(x)
6270

63-
assert y.shape == (n, 10), f"Shape: {y.shape}"
71+
assert y.shape == (n, num_classes), f"Shape: {y.shape}"

tests/test_wrappers.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,33 @@
1+
from pathlib import Path
2+
13
from CollaborativeCoding import load_data, load_metric, load_model
24

3-
# def test_load_model():
4-
# import torch as th
55

6-
# image_shape = (1, 16, 16)
7-
# num_classes = 4
6+
def test_load_model():
7+
import torch as th
88

9-
# dummy_img = th.rand((1, *image_shape))
9+
image_shape = (1, 16, 16)
10+
num_classes = 4
1011

11-
# modelnames = [
12-
# "magnusmodel",
13-
# "christianmodel",
14-
# "janmodel",
15-
# "solveigmodel",
16-
# "johanmodel",
17-
# ]
12+
dummy_img = th.rand((1, *image_shape))
1813

19-
# for name in modelnames:
20-
# print(name)
21-
# model = load_model(name, image_shape=image_shape, num_classes=num_classes)
14+
modelnames = [
15+
"magnusmodel",
16+
"christianmodel",
17+
"janmodel",
18+
"solveigmodel",
19+
"johanmodel",
20+
]
2221

23-
# with th.no_grad():
24-
# output = model(dummy_img)
25-
# assert output.size() == (1, 4), (
26-
# f"Model {name} returned image of size {output}. Expected (1,4)"
27-
# )
22+
for name in modelnames:
23+
print(name)
24+
model = load_model(name, image_shape=image_shape, num_classes=num_classes)
25+
26+
with th.no_grad():
27+
output = model(dummy_img)
28+
assert output.size() == (1, 4), (
29+
f"Model {name} returned image of size {output}. Expected (1,4)"
30+
)
2831

2932

3033
def test_load_data():
@@ -51,7 +54,7 @@ def test_load_data():
5154
with TemporaryDirectory() as tmppath:
5255
for name in dataset_names:
5356
dataset = load_data(
54-
name, train=False, data_path=tmppath, download=True, transform=trans
57+
name, train=False, data_dir=Path(tmppath), transform=trans
5558
)
5659

5760
im, _ = dataset.__getitem__(0)

0 commit comments

Comments
 (0)