Skip to content

Commit 686e443

Browse files
committed
Fixed and passed all my related functions
1 parent dd5c6c6 commit 686e443

File tree

9 files changed

+82
-45
lines changed

9 files changed

+82
-45
lines changed

CollaborativeCoding/dataloaders/download.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import h5py as h5
1010
import numpy as np
11+
from scipy.io import loadmat
12+
from torchvision.datasets import SVHN
1113

1214
from .datasources import MNIST_SOURCE, USPS_SOURCE
1315

@@ -84,7 +86,26 @@ def _get_labels(path: Path) -> np.ndarray:
8486
return train_labels, test_labels
8587

8688
def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
87-
raise NotImplementedError("SVHN download not implemented yet")
89+
def download_svhn(path, train: bool = True):
90+
SVHN()
91+
92+
parent_path = data_dir / "SVHN"
93+
94+
if not parent_path.exists():
95+
parent_path.mkdir(parents=True)
96+
97+
train_data = parent_path / "train_32x32.mat"
98+
test_data = parent_path / "test_32x32.mat"
99+
100+
if not train_data.exists():
101+
download_svhn(parent_path, train=True)
102+
if not test_data.exists():
103+
download_svhn(parent_path, train=False)
104+
105+
train_labels = loadmat(train_data)["y"]
106+
test_labels = loadmat(test_data)["y"]
107+
108+
return train_labels, test_labels
88109

89110
def usps(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
90111
"""

CollaborativeCoding/dataloaders/svhn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from pathlib import Path
23

34
import h5py
45
import numpy as np
@@ -11,10 +12,10 @@
1112
class SVHNDataset(Dataset):
1213
def __init__(
1314
self,
14-
data_path: str,
15+
data_path: Path,
16+
sample_ids: list,
1517
train: bool,
1618
transform=None,
17-
download: bool = True,
1819
nr_channels=3,
1920
):
2021
"""
@@ -31,11 +32,9 @@ def __init__(
3132
super().__init__()
3233

3334
self.data_path = data_path
35+
self.indexes = sample_ids
3436
self.split = "train" if train else "test"
3537

36-
if download:
37-
self._download_data(data_path)
38-
3938
self.nr_channels = nr_channels
4039
self.transforms = transform
4140

CollaborativeCoding/load_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,23 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
8686
sample_ids=train_samples,
8787
train=True,
8888
transform=transform,
89+
nr_channels=kwargs.get("nr_channels"),
8990
)
9091

9192
val = dataset(
9293
data_path=data_dir,
9394
sample_ids=val_samples,
9495
train=True,
9596
transform=transform,
97+
nr_channels=kwargs.get("nr_channels"),
9698
)
9799

98100
test = dataset(
99101
data_path=data_dir,
100102
sample_ids=test_samples,
101103
train=False,
102104
transform=transform,
105+
nr_channels=kwargs.get("nr_channels"),
103106
)
104107

105108
return train, val, test

CollaborativeCoding/load_metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __call__(self, y_true, y_pred):
8282
for key in self.metrics:
8383
self.metrics[key](y_true, y_pred)
8484

85-
def __getmetrics__(self, str_prefix: str = None):
85+
def getmetrics(self, str_prefix: str = None):
8686
return_metrics = {}
8787
for key in self.metrics:
8888
if str_prefix is not None:
@@ -91,6 +91,6 @@ def __getmetrics__(self, str_prefix: str = None):
9191
return_metrics[key] = self.metrics[key].__returnmetric__()
9292
return return_metrics
9393

94-
def __resetmetrics__(self):
94+
def resetmetric(self):
9595
for key in self.metrics:
9696
self.metrics[key].__reset__()

CollaborativeCoding/metrics/EntropyPred.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
class EntropyPrediction(nn.Module):
8-
def __init__(self, averages: str = "mean"):
8+
def __init__(self, num_classes, macro_averaging=None):
99
"""
1010
Initializes the EntropyPrediction module, which calculates the Shannon Entropy
1111
of predicted logits and aggregates the results based on the specified method.
@@ -17,11 +17,8 @@ def __init__(self, averages: str = "mean"):
1717
"""
1818
super().__init__()
1919

20-
assert averages in ["mean", "sum", "none"], (
21-
"averages must be 'mean', 'sum', or 'none'"
22-
)
23-
self.averages = averages
2420
self.stored_entropy_values = []
21+
self.num_classes = num_classes
2522

2623
def __call__(self, y_true: th.Tensor, y_logits: th.Tensor):
2724
"""
@@ -36,6 +33,10 @@ def __call__(self, y_true: th.Tensor, y_logits: th.Tensor):
3633
"""
3734

3835
assert len(y_logits.size()) == 2, f"y_logits shape: {y_logits.size()}"
36+
assert y_logits.size(-1) == self.num_classes, (
37+
f"y_logit class length: {y_logits.size(-1)}, expected: {self.num_classes}"
38+
)
39+
3940
y_pred = nn.Softmax(dim=1)(y_logits)
4041
print(f"y_pred: {y_pred}")
4142
entropy_values = entropy(y_pred, axis=1)
@@ -50,13 +51,8 @@ def __call__(self, y_true: th.Tensor, y_logits: th.Tensor):
5051

5152
def __returnmetric__(self):
5253
stored_entropy_values = th.from_numpy(np.asarray(self.stored_entropy_values))
54+
stored_entropy_values = th.mean(stored_entropy_values)
5355

54-
if self.averages == "mean":
55-
stored_entropy_values = th.mean(stored_entropy_values)
56-
elif self.averages == "sum":
57-
stored_entropy_values = th.sum(stored_entropy_values)
58-
elif self.averages == "none":
59-
pass
6056
return stored_entropy_values
6157

6258
def __reset__(self):

main.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def main():
5050
data_dir=args.datafolder,
5151
transform=transform,
5252
val_size=args.val_size,
53+
nr_channels=args.nr_channels,
5354
)
5455

