Skip to content

Commit 504fef4

Browse files
authored
Merge pull request #609 from rwightman/cait
CaiT transformer model
2 parents 936e9b3 + d45e50b commit 504fef4

File tree

3 files changed

+405
-2
lines changed

3 files changed

+405
-2
lines changed

tests/test_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
torch._C._jit_set_profiling_mode(False)
1616

1717
# transformer models don't support many of the spatial / feature based model functionalities
18-
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'mixer_*']
18+
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*']
1919
NUM_NON_STD = len(NON_STD_FILTERS)
2020

2121
# exclude models that cause specific test failures
@@ -43,7 +43,9 @@ def test_model_forward(model_name, batch_size):
4343

4444
input_size = model.default_cfg['input_size']
4545
if any([x > MAX_FWD_SIZE for x in input_size]):
46-
# cap forward test at max res 448 * 448 to keep resource down
46+
if is_model_default_key(model_name, 'fixed_input_size'):
47+
pytest.skip("Fixed input size model > limit.")
48+
# cap forward test at max res 384 * 384 to keep resource down
4749
input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size])
4850
inputs = torch.randn((batch_size, *input_size))
4951
outputs = model(inputs)

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .byoanet import *
22
from .byobnet import *
3+
from .cait import *
34
from .coat import *
45
from .cspnet import *
56
from .densenet import *

0 commit comments

Comments
 (0)