6
6
7
7
import unittest
8
8
9
+ from typing import Any , Callable
10
+
9
11
import torch
10
12
11
13
from executorch .examples .export .utils import _CAPTURE_CONFIG , _EDGE_COMPILE_CONFIG
19
21
20
22
class ExportTest (unittest .TestCase ):
21
23
def _assert_eager_lowered_same_result (
22
- self , eager_model : torch .nn .Module , example_inputs
24
+ self ,
25
+ eager_model : torch .nn .Module ,
26
+ example_inputs ,
27
+ validation_fn : Callable [[Any , Any ], bool ],
23
28
):
29
+ """
30
+ Asserts that the given model has the same result as the eager mode
31
+ lowered model, with example_inputs, validated by validation_fn, which
32
+ takes the eager mode output and ET output, and returns True if they
33
+ match.
34
+ """
24
35
import executorch .exir as exir
25
36
26
37
edge_model = exir .capture (eager_model , example_inputs , _CAPTURE_CONFIG ).to_edge (
@@ -36,38 +47,54 @@ def _assert_eager_lowered_same_result(
36
47
with torch .no_grad ():
37
48
executorch_output = pte_model .run_method ("forward" , example_inputs )
38
49
39
- if isinstance (eager_output , tuple ):
40
- # TODO: Allow validating other items
41
- self .assertTrue (
42
- torch .allclose (
43
- eager_output [0 ], executorch_output [0 ][0 ], rtol = 1e-5 , atol = 1e-5
44
- )
45
- )
46
- else :
47
- self .assertTrue (
48
- torch .allclose (eager_output , executorch_output [0 ], rtol = 1e-5 , atol = 1e-5 )
49
- )
50
+ self .assertTrue (validation_fn (eager_output , executorch_output ))
51
+
52
+ @staticmethod
53
+ def validate_tensor_allclose (eager_output , executorch_output ):
54
+ return torch .allclose (
55
+ eager_output ,
56
+ executorch_output [0 ],
57
+ rtol = 1e-5 ,
58
+ atol = 1e-5 ,
59
+ )
50
60
51
61
def test_mv3_export_to_executorch (self ):
52
62
eager_model , example_inputs = MODEL_NAME_TO_MODEL ["mv3" ]()
53
63
eager_model = eager_model .eval ()
54
64
55
- self ._assert_eager_lowered_same_result (eager_model , example_inputs )
65
+ self ._assert_eager_lowered_same_result (
66
+ eager_model , example_inputs , self .validate_tensor_allclose
67
+ )
56
68
57
69
def test_mv2_export_to_executorch (self ):
58
70
eager_model , example_inputs = MODEL_NAME_TO_MODEL ["mv2" ]()
59
71
eager_model = eager_model .eval ()
60
72
61
- self ._assert_eager_lowered_same_result (eager_model , example_inputs )
73
+ self ._assert_eager_lowered_same_result (
74
+ eager_model , example_inputs , self .validate_tensor_allclose
75
+ )
62
76
63
77
def test_emformer_export_to_executorch (self ):
64
78
eager_model , example_inputs = MODEL_NAME_TO_MODEL ["emformer" ]()
65
79
eager_model = eager_model .eval ()
66
80
67
- self ._assert_eager_lowered_same_result (eager_model , example_inputs )
81
+ validate_emformer_result = (
82
+ lambda eager_output , executorch_output : torch .allclose (
83
+ eager_output [0 ],
84
+ executorch_output [0 ][0 ],
85
+ rtol = 1e-5 ,
86
+ atol = 1e-5 ,
87
+ )
88
+ )
89
+
90
+ self ._assert_eager_lowered_same_result (
91
+ eager_model , example_inputs , validate_emformer_result
92
+ )
68
93
69
94
def test_vit_export_to_executorch (self ):
70
95
eager_model , example_inputs = MODEL_NAME_TO_MODEL ["vit" ]()
71
96
eager_model = eager_model .eval ()
72
97
73
- self ._assert_eager_lowered_same_result (eager_model , example_inputs )
98
+ self ._assert_eager_lowered_same_result (
99
+ eager_model , example_inputs , self .validate_tensor_allclose
100
+ )
0 commit comments