1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import torch
10
+ import torchvision
11
+ import unittest
12
+
13
+ from executorch .backends .test .suite .models import model_test_params , model_test_cls , run_model_test
14
+ from torch .export import Dim
15
+ from typing import Callable
16
+
17
+ #
18
+ # This file contains model integration tests for supported torchvision models.
19
+ #
20
+
21
+ @model_test_cls
22
+ class TorchVision (unittest .TestCase ):
23
+ def _test_cv_model (
24
+ self ,
25
+ model : torch .nn .Module ,
26
+ dtype : torch .dtype ,
27
+ use_dynamic_shapes : bool ,
28
+ tester_factory : Callable ,
29
+ ):
30
+ # Test a CV model that follows the standard conventions.
31
+ inputs = (
32
+ torch .randn (1 , 3 , 224 , 224 , dtype = dtype ),
33
+ )
34
+
35
+ dynamic_shapes = (
36
+ {
37
+ 2 : Dim ("height" , min = 1 , max = 16 )* 16 ,
38
+ 3 : Dim ("width" , min = 1 , max = 16 )* 16 ,
39
+ },
40
+ ) if use_dynamic_shapes else None
41
+
42
+ run_model_test (model , inputs , dtype , dynamic_shapes , tester_factory )
43
+
44
+
45
+ def test_alexnet (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
46
+ model = torchvision .models .alexnet ()
47
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
48
+
49
+
50
+ def test_convnext_small (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
51
+ model = torchvision .models .convnext_small ()
52
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
53
+
54
+
55
+ def test_densenet161 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
56
+ model = torchvision .models .densenet161 ()
57
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
58
+
59
+
60
+ def test_efficientnet_b4 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
61
+ model = torchvision .models .efficientnet_b4 ()
62
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
63
+
64
+
65
+ def test_efficientnet_v2_s (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
66
+ model = torchvision .models .efficientnet_v2_s ()
67
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
68
+
69
+
70
+ def test_googlenet (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
71
+ model = torchvision .models .googlenet ()
72
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
73
+
74
+
75
+ def test_inception_v3 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
76
+ model = torchvision .models .inception_v3 ()
77
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
78
+
79
+
80
+ @model_test_params (supports_dynamic_shapes = False )
81
+ def test_maxvit_t (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
82
+ model = torchvision .models .maxvit_t ()
83
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
84
+
85
+
86
+ def test_mnasnet1_0 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
87
+ model = torchvision .models .mnasnet1_0 ()
88
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
89
+
90
+
91
+ def test_mobilenet_v2 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
92
+ model = torchvision .models .mobilenet_v2 ()
93
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
94
+
95
+
96
+ def test_mobilenet_v3_small (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
97
+ model = torchvision .models .mobilenet_v3_small ()
98
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
99
+
100
+
101
+ def test_regnet_y_1_6gf (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
102
+ model = torchvision .models .regnet_y_1_6gf ()
103
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
104
+
105
+
106
+ def test_resnet50 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
107
+ model = torchvision .models .resnet50 ()
108
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
109
+
110
+
111
+ def test_resnext50_32x4d (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
112
+ model = torchvision .models .resnext50_32x4d ()
113
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
114
+
115
+
116
+ def test_shufflenet_v2_x1_0 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
117
+ model = torchvision .models .shufflenet_v2_x1_0 ()
118
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
119
+
120
+
121
+ def test_squeezenet1_1 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
122
+ model = torchvision .models .squeezenet1_1 ()
123
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
124
+
125
+
126
+ def test_swin_v2_t (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
127
+ model = torchvision .models .swin_v2_t ()
128
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
129
+
130
+
131
+ def test_vgg11 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
132
+ model = torchvision .models .vgg11 ()
133
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
134
+
135
+
136
+ @model_test_params (supports_dynamic_shapes = False )
137
+ def test_vit_b_16 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
138
+ model = torchvision .models .vit_b_16 ()
139
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
140
+
141
+
142
+ def test_wide_resnet50_2 (self , dtype : torch .dtype , use_dynamic_shapes : bool , tester_factory : Callable ):
143
+ model = torchvision .models .wide_resnet50_2 ()
144
+ self ._test_cv_model (model , dtype , use_dynamic_shapes , tester_factory )
145
+
0 commit comments