Skip to content

Commit 1a402ba

Browse files
authored
Merge branch 'main' into Jan-dev
2 parents 87753d5 + 6fb5296 commit 1a402ba

File tree

10 files changed

+96
-66
lines changed

10 files changed

+96
-66
lines changed

CollaborativeCoding/dataloaders/download.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import h5py as h5
1010
import numpy as np
11+
from scipy.io import loadmat
12+
from torchvision.datasets import SVHN
1113

1214
from .datasources import MNIST_SOURCE, USPS_SOURCE
1315

@@ -84,7 +86,26 @@ def _get_labels(path: Path) -> np.ndarray:
8486
return train_labels, test_labels
8587

8688
def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
87-
raise NotImplementedError("SVHN download not implemented yet")
89+
def download_svhn(path, train: bool = True):
90+
SVHN()
91+
92+
parent_path = data_dir / "SVHN"
93+
94+
if not parent_path.exists():
95+
parent_path.mkdir(parents=True)
96+
97+
train_data = parent_path / "train_32x32.mat"
98+
test_data = parent_path / "test_32x32.mat"
99+
100+
if not train_data.exists():
101+
download_svhn(parent_path, train=True)
102+
if not test_data.exists():
103+
download_svhn(parent_path, train=False)
104+
105+
train_labels = loadmat(train_data)["y"]
106+
test_labels = loadmat(test_data)["y"]
107+
108+
return train_labels, test_labels
88109

89110
def usps(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
90111
"""

CollaborativeCoding/dataloaders/svhn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from pathlib import Path
23

34
import h5py
45
import numpy as np
@@ -11,10 +12,10 @@
1112
class SVHNDataset(Dataset):
1213
def __init__(
1314
self,
14-
data_path: str,
15+
data_path: Path,
16+
sample_ids: list,
1517
train: bool,
1618
transform=None,
17-
download: bool = True,
1819
nr_channels=3,
1920
):
2021
"""
@@ -31,11 +32,9 @@ def __init__(
3132
super().__init__()
3233

3334
self.data_path = data_path
35+
self.indexes = sample_ids
3436
self.split = "train" if train else "test"
3537

36-
if download:
37-
self._download_data(data_path)
38-
3938
self.nr_channels = nr_channels
4039
self.transforms = transform
4140

CollaborativeCoding/load_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,23 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
8686
sample_ids=train_samples,
8787
train=True,
8888
transform=transform,
89+
nr_channels=kwargs.get("nr_channels"),
8990
)
9091

9192
val = dataset(
9293
data_path=data_dir,
9394
sample_ids=val_samples,
9495
train=True,
9596
transform=transform,
97+
nr_channels=kwargs.get("nr_channels"),
9698
)
9799

98100
test = dataset(
99101
data_path=data_dir,
100102
sample_ids=test_samples,
101103
train=False,
102104
transform=transform,
105+
nr_channels=kwargs.get("nr_channels"),
103106
)
104107

105108
return train, val, test

CollaborativeCoding/load_metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __call__(self, y_true, y_pred):
8282
for key in self.metrics:
8383
self.metrics[key](y_true, y_pred)
8484

85-
def __getmetrics__(self, str_prefix: str = None):
85+
def getmetrics(self, str_prefix: str = None):
8686
return_metrics = {}
8787
for key in self.metrics:
8888
if str_prefix is not None:
@@ -91,6 +91,6 @@ def __getmetrics__(self, str_prefix: str = None):
9191
return_metrics[key] = self.metrics[key].__returnmetric__()
9292
return return_metrics
9393

94-
def __resetmetrics__(self):
94+
def resetmetric(self):
9595
for key in self.metrics:
9696
self.metrics[key].__reset__()

CollaborativeCoding/metrics/EntropyPred.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
class EntropyPrediction(nn.Module):
8-
def __init__(self, averages: str = "mean"):
8+
def __init__(self, num_classes, macro_averaging=None):
99
"""
1010
Initializes the EntropyPrediction module, which calculates the Shannon Entropy
1111
of predicted logits and aggregates the results based on the specified method.
@@ -17,11 +17,8 @@ def __init__(self, averages: str = "mean"):
1717
"""
1818
super().__init__()
1919

20-
assert averages in ["mean", "sum", "none"], (
21-
"averages must be 'mean', 'sum', or 'none'"
22-
)
23-
self.averages = averages
2420
self.stored_entropy_values = []
21+
self.num_classes = num_classes
2522

2623
def __call__(self, y_true: th.Tensor, y_logits: th.Tensor):
2724
"""
@@ -36,6 +33,10 @@ def __call__(self, y_true: th.Tensor, y_logits: th.Tensor):
3633
"""
3734

3835
assert len(y_logits.size()) == 2, f"y_logits shape: {y_logits.size()}"
36+
assert y_logits.size(-1) == self.num_classes, (
37+
f"y_logit class length: {y_logits.size(-1)}, expected: {self.num_classes}"
38+
)
39+
3940
y_pred = nn.Softmax(dim=1)(y_logits)
4041
print(f"y_pred: {y_pred}")
4142
entropy_values = entropy(y_pred, axis=1)
@@ -50,13 +51,8 @@ def __call__(self, y_true: th.Tensor, y_logits: th.Tensor):
5051

5152
def __returnmetric__(self):
5253
stored_entropy_values = th.from_numpy(np.asarray(self.stored_entropy_values))
54+
stored_entropy_values = th.mean(stored_entropy_values)
5355

54-
if self.averages == "mean":
55-
stored_entropy_values = th.mean(stored_entropy_values)
56-
elif self.averages == "sum":
57-
stored_entropy_values = th.sum(stored_entropy_values)
58-
elif self.averages == "none":
59-
pass
6056
return stored_entropy_values
6157

6258
def __reset__(self):

doc/Magnus_page.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
Magnus Individual Task
22
======================
33

4-
# Magnus Størdal Individual Task
5-
64
## Task overview
75
In addition to the overall task, I was tasked to implement a three layer linear network, a dataset loader for the SVHN dataset, and a entropy metric.
86

main.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def main():
5353
data_dir=args.datafolder,
5454
transform=transform,
5555
val_size=args.val_size,
56+
nr_channels=args.nr_channels,
5657
)
5758

