Skip to content

Commit dd5c6c6

Browse files
committed
Added test_load_model
1 parent 5b1d98b commit dd5c6c6

File tree

2 files changed

+33
-27
lines changed

2 files changed

+33
-27
lines changed

tests/test_metrics.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import pytest
44

55
from CollaborativeCoding.load_metric import MetricWrapper
6-
from CollaborativeCoding.metrics import Accuracy, F1Score, Precision, Recall
6+
from CollaborativeCoding.metrics import (
7+
Accuracy,
8+
EntropyPrediction,
9+
F1Score,
10+
Precision,
11+
Recall,
12+
)
713

814

915
@pytest.mark.parametrize(
@@ -17,7 +23,7 @@
1723
("accuracy", randint(2, 10), True),
1824
("precision", randint(2, 10), False),
1925
("precision", randint(2, 10), True),
20-
# TODO: Add test for EntropyPrediction
26+
("EntropyPrediction", randint(2, 10), False),
2127
],
2228
)
2329
def test_metric_wrapper(metric, num_classes, macro_averaging):

tests/test_wrappers.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
11
from CollaborativeCoding import load_data, load_metric, load_model
22

3-
# def test_load_model():
4-
# import torch as th
5-
6-
# image_shape = (1, 16, 16)
7-
# num_classes = 4
8-
9-
# dummy_img = th.rand((1, *image_shape))
10-
11-
# modelnames = [
12-
# "magnusmodel",
13-
# "christianmodel",
14-
# "janmodel",
15-
# "solveigmodel",
16-
# "johanmodel",
17-
# ]
18-
19-
# for name in modelnames:
20-
# print(name)
21-
# model = load_model(name, image_shape=image_shape, num_classes=num_classes)
22-
23-
# with th.no_grad():
24-
# output = model(dummy_img)
25-
# assert output.size() == (1, 4), (
26-
# f"Model {name} returned image of size {output}. Expected (1,4)"
27-
# )
3+
4+
def test_load_model():
5+
import torch as th
6+
7+
image_shape = (1, 16, 16)
8+
num_classes = 4
9+
10+
dummy_img = th.rand((1, *image_shape))
11+
12+
modelnames = [
13+
"magnusmodel",
14+
"christianmodel",
15+
"janmodel",
16+
"solveigmodel",
17+
"johanmodel",
18+
]
19+
20+
for name in modelnames:
21+
model = load_model(name, image_shape=image_shape, num_classes=num_classes)
22+
23+
with th.no_grad():
24+
output = model(dummy_img)
25+
assert output.size() == (1, 4), (
26+
f"Model {name} returned image of size {output}. Expected (1,4)"
27+
)
2828

2929

3030
def test_load_data():

0 commit comments

Comments
 (0)