Skip to content

Commit 587a42f

Browse files
authored
Merge pull request #87 from SFI-Visual-Intelligence/christian/update-dataloader-recall
Christian/update dataloader recall + a bit of the structure
2 parents 78181b9 + 9683fe8 commit 587a42f

File tree

2 files changed

+8
-20
lines changed

2 files changed

+8
-20
lines changed

CollaborativeCoding/load_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def resetmetric(self):
107107
y_pred = th.rand((5, class_size))
108108

109109
metricwrapper = MetricWrapper(
110-
metric,
110+
*metrics,
111111
num_classes=class_size,
112112
macro_averaging=True if class_size % 2 == 0 else False,
113113
)

tests/test_wrappers.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
from pathlib import Path
2-
from tempfile import TemporaryDirectory
32

43
import pytest
5-
import torch
6-
from torchvision import transforms
4+
import torch as th
75

86
from CollaborativeCoding import MetricWrapper, load_data, load_model
97

108

119
def test_load_model():
12-
import torch as th
13-
1410
image_shape = (1, 16, 16)
1511
num_classes = 4
1612

@@ -36,17 +32,14 @@ def test_load_model():
3632

3733

3834
def test_load_data():
39-
from tempfile import TemporaryDirectory
40-
41-
import torch as th
4235
from torchvision import transforms
4336

4437
dataset_names = [
4538
"usps_0-6",
4639
"mnist_0-3",
4740
"usps_7-9",
4841
"svhn",
49-
"mnist_4-9", # Uncomment when implemented
42+
"mnist_4-9",
5043
]
5144

5245
trans = transforms.Compose(
@@ -56,21 +49,16 @@ def test_load_data():
5649
]
5750
)
5851

59-
with TemporaryDirectory() as tmppath:
60-
for name in dataset_names:
61-
dataset, _, _ = load_data(
62-
name, train=False, data_dir=Path(tmppath), transform=trans
63-
)
52+
for name in dataset_names:
53+
dataset = load_data(name, train=False, data_dir=Path.cwd() / "Data", transform=trans)
6454

65-
im, _ = dataset.__getitem__(0)
55+
im, _ = dataset.__getitem__(0)
6656

67-
assert dataset.__len__() != 0
68-
assert type(im) == th.Tensor and len(im.size()) == 3
57+
assert dataset.__len__() != 0
58+
assert type(im) is th.Tensor and len(im.size()) == 3
6959

7060

7161
def test_load_metric():
72-
import torch as th
73-
7462
metrics = ("entropy", "f1", "recall", "precision", "accuracy")
7563

7664
class_sizes = [3, 6, 10]

0 commit comments

Comments
 (0)