5859
train_metrics = MetricWrapper(
@@ -124,7 +125,7 @@ def main():
124125
train_metrics(y, logits)
125126

126127
break
127-
print(train_metrics.__getmetrics__())
128+
print(train_metrics.getmetrics())
128129
print("Dry run completed successfully.")
129130
exit()
130131

@@ -172,11 +173,11 @@ def main():
172173
"Train loss": np.mean(trainingloss),
173174
"Validation loss": np.mean(valloss),
174175
}
175-
| train_metrics.__getmetrics__(str_prefix="Train ")
176-
| val_metrics.__getmetrics__(str_prefix="Validation ")
176+
| train_metrics.getmetric(str_prefix="Train ")
177+
| val_metrics.getmetric(str_prefix="Validation ")
177178
)
178-
train_metrics.__resetmetrics__()
179-
val_metrics.__resetmetrics__()
179+
train_metrics.resetmetric()
180+
val_metrics.resetmetric()
180181

181182
testloss = []
182183
model.eval()
@@ -192,9 +193,9 @@ def main():
192193

193194
wandb.log(
194195
{"Epoch": 1, "Test loss": np.mean(testloss)}
195-
| test_metrics.__getmetrics__(str_prefix="Test ")
196+
| test_metrics.getmetric(str_prefix="Test ")
196197
)
197-
test_metrics.__resetmetrics__()
198+
test_metrics.resetmetric()
198199

199200

200201
if __name__ == "__main__":

tests/test_metrics.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
("accuracy", randint(2, 10), True),
2424
("precision", randint(2, 10), False),
2525
("precision", randint(2, 10), True),
26-
# TODO: Add test for EntropyPrediction
26+
("entropy", randint(2, 10), False),
2727
],
2828
)
2929
def test_metric_wrapper(metric, num_classes, macro_averaging):
@@ -40,9 +40,9 @@ def test_metric_wrapper(metric, num_classes, macro_averaging):
4040
)
4141

4242
metrics(y_true, logits)
43-
score = metrics.__getmetrics__()
44-
metrics.__resetmetrics__()
45-
empty_score = metrics.__getmetrics__()
43+
score = metrics.getmetrics()
44+
metrics.resetmetric()
45+
empty_score = metrics.getmetrics()
4646

4747
assert isinstance(score, dict), "Expected a dictionary output."
4848
assert metric in score, f"Expected {metric} metric in the output."
@@ -169,16 +169,22 @@ def test_accuracy():
169169
def test_entropypred():
170170
import torch as th
171171

