We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8da43e0 commit 69d725cCopy full SHA for 69d725c
tests/__init__.py
tests/test_inference.py
@@ -0,0 +1,19 @@
1
+import pytest
2
+import torch
3
+
4
+from timm import list_models, create_model
5
6
7
+@pytest.mark.timeout(60)
8
+@pytest.mark.parametrize('model_name', list_models())
9
+@pytest.mark.parametrize('batch_size', [1])
10
+def test_model_forward(model_name, batch_size):
11
+ """Run a single forward pass with each model"""
12
+ model = create_model(model_name, pretrained=False)
13
+ model.eval()
14
15
+ inputs = torch.randn((batch_size, *model.default_cfg['input_size']))
16
+ outputs = model(inputs)
17
18
+ assert outputs.shape[0] == batch_size
19
+ assert not torch.isnan(outputs).any(), 'Output included NaNs'
0 commit comments