|
1 | 1 | import pytest |
2 | 2 | import torch |
3 | 3 |
|
4 | | -from CollaborativeCoding.models import ( |
5 | | - ChristianModel, |
6 | | - JanModel, |
7 | | - JohanModel, |
8 | | - MagnusModel, |
9 | | - SolveigModel, |
10 | | -) |
11 | | - |
12 | | - |
13 | | -@pytest.mark.parametrize( |
14 | | - "image_shape, num_classes", |
15 | | - [((1, 16, 16), 6), ((3, 16, 16), 6)], |
16 | | -) |
17 | | -def test_christian_model(image_shape, num_classes): |
18 | | - n, c, h, w = 5, *image_shape |
19 | | - |
20 | | - model = ChristianModel(image_shape, num_classes) |
21 | | - |
22 | | - x = torch.randn(n, c, h, w) |
23 | | - y = model(x) |
24 | | - |
25 | | - assert y.shape == (n, num_classes), f"Shape: {y.shape}" |
26 | | - |
27 | | - |
28 | | -@pytest.mark.parametrize( |
29 | | - "image_shape, num_classes", |
30 | | - [((1, 28, 28), 4), ((3, 16, 16), 10)], |
31 | | -) |
32 | | -def test_jan_model(image_shape, num_classes): |
33 | | - n, c, h, w = 5, *image_shape |
34 | | - |
35 | | - model = JanModel(image_shape, num_classes) |
36 | | - |
37 | | - x = torch.randn(n, c, h, w) |
38 | | - y = model(x) |
39 | | - |
40 | | - assert y.shape == (n, num_classes), f"Shape: {y.shape}" |
41 | | - |
42 | | - |
43 | | -@pytest.mark.parametrize( |
44 | | - "image_shape, num_classes", |
45 | | - [((3, 16, 16), 3), ((3, 16, 16), 7)], |
46 | | -) |
47 | | -def test_solveig_model(image_shape, num_classes): |
48 | | - n, c, h, w = 5, *image_shape |
49 | | - |
50 | | - model = SolveigModel(image_shape, num_classes) |
51 | | - |
52 | | - x = torch.randn(n, c, h, w) |
53 | | - y = model(x) |
54 | | - |
55 | | - assert y.shape == (n, num_classes), f"Shape: {y.shape}" |
| 4 | +from CollaborativeCoding import load_model |
56 | 5 |
|
57 | 6 |
|
58 | 7 | @pytest.mark.parametrize( |
59 | | - "image_shape, num_classes", [((3, 28, 28), 10), ((1, 16, 16), 10)] |
| 8 | + "model_name", |
| 9 | + [ |
| 10 | + "magnusmodel", |
| 11 | + "christianmodel", |
| 12 | + "janmodel", |
| 13 | + "johanmodel", |
| 14 | + "solveigmodel", |
| 15 | + ], |
60 | 16 | ) |
61 | | -def test_magnus_model(image_shape, num_classes): |
62 | | - import torch as th |
| 17 | +@pytest.mark.parametrize("image_shape", [(i, 28, 28) for i in [1, 3]]) |
| 18 | +@pytest.mark.parametrize("num_classes", [3, 6, 10]) |
| 19 | +def test_load_model(model_name, image_shape, num_classes): |
| 20 | + model = load_model(model_name, image_shape, num_classes) |
63 | 21 |
|
64 | 22 | n, c, h, w = 5, *image_shape |
65 | | - model = MagnusModel([h, w], num_classes, c) |
66 | 23 |
|
67 | | - x = th.rand((n, c, h, w)) |
68 | | - with th.no_grad(): |
69 | | - y = model(x) |
| 24 | + dummy_img = torch.randn(n, c, h, w) |
| 25 | + with torch.no_grad(): |
| 26 | + y = model(dummy_img) |
70 | 27 |
|
71 | | - assert y.shape == (n, num_classes), f"Shape: {y.shape}" |
| 28 | + assert y.shape == (n, num_classes), f"Shape: {y.shape} != {(n, num_classes)}" |
0 commit comments