Skip to content

Commit bfb895c

Browse files
committed
Commiting before merging christian branch
1 parent 331734d commit bfb895c

File tree

6 files changed

+59
-14
lines changed

6 files changed

+59
-14
lines changed

.gitignore

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__pycache__/
22
.ipynb_checkpoints/
33
Data/*
4+
data/*
45
Results/*
56
Experiments/*
67
_build/
@@ -14,9 +15,7 @@ doc/autoapi
1415

1516
#Magnus specific
1617
job*
17-
env2/*
18-
ruffian.sh
19-
localtest.sh
18+
local*
2019

2120
# Johanthings
2221
formatting.x

CollaborativeCoding/dataloaders/download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _get_labels(path: Path) -> np.ndarray:
8787

8888
def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
8989
def download_svhn(path, train: bool = True):
90-
SVHN()
90+
SVHN(path, split="train" if train else "test", download=True)
9191

9292
parent_path = data_dir / "SVHN"
9393

CollaborativeCoding/load_data.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
6464
case "svhn":
6565
dataset = SVHNDataset
6666
train_labels, test_labels = downloader.svhn(data_dir=data_dir)
67-
labels = np.arange(10)
67+
labels = np.unique(train_labels)
6868
case "mnist_4-9":
6969
dataset = MNISTDataset4_9
7070
train_labels, test_labels = downloader.mnist(data_dir=data_dir)
@@ -78,6 +78,10 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
7878
train_indices = np.arange(len(train_labels))
7979
test_indices = np.arange(len(test_labels))
8080

81+
print(train_indices.shape)
82+
print(np.asarray(train_labels).shape)
83+
print(labels.shape)
84+
8185
# Filter the labels to only get indices of the wanted labels
8286
train_samples = train_indices[np.isin(train_labels, labels)]
8387
test_samples = test_indices[np.isin(test_labels, labels)]

CollaborativeCoding/load_metric.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,28 @@ def getmetrics(self, str_prefix: str = None):
9494
def resetmetric(self):
9595
for key in self.metrics:
9696
self.metrics[key].__reset__()
97+
98+
99+
if __name__ == "__main__":
100+
import torch as th
101+
102+
metrics = ["entropy", "f1", "recall", "precision", "accuracy"]
103+
104+
class_sizes = [3, 6, 10]
105+
for class_size in class_sizes:
106+
y_true = th.rand((5, class_size)).argmax(dim=1)
107+
y_pred = th.rand((5, class_size))
108+
109+
metricwrapper = MetricWrapper(
110+
metric,
111+
num_classes=class_size,
112+
macro_averaging=True if class_size % 2 == 0 else False,
113+
)
114+
115+
metricwrapper(y_true, y_pred)
116+
metric = metricwrapper.getmetrics()
117+
assert metric is not None
118+
119+
metricwrapper.resetmetric()
120+
metric2 = metricwrapper.getmetrics()
121+
assert metric != metric2

tests/test_dataloaders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from CollaborativeCoding.dataloaders import (
1010
MNISTDataset0_3,
11+
SVHNDataset,
1112
USPSDataset0_6,
1213
USPSH5_Digit_7_9_Dataset,
1314
)
@@ -20,6 +21,7 @@
2021
("usps_0-6", USPSDataset0_6),
2122
("usps_7-9", USPSH5_Digit_7_9_Dataset),
2223
("mnist_0-3", MNISTDataset0_3),
24+
("svhn", SVHNDataset),
2325
# TODO: Add more datasets here
2426
],
2527
)

tests/test_wrappers.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pathlib import Path
22

3-
from CollaborativeCoding import load_data, load_metric, load_model
3+
from CollaborativeCoding import MetricWrapper, load_data, load_model
44

55

66
def test_load_model():
@@ -36,13 +36,7 @@ def test_load_data():
3636
import torch as th
3737
from torchvision import transforms
3838

39-
dataset_names = [
40-
"usps_0-6",
41-
"mnist_0-3",
42-
"usps_7-9",
43-
"svhn",
44-
# 'mnist_4-9' #Uncomment when implemented
45-
]
39+
dataset_names = ["usps_0-6", "mnist_0-3", "usps_7-9", "svhn", "mnist_4-9"]
4640

4741
trans = transforms.Compose(
4842
[
@@ -64,4 +58,25 @@ def test_load_data():
6458

6559

6660
def test_load_metric():
67-
pass
61+
import torch as th
62+
63+
metrics = ("entropy", "f1", "recall", "precision", "accuracy")
64+
65+
class_sizes = [3, 6, 10]
66+
for class_size in class_sizes:
67+
y_true = th.rand((5, class_size)).argmax(dim=1)
68+
y_pred = th.rand((5, class_size))
69+
70+
metricwrapper = MetricWrapper(
71+
*metrics,
72+
num_classes=class_size,
73+
macro_averaging=True if class_size % 2 == 0 else False,
74+
)
75+
76+
metricwrapper(y_true, y_pred)
77+
metric = metricwrapper.getmetrics()
78+
assert metric is not None
79+
80+
metricwrapper.resetmetric()
81+
metric2 = metricwrapper.getmetrics()
82+
assert metric != metric2

0 commit comments

Comments
 (0)