1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ # pyre-unsafe
8+
9+ import torch
10+ import torchvision
11+ import unittest
12+
13+ from executorch .backends .test .suite .models import model_test_params , model_test_cls , run_model_test
14+ from torch .export import Dim
15+ from typing import Callable
16+
17+ #
18+ # This file contains model integration tests for supported torchvision models.
19+ #
20+
21+ @model_test_cls
22+ class TorchVision (unittest .TestCase ):
23+ def _test_cv_model (
24+ self ,
25+ model : torch .nn .Module ,
26+ dtype : torch .dtype ,
27+ use_dynamic_shapes : bool ,
28+ tester_factory : Callable ,
29+ ):
30+ # Test a CV model that follows the standard conventions.
31+ inputs = (
32+ torch .randn (1 , 3 , 224 , 224 , dtype = dtype ),
33+ )
34+
35+ dynamic_shapes = (
36+ {
37+ 2 : Dim ("height" , min = 1 , max = 16 )* 16 ,
38+ 3 : Dim ("width" , min = 1 , max = 16 )* 16 ,
39+ },
40+ ) if use_dynamic_shapes else None
41+
42+ run_model_test (model , inputs , dtype , dynamic_shapes , tester_factory )
43+
44+
45+ def test_alexnet (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
46+ model = torchvision .models .alexnet ()
47+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
48+
49+
50+ def test_convnext_small (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
51+ model = torchvision .models .convnext_small ()
52+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
53+
54+
55+ def test_densenet161 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
56+ model = torchvision .models .densenet161 ()
57+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
58+
59+
60+ def test_efficientnet_b4 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
61+ model = torchvision .models .efficientnet_b4 ()
62+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
63+
64+
65+ def test_efficientnet_v2_s (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
66+ model = torchvision .models .efficientnet_v2_s ()
67+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
68+
69+
70+ def test_googlenet (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
71+ model = torchvision .models .googlenet ()
72+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
73+
74+
75+ def test_inception_v3 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
76+ model = torchvision .models .inception_v3 ()
77+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
78+
79+
80+ @model_test_params (supports_dynamic_shapes = False )
81+ def test_maxvit_t (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
82+ model = torchvision .models .maxvit_t ()
83+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
84+
85+
86+ def test_mnasnet1_0 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
87+ model = torchvision .models .mnasnet1_0 ()
88+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
89+
90+
91+ def test_mobilenet_v2 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
92+ model = torchvision .models .mobilenet_v2 ()
93+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
94+
95+
96+ def test_mobilenet_v3_small (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
97+ model = torchvision .models .mobilenet_v3_small ()
98+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
99+
100+
101+ def test_regnet_y_1_6gf (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
102+ model = torchvision .models .regnet_y_1_6gf ()
103+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
104+
105+
106+ def test_resnet50 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
107+ model = torchvision .models .resnet50 ()
108+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
109+
110+
111+ def test_resnext50_32x4d (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
112+ model = torchvision .models .resnext50_32x4d ()
113+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
114+
115+
116+ def test_shufflenet_v2_x1_0 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
117+ model = torchvision .models .shufflenet_v2_x1_0 ()
118+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
119+
120+
121+ def test_squeezenet1_1 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
122+ model = torchvision .models .squeezenet1_1 ()
123+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
124+
125+
126+ def test_swin_v2_t (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
127+ model = torchvision .models .swin_v2_t ()
128+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
129+
130+
131+ def test_vgg11 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
132+ model = torchvision .models .vgg11 ()
133+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
134+
135+
136+ @model_test_params (supports_dynamic_shapes = False )
137+ def test_vit_b_16 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
138+ model = torchvision .models .vit_b_16 ()
139+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
140+
141+
142+ def test_wide_resnet50_2 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
143+ model = torchvision .models .wide_resnet50_2 ()
144+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
145+
0 commit comments