172-
pred_logits = th.rand(6, 5)
173172
true_lab = th.rand(6, 5)
174173

175-
metric = EntropyPrediction(averages="mean")
176-
metric2 = EntropyPrediction(averages="sum")
174+
metric = EntropyPrediction(num_classes=5)
177175

178-
# Test for averaging metric consistency
176+
# Test if the metric stores multiple values
177+
pred_logits = th.rand(6, 5)
179178
metric(true_lab, pred_logits)
180-
metric2(true_lab, pred_logits)
181-
assert (
182-
th.abs(th.sum(6 * metric.__returnmetric__() - metric2.__returnmetric__()))
183-
< 1e-5
184-
)
179+
180+
pred_logits = th.rand(6, 5)
181+
metric(true_lab, pred_logits)
182+
183+
pred_logits = th.rand(6, 5)
184+
metric(true_lab, pred_logits)
185+
186+
assert type(metric.__returnmetric__()) == th.Tensor
187+
188+
# Test than an error is raised with num_class != class dimension length
189+
with pytest.raises(AssertionError):
190+
metric(true_lab, th.rand(6, 6))

tests/test_models.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from CollaborativeCoding.models import (
55
ChristianModel,
66
JanModel,
7+
JohanModel,
78
MagnusModel,
89
SolveigModel,
910
)
@@ -54,15 +55,17 @@ def test_solveig_model(image_shape, num_classes):
5455
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
5556

5657

57-
@pytest.mark.parametrize("image_shape", [(3, 28, 28)])
58-
def test_magnus_model(image_shape):
58+
@pytest.mark.parametrize(
59+
"image_shape, num_classes", [((3, 28, 28), 10), ((1, 16, 16), 10)]
60+
)
61+
def test_magnus_model(image_shape, num_classes):
5962
import torch as th
6063

6164
n, c, h, w = 5, *image_shape
62-
model = MagnusModel([h, w], 10, c)
65+
model = MagnusModel([h, w], num_classes, c)
6366

6467
x = th.rand((n, c, h, w))
6568
with th.no_grad():
6669
y = model(x)
6770

68-
assert y.shape == (n, 10), f"Shape: {y.shape}"
71+
assert y.shape == (n, num_classes), f"Shape: {y.shape}"

tests/test_wrappers.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,33 @@
1+
from pathlib import Path
2+
13
from CollaborativeCoding import load_data, load_metric, load_model
24

3-
# def test_load_model():
4-
# import torch as th
55

6-
# image_shape = (1, 16, 16)
7-
# num_classes = 4
6+
def test_load_model():
7+
import torch as th
88

9-
# dummy_img = th.rand((1, *image_shape))
9+
image_shape = (1, 16, 16)
10+
num_classes = 4
1011

11-
# modelnames = [
12-
# "magnusmodel",
13-
# "christianmodel",
14-
# "janmodel",
15-
# "solveigmodel",
16-
# "johanmodel",
17-
# ]
12+
dummy_img = th.rand((1, *image_shape))
1813

19-
# for name in modelnames:
20-
# print(name)
21-
# model = load_model(name, image_shape=image_shape, num_classes=num_classes)
14+
modelnames = [
15+
"magnusmodel",
16+
"christianmodel",
17+
"janmodel",
18+
"solveigmodel",
19+
"johanmodel",
20+
]
2221

23-
# with th.no_grad():
24-
# output = model(dummy_img)
25-
# assert output.size() == (1, 4), (
26-
# f"Model {name} returned image of size {output}. Expected (1,4)"
27-
# )
22+
for name in modelnames:
23+
print(name)
24+
model = load_model(name, image_shape=image_shape, num_classes=num_classes)
25+
26+
with th.no_grad():
27+
output = model(dummy_img)
28+
assert output.size() == (1, 4), (
29+
f"Model {name} returned image of size {output}. Expected (1,4)"
30+
)
2831

2932

3033
def test_load_data():
@@ -51,7 +54,7 @@ def test_load_data():
5154
with TemporaryDirectory() as tmppath:
5255
for name in dataset_names:
5356
dataset = load_data(
54-
name, train=False, data_path=tmppath, download=True, transform=trans
57+
name, train=False, data_dir=Path(tmppath), transform=trans
5558
)
5659

5760
im, _ = dataset.__getitem__(0)

0 commit comments

Comments
 (0)