5556
train_metrics = MetricWrapper(
@@ -121,7 +122,7 @@ def main():
121122
train_metrics(y, logits)
122123

123124
break
124-
print(train_metrics.accumulate())
125+
print(train_metrics.getmetrics())
125126
print("Dry run completed successfully.")
126127
exit()
127128

@@ -169,11 +170,11 @@ def main():
169170
"Train loss": np.mean(trainingloss),
170171
"Validation loss": np.mean(valloss),
171172
}
172-
| train_metrics.__getmetrics__(str_prefix="Train ")
173-
| val_metrics.__getmetrics__(str_prefix="Validation ")
173+
| train_metrics.getmetric(str_prefix="Train ")
174+
| val_metrics.getmetric(str_prefix="Validation ")
174175
)
175-
train_metrics.__resetmetrics__()
176-
val_metrics.__resetmetrics__()
176+
train_metrics.resetmetric()
177+
val_metrics.resetmetric()
177178

178179
testloss = []
179180
model.eval()
@@ -189,9 +190,9 @@ def main():
189190

190191
wandb.log(
191192
{"Epoch": 1, "Test loss": np.mean(testloss)}
192-
| test_metrics.__getmetrics__(str_prefix="Test ")
193+
| test_metrics.getmetric(str_prefix="Test ")
193194
)
194-
test_metrics.__resetmetrics__()
195+
test_metrics.resetmetric()
195196

196197

197198
if __name__ == "__main__":

tests/test_metrics.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
("accuracy", randint(2, 10), True),
2424
("precision", randint(2, 10), False),
2525
("precision", randint(2, 10), True),
26-
("EntropyPrediction", randint(2, 10), False),
26+
("entropy", randint(2, 10), False),
2727
],
2828
)
2929
def test_metric_wrapper(metric, num_classes, macro_averaging):
@@ -40,9 +40,9 @@ def test_metric_wrapper(metric, num_classes, macro_averaging):
4040
)
4141

