Skip to content

Commit 9b6b054

Browse files
committed
Remove each model test and instead test all in load_model
1 parent 690e0e8 commit 9b6b054

File tree

1 file changed

+17
-60
lines changed

1 file changed

+17
-60
lines changed

tests/test_models.py

Lines changed: 17 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,28 @@
11
import pytest
22
import torch
33

4-
from CollaborativeCoding.models import (
5-
ChristianModel,
6-
JanModel,
7-
JohanModel,
8-
MagnusModel,
9-
SolveigModel,
10-
)
11-
12-
13-
@pytest.mark.parametrize(
14-
"image_shape, num_classes",
15-
[((1, 16, 16), 6), ((3, 16, 16), 6)],
16-
)
17-
def test_christian_model(image_shape, num_classes):
18-
n, c, h, w = 5, *image_shape
19-
20-
model = ChristianModel(image_shape, num_classes)
21-
22-
x = torch.randn(n, c, h, w)
23-
y = model(x)
24-
25-
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
26-
27-
28-
@pytest.mark.parametrize(
29-
"image_shape, num_classes",
30-
[((1, 28, 28), 4), ((3, 16, 16), 10)],
31-
)
32-
def test_jan_model(image_shape, num_classes):
33-
n, c, h, w = 5, *image_shape
34-
35-
model = JanModel(image_shape, num_classes)
36-
37-
x = torch.randn(n, c, h, w)
38-
y = model(x)
39-
40-
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
41-
42-
43-
@pytest.mark.parametrize(
44-
"image_shape, num_classes",
45-
[((3, 16, 16), 3), ((3, 16, 16), 7)],
46-
)
47-
def test_solveig_model(image_shape, num_classes):
48-
n, c, h, w = 5, *image_shape
49-
50-
model = SolveigModel(image_shape, num_classes)
51-
52-
x = torch.randn(n, c, h, w)
53-
y = model(x)
54-
55-
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
4+
from CollaborativeCoding import load_model
565

576

587
@pytest.mark.parametrize(
59-
"image_shape, num_classes", [((3, 28, 28), 10), ((1, 16, 16), 10)]
8+
"model_name",
9+
[
10+
"magnusmodel",
11+
"christianmodel",
12+
"janmodel",
13+
"johanmodel",
14+
"solveigmodel",
15+
],
6016
)
61-
def test_magnus_model(image_shape, num_classes):
62-
import torch as th
17+
@pytest.mark.parametrize("image_shape", [(i, 28, 28) for i in [1, 3]])
18+
@pytest.mark.parametrize("num_classes", [3, 6, 10])
19+
def test_load_model(model_name, image_shape, num_classes):
20+
model = load_model(model_name, image_shape, num_classes)
6321

6422
n, c, h, w = 5, *image_shape
65-
model = MagnusModel([h, w], num_classes, c)
6623

67-
x = th.rand((n, c, h, w))
68-
with th.no_grad():
69-
y = model(x)
24+
dummy_img = torch.randn(n, c, h, w)
25+
with torch.no_grad():
26+
y = model(dummy_img)
7027

71-
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
28+
assert y.shape == (n, num_classes), f"Shape: {y.shape} != {(n, num_classes)}"

0 commit comments

Comments
 (0)