Skip to content

Commit f8d8ac6

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Clean up output validation helper
Summary: In helper `_assert_eager_lowered_same_result`, we want to allow user to validate different types of output from eager model and ET model. We can use a `validation_fn` to do this, and let user use the proper validation function to check. Reviewed By: guangy10 Differential Revision: D48267539 fbshipit-source-id: abb32c0f5bd26ccefcecd3f8b56767c3ce5789f2
1 parent 4720677 commit f8d8ac6

File tree

1 file changed

+43
-16
lines changed

1 file changed

+43
-16
lines changed

examples/export/test/test_export.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import unittest
88

9+
from typing import Any, Callable
10+
911
import torch
1012

1113
from executorch.examples.export.utils import _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG
@@ -19,8 +21,17 @@
1921

2022
class ExportTest(unittest.TestCase):
2123
def _assert_eager_lowered_same_result(
22-
self, eager_model: torch.nn.Module, example_inputs
24+
self,
25+
eager_model: torch.nn.Module,
26+
example_inputs,
27+
validation_fn: Callable[[Any, Any], bool],
2328
):
29+
"""
30+
Asserts that the given model has the same result as the eager mode
31+
lowered model, with example_inputs, validated by validation_fn, which
32+
takes the eager mode output and ET output, and returns True if they
33+
match.
34+
"""
2435
import executorch.exir as exir
2536

2637
edge_model = exir.capture(eager_model, example_inputs, _CAPTURE_CONFIG).to_edge(
@@ -36,38 +47,54 @@ def _assert_eager_lowered_same_result(
3647
with torch.no_grad():
3748
executorch_output = pte_model.run_method("forward", example_inputs)
3849

39-
if isinstance(eager_output, tuple):
40-
# TODO: Allow validating other items
41-
self.assertTrue(
42-
torch.allclose(
43-
eager_output[0], executorch_output[0][0], rtol=1e-5, atol=1e-5
44-
)
45-
)
46-
else:
47-
self.assertTrue(
48-
torch.allclose(eager_output, executorch_output[0], rtol=1e-5, atol=1e-5)
49-
)
50+
self.assertTrue(validation_fn(eager_output, executorch_output))
51+
52+
@staticmethod
53+
def validate_tensor_allclose(eager_output, executorch_output):
54+
return torch.allclose(
55+
eager_output,
56+
executorch_output[0],
57+
rtol=1e-5,
58+
atol=1e-5,
59+
)
5060

5161
def test_mv3_export_to_executorch(self):
5262
eager_model, example_inputs = MODEL_NAME_TO_MODEL["mv3"]()
5363
eager_model = eager_model.eval()
5464

55-
self._assert_eager_lowered_same_result(eager_model, example_inputs)
65+
self._assert_eager_lowered_same_result(
66+
eager_model, example_inputs, self.validate_tensor_allclose
67+
)
5668

5769
def test_mv2_export_to_executorch(self):
5870
eager_model, example_inputs = MODEL_NAME_TO_MODEL["mv2"]()
5971
eager_model = eager_model.eval()
6072

61-
self._assert_eager_lowered_same_result(eager_model, example_inputs)
73+
self._assert_eager_lowered_same_result(
74+
eager_model, example_inputs, self.validate_tensor_allclose
75+
)
6276

6377
def test_emformer_export_to_executorch(self):
6478
eager_model, example_inputs = MODEL_NAME_TO_MODEL["emformer"]()
6579
eager_model = eager_model.eval()
6680

67-
self._assert_eager_lowered_same_result(eager_model, example_inputs)
81+
validate_emformer_result = (
82+
lambda eager_output, executorch_output: torch.allclose(
83+
eager_output[0],
84+
executorch_output[0][0],
85+
rtol=1e-5,
86+
atol=1e-5,
87+
)
88+
)
89+
90+
self._assert_eager_lowered_same_result(
91+
eager_model, example_inputs, validate_emformer_result
92+
)
6893

6994
def test_vit_export_to_executorch(self):
7095
eager_model, example_inputs = MODEL_NAME_TO_MODEL["vit"]()
7196
eager_model = eager_model.eval()
7297

73-
self._assert_eager_lowered_same_result(eager_model, example_inputs)
98+
self._assert_eager_lowered_same_result(
99+
eager_model, example_inputs, self.validate_tensor_allclose
100+
)

0 commit comments

Comments
 (0)