Skip to content

Commit 20329f2

Browse files
committed
Bring down test resolutions to see if we can at least do a fwd on the L2 models
1 parent 3873ea7 commit 20329f2

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

tests/test_models.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
from timm import list_models, create_model
55

6+
MAX_FWD_SIZE = 320
7+
MAX_BWD_SIZE = 128
8+
MAX_FWD_FEAT_SIZE = 448
9+
610

711
@pytest.mark.timeout(120)
812
@pytest.mark.parametrize('model_name', list_models())
@@ -13,9 +17,9 @@ def test_model_forward(model_name, batch_size):
1317
model.eval()
1418

1519
input_size = model.default_cfg['input_size']
16-
if any([x > 448 for x in input_size]):
20+
if any([x > MAX_FWD_SIZE for x in input_size]):
1721
# cap forward test at max res 448 * 448 to keep resource down
18-
input_size = tuple([min(x, 448) for x in input_size])
22+
input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size])
1923
inputs = torch.randn((batch_size, *input_size))
2024
outputs = model(inputs)
2125

@@ -33,9 +37,9 @@ def test_model_backward(model_name, batch_size):
3337
model.eval()
3438

3539
input_size = model.default_cfg['input_size']
36-
if any([x > 128 for x in input_size]):
40+
if any([x > MAX_BWD_SIZE for x in input_size]):
3741
# cap backward test at 128 * 128 to keep resource usage down
38-
input_size = tuple([min(x, 128) for x in input_size])
42+
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
3943
inputs = torch.randn((batch_size, *input_size))
4044
outputs = model(inputs)
4145
outputs.mean().backward()
@@ -61,9 +65,9 @@ def test_model_default_cfgs(model_name, batch_size):
6165
pool_size = cfg['pool_size']
6266
input_size = model.default_cfg['input_size']
6367

64-
if all([x <= 448 for x in input_size]):
68+
if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and 'efficientnet_l2' not in model_name:
6569
# pool size only checked if default res <= 448 * 448 to keep resource down
66-
input_size = tuple([min(x, 448) for x in input_size])
70+
input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size])
6771
outputs = model.forward_features(torch.randn((batch_size, *input_size)))
6872
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
6973
assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params'

0 commit comments

Comments
 (0)