Skip to content

Commit 24e40ae

Browse files
committed
Added test for load_model
1 parent 4481ef8 commit 24e40ae

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

tests/test_wrappers.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from utils import load_data, load_metric, load_model
2+
3+
4+
def test_load_model():
5+
import torch as th
6+
7+
image_shape = (1, 28, 28)
8+
num_classes = 4
9+
10+
dummy_img = th.rand((1, *image_shape))
11+
12+
modelnames = [
13+
"magnusmodel",
14+
"christianmodel",
15+
"janmodel",
16+
"solveigmodel",
17+
"johanmodel",
18+
]
19+
20+
for name in modelnames:
21+
model = load_model(name, image_shape=image_shape, num_classes=num_classes)
22+
23+
with th.no_grad():
24+
output = model(dummy_img)
25+
assert output.size() == (1, 4), (
26+
f"Model {name} returned image of size {output}. Expected (1,4)"
27+
)
28+
29+
30+
def test_load_data():
31+
pass

0 commit comments

Comments
 (0)