Skip to content

Commit d911e4a

Browse files
committed
Update tests to accept new model input
1 parent f4e5591 commit d911e4a

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/test_models.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
from utils.models import ChristianModel
55

66

7-
@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)])
8-
def test_christian_model(in_channels, num_classes):
9-
n, c, h, w = 5, in_channels, 16, 16
7+
@pytest.mark.parametrize(
8+
"image_shape, num_classes",
9+
[((1, 16, 16), 6), ((3, 16, 16), 6)],
10+
)
11+
def test_christian_model(image_shape, num_classes):
12+
n, c, h, w = 5, *image_shape
1013

11-
model = ChristianModel(c, num_classes)
14+
model = ChristianModel(image_shape, num_classes)
1215

1316
x = torch.randn(n, c, h, w)
1417
y = model(x)

0 commit comments

Comments
 (0)