Skip to content

Commit 0066e87

Browse files
committed
Move callbacks into _compare_outputs
This is to make run_method_and_compare_outputs less complex since lintrunner was complaining. Additionally moves out previous info dumps in compare_output into a new callback function to handle all error handling in the same way. Change-Id: If7c7cfd515d4870c018a34785dd80be15d4fbcef
1 parent 12771e3 commit 0066e87

File tree

2 files changed

+79
-53
lines changed

2 files changed

+79
-53
lines changed

backends/arm/test/tester/analyze_output_utils.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,16 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7+
import tempfile
78

89
import torch
10+
from executorch.backends.arm.test.runner_utils import (
11+
_get_input_quantization_params,
12+
_get_output_node,
13+
_get_output_quantization_params,
14+
)
15+
16+
from executorch.backends.xnnpack.test.tester.tester import Export, Quantize
917

1018
logger = logging.getLogger(__name__)
1119

@@ -71,6 +79,7 @@ def _print_elements(result, reference, C, H, W, rtol, atol):
7179

7280

7381
def print_error_diffs(
82+
tester,
7483
result: torch.Tensor | tuple,
7584
reference: torch.Tensor | tuple,
7685
quantization_scale=None,
@@ -193,7 +202,50 @@ def print_error_diffs(
193202
end_separator = "#" * separator_length + "\n"
194203
output_str += end_separator
195204

196-
logger.info(output_str)
205+
logger.error(output_str)
206+
207+
208+
def dump_error_output(
209+
tester,
210+
reference_output,
211+
stage_output,
212+
quantization_scale=None,
213+
atol=1e-03,
214+
rtol=1e-03,
215+
qtol=0,
216+
):
217+
"""
218+
Prints Quantization info and error tolerances, and saves the differing tensors to disc.
219+
"""
220+
# Capture assertion error and print more info
221+
banner = "=" * 40 + "TOSA debug info" + "=" * 40
222+
logger.error(banner)
223+
path_to_tosa_files = tester.runner_util.intermediate_path
224+
225+
if path_to_tosa_files is None:
226+
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")
227+
228+
export_stage = tester.stages.get(tester.stage_name(Export), None)
229+
quantize_stage = tester.stages.get(tester.stage_name(Quantize), None)
230+
if export_stage is not None and quantize_stage is not None:
231+
output_node = _get_output_node(export_stage.artifact)
232+
qp_input = _get_input_quantization_params(export_stage.artifact)
233+
qp_output = _get_output_quantization_params(export_stage.artifact, output_node)
234+
logger.error(f"Input QuantArgs: {qp_input}")
235+
logger.error(f"Output QuantArgs: {qp_output}")
236+
237+
logger.error(f"{path_to_tosa_files=}")
238+
import os
239+
240+
torch.save(
241+
stage_output,
242+
os.path.join(path_to_tosa_files, "torch_tosa_output.pt"),
243+
)
244+
torch.save(
245+
reference_output,
246+
os.path.join(path_to_tosa_files, "torch_ref_output.pt"),
247+
)
248+
logger.error(f"{atol=}, {rtol=}, {qtol=}")
197249

198250

199251
if __name__ == "__main__":

backends/arm/test/tester/arm_tester.py

Lines changed: 26 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7-
import tempfile
87

98
from collections import Counter
109
from pprint import pformat
@@ -25,14 +24,11 @@
2524
)
2625
from executorch.backends.arm.test.common import get_target_board
2726

28-
from executorch.backends.arm.test.runner_utils import (
29-
_get_input_quantization_params,
30-
_get_output_node,
31-
_get_output_quantization_params,
32-
dbg_tosa_fb_to_json,
33-
RunnerUtil,
27+
from executorch.backends.arm.test.runner_utils import dbg_tosa_fb_to_json, RunnerUtil
28+
from executorch.backends.arm.test.tester.analyze_output_utils import (
29+
dump_error_output,
30+
print_error_diffs,
3431
)
35-
from executorch.backends.arm.test.tester.analyze_output_utils import print_error_diffs
3632
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
3733

3834
from executorch.backends.xnnpack.test.tester import Tester
@@ -279,7 +275,7 @@ def run_method_and_compare_outputs(
279275
atol=1e-03,
280276
rtol=1e-03,
281277
qtol=0,
282-
callback=print_error_diffs,
278+
error_callbacks=None,
283279
):
284280
"""
285281
Compares the run_artifact output of 'stage' with the output of a reference stage.
@@ -367,20 +363,15 @@ def run_method_and_compare_outputs(
367363
):
368364
test_output = self.transpose_data_format(test_output, "NCHW")
369365

370-
try:
371-
self._compare_outputs(
372-
reference_output, test_output, quantization_scale, atol, rtol, qtol
373-
)
374-
except AssertionError as e:
375-
callback(
376-
reference_output,
377-
test_output,
378-
quantization_scale,
379-
atol,
380-
rtol,
381-
qtol,
382-
)
383-
raise e
366+
self._compare_outputs(
367+
reference_output,
368+
test_output,
369+
quantization_scale,
370+
atol,
371+
rtol,
372+
qtol,
373+
error_callbacks,
374+
)
384375

385376
return self
386377

@@ -528,42 +519,25 @@ def _compare_outputs(
528519
atol=1e-03,
529520
rtol=1e-03,
530521
qtol=0,
522+
error_callbacks=None,
531523
):
532524
try:
533525
super()._compare_outputs(
534526
reference_output, stage_output, quantization_scale, atol, rtol, qtol
535527
)
536528
except AssertionError as e:
537-
# Capture assertion error and print more info
538-
banner = "=" * 40 + "TOSA debug info" + "=" * 40
539-
logger.error(banner)
540-
path_to_tosa_files = self.runner_util.intermediate_path
541-
if path_to_tosa_files is None:
542-
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")
543-
544-
export_stage = self.stages.get(self.stage_name(tester.Export), None)
545-
quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None)
546-
if export_stage is not None and quantize_stage is not None:
547-
output_node = _get_output_node(export_stage.artifact)
548-
qp_input = _get_input_quantization_params(export_stage.artifact)
549-
qp_output = _get_output_quantization_params(
550-
export_stage.artifact, output_node
529+
if error_callbacks is None:
530+
error_callbacks = [print_error_diffs, dump_error_output]
531+
for callback in error_callbacks:
532+
callback(
533+
self,
534+
reference_output,
535+
stage_output,
536+
quantization_scale=None,
537+
atol=1e-03,
538+
rtol=1e-03,
539+
qtol=0,
551540
)
552-
logger.error(f"Input QuantArgs: {qp_input}")
553-
logger.error(f"Output QuantArgs: {qp_output}")
554-
555-
logger.error(f"{path_to_tosa_files=}")
556-
import os
557-
558-
torch.save(
559-
stage_output,
560-
os.path.join(path_to_tosa_files, "torch_tosa_output.pt"),
561-
)
562-
torch.save(
563-
reference_output,
564-
os.path.join(path_to_tosa_files, "torch_ref_output.pt"),
565-
)
566-
logger.error(f"{atol=}, {rtol=}, {qtol=}")
567541
raise e
568542

569543

0 commit comments

Comments
 (0)