33
44from timm import list_models , create_model
55
6+ MAX_FWD_SIZE = 320
7+ MAX_BWD_SIZE = 128
8+ MAX_FWD_FEAT_SIZE = 448
9+
610
711@pytest .mark .timeout (120 )
812@pytest .mark .parametrize ('model_name' , list_models ())
@@ -13,9 +17,9 @@ def test_model_forward(model_name, batch_size):
1317 model .eval ()
1418
1519 input_size = model .default_cfg ['input_size' ]
16- if any ([x > 448 for x in input_size ]):
20+ if any ([x > MAX_FWD_SIZE for x in input_size ]):
1721 # cap forward test at max res 448 * 448 to keep resource down
18- input_size = tuple ([min (x , 448 ) for x in input_size ])
22+ input_size = tuple ([min (x , MAX_FWD_SIZE ) for x in input_size ])
1923 inputs = torch .randn ((batch_size , * input_size ))
2024 outputs = model (inputs )
2125
@@ -33,9 +37,9 @@ def test_model_backward(model_name, batch_size):
3337 model .eval ()
3438
3539 input_size = model .default_cfg ['input_size' ]
36- if any ([x > 128 for x in input_size ]):
40+ if any ([x > MAX_BWD_SIZE for x in input_size ]):
3741 # cap backward test at 128 * 128 to keep resource usage down
38- input_size = tuple ([min (x , 128 ) for x in input_size ])
42+ input_size = tuple ([min (x , MAX_BWD_SIZE ) for x in input_size ])
3943 inputs = torch .randn ((batch_size , * input_size ))
4044 outputs = model (inputs )
4145 outputs .mean ().backward ()
@@ -61,9 +65,9 @@ def test_model_default_cfgs(model_name, batch_size):
6165 pool_size = cfg ['pool_size' ]
6266 input_size = model .default_cfg ['input_size' ]
6367
64- if all ([x <= 448 for x in input_size ]):
68+ if all ([x <= MAX_FWD_FEAT_SIZE for x in input_size ]) and 'efficientnet_l2' not in model_name :
6569 # pool size only checked if default res <= 448 * 448 to keep resource down
66- input_size = tuple ([min (x , 448 ) for x in input_size ])
70+ input_size = tuple ([min (x , MAX_FWD_FEAT_SIZE ) for x in input_size ])
6771 outputs = model .forward_features (torch .randn ((batch_size , * input_size )))
6872 assert outputs .shape [- 1 ] == pool_size [- 1 ] and outputs .shape [- 2 ] == pool_size [- 2 ]
6973 assert any ([k .startswith (classifier ) for k in state_dict .keys ()]), f'{ classifier } not in model params'
0 commit comments