4242
metrics(y_true, logits)
43-
score = metrics.accumulate()
44-
metrics.reset()
45-
empty_score = metrics.accumulate()
43+
score = metrics.getmetrics()
44+
metrics.resetmetric()
45+
empty_score = metrics.getmetrics()
4646

4747
assert isinstance(score, dict), "Expected a dictionary output."
4848
assert metric in score, f"Expected {metric} metric in the output."
@@ -151,16 +151,22 @@ def test_accuracy():
151151
def test_entropypred():
152152
import torch as th
153153

154-
pred_logits = th.rand(6, 5)
155154
true_lab = th.rand(6, 5)
156155

157-
metric = EntropyPrediction(averages="mean")
158-
metric2 = EntropyPrediction(averages="sum")
156+
metric = EntropyPrediction(num_classes=5)
159157

160-
# Test for averaging metric consistency
158+
# Test if the metric stores multiple values
159+
pred_logits = th.rand(6, 5)
161160
metric(true_lab, pred_logits)
162-
metric2(true_lab, pred_logits)
163-
assert (
164-
th.abs(th.sum(6 * metric.__returnmetric__() - metric2.__returnmetric__()))
165-
< 1e-5
166-
)
161+
162+
pred_logits = th.rand(6, 5)
163+
metric(true_lab, pred_logits)
164+
165+
pred_logits = th.rand(6, 5)
166+
metric(true_lab, pred_logits)
167+
168+
assert type(metric.__returnmetric__()) == th.Tensor
169+
170+
# Test than an error is raised with num_class != class dimension length
171+
with pytest.raises(AssertionError):
172+
metric(true_lab, th.rand(6, 6))

tests/test_models.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import pytest
22
import torch
33

4-
from CollaborativeCoding.models import ChristianModel, JanModel, MagnusModel
4+
from CollaborativeCoding.models import (
5+
ChristianModel,
6+
JanModel,
7+
JohanModel,
8+
MagnusModel,
9+
SolveigModel,
10+
)
511

612

713
@pytest.mark.parametrize(
@@ -49,15 +55,17 @@ def test_solveig_model(image_shape, num_classes):
4955
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
5056

5157

52-
@pytest.mark.parametrize("image_shape", [(3, 28, 28)])
53-
def test_magnus_model(image_shape):
58+
@pytest.mark.parametrize(
59+
"image_shape, num_classes", [((3, 28, 28), 10), ((1, 16, 16), 10)]
60+
)
61+
def test_magnus_model(image_shape, num_classes):
5462
import torch as th
5563

5664
n, c, h, w = 5, *image_shape
57-
model = MagnusModel([h, w], 10, c)
65+
model = MagnusModel([h, w], num_classes, c)
5866

5967
x = th.rand((n, c, h, w))
6068
with th.no_grad():
6169
y = model(x)
6270

63-
assert y.shape == (n, 10), f"Shape: {y.shape}"
71+
assert y.shape == (n, num_classes), f"Shape: {y.shape}"

tests/test_wrappers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from pathlib import Path
2+
13
from CollaborativeCoding import load_data, load_metric, load_model
24

35

@@ -18,6 +20,7 @@ def test_load_model():
1820
]
1921

2022
for name in modelnames:
23+
print(name)
2124
model = load_model(name, image_shape=image_shape, num_classes=num_classes)
2225

2326
with th.no_grad():
@@ -51,7 +54,7 @@ def test_load_data():
5154
with TemporaryDirectory() as tmppath:
5255
for name in dataset_names:
5356
dataset = load_data(
54-
name, train=False, data_path=tmppath, download=True, transform=trans
57+
name, train=False, data_dir=Path(tmppath), transform=trans
5558
)
5659

5760
im, _ = dataset.__getitem__(0)

0 commit comments

Comments
 (0)