Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 268 additions & 0 deletions backends/arm/test/tester/analyze_output_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import tempfile

import torch
from executorch.backends.arm.test.runner_utils import (
_get_input_quantization_params,
_get_output_node,
_get_output_quantization_params,
)

from executorch.backends.xnnpack.test.tester.tester import Export, Quantize

logger = logging.getLogger(__name__)


def _print_channels(result, reference, channels_close, C, H, W, rtol, atol):

output_str = ""
for c in range(C):
if channels_close[c]:
continue

max_diff = torch.max(torch.abs(reference - result))
exp = f"{max_diff:2e}"[-3:]
output_str += f"channel {c} (e{exp})\n"

for y in range(H):
res = "["
for x in range(W):
if torch.allclose(reference[c, y, x], result[c, y, x], rtol, atol):
res += " . "
else:
diff = (reference[c, y, x] - result[c, y, x]) / 10 ** (int(exp))
res += f"{diff: .2f} "

# Break early for large widths
if x == 16:
res += "..."
break

res += "]\n"
output_str += res

return output_str


def _print_elements(result, reference, C, H, W, rtol, atol):
output_str = ""
for y in range(H):
res = "["
for x in range(W):
result_channels = result[:, y, x]
reference_channels = reference[:, y, x]

n_errors = 0
for a, b in zip(result_channels, reference_channels):
if not torch.allclose(a, b, rtol, atol):
n_errors = n_errors + 1

if n_errors == 0:
res += ". "
else:
res += f"{n_errors} "

# Break early for large widths
if x == 16:
res += "..."
break

res += "]\n"
output_str += res

return output_str


def print_error_diffs(
tester,
result: torch.Tensor | tuple,
reference: torch.Tensor | tuple,
quantization_scale=None,
atol=1e-03,
rtol=1e-03,
qtol=0,
):
"""
Prints the error difference between a result tensor and a reference tensor in NCHW format.
Certain formatting rules are applied to clarify errors:

- Batches are only expanded if they contain errors.
-> Shows if errors are related to batch handling
- If errors appear in all channels, only the number of errors in each HW element are printed.
-> Shows if errors are related to HW handling
- If at least one channel is free from errors, or if C==1, errors are printed channel by channel
-> Shows if errors are related to channel handling or single errors such as rounding/quantization errors

Example output of shape (3,3,2,2):

############################ ERROR DIFFERENCE #############################
BATCH 0
.
BATCH 1
[. . ]
[. 3 ]
BATCH 2
channel 1 (e-03)
[ 1.85 . ]
[ . 9.32 ]

MEAN MEDIAN MAX MIN (error as % of reference output range)
60.02% 55.73% 100.17% 19.91%
###########################################################################


"""

if isinstance(reference, tuple):
reference = reference[0]
if isinstance(result, tuple):
result = result[0]

if not result.shape == reference.shape:
raise ValueError("Output needs to be of same shape")
shape = result.shape

match len(shape):
case 4:
N, C, H, W = (shape[0], shape[1], shape[2], shape[3])
case 3:
N, C, H, W = (1, shape[0], shape[1], shape[2])
case 2:
N, C, H, W = (1, 1, shape[0], shape[1])
case 1:
N, C, H, W = (1, 1, 1, shape[0])
case _:
raise ValueError("Invalid tensor rank")

if quantization_scale is not None:
atol += quantization_scale * qtol

# Reshape tensors to 4D NCHW format
result = torch.reshape(result, (N, C, H, W))
reference = torch.reshape(reference, (N, C, H, W))

output_str = ""
for n in range(N):
output_str += f"BATCH {n}\n"
result_batch = result[n, :, :, :]
reference_batch = reference[n, :, :, :]
is_close = torch.allclose(result_batch, reference_batch, rtol, atol)
if is_close:
output_str += ".\n"
else:
channels_close = [None] * C
for c in range(C):
result_hw = result[n, c, :, :]
reference_hw = reference[n, c, :, :]

channels_close[c] = torch.allclose(result_hw, reference_hw, rtol, atol)

if any(channels_close) or len(channels_close) == 1:
output_str += _print_channels(
result[n, :, :, :],
reference[n, :, :, :],
channels_close,
C,
H,
W,
rtol,
atol,
)
else:
output_str += _print_elements(
result[n, :, :, :], reference[n, :, :, :], C, H, W, rtol, atol
)

reference_range = torch.max(reference) - torch.min(reference)
diff = torch.abs(reference - result).flatten()
diff = diff[diff.nonzero()]
if not len(diff) == 0:
diff_percent = diff / reference_range
output_str += "\nMEAN MEDIAN MAX MIN (error as % of reference output range)\n"
output_str += f"{torch.mean(diff_percent):<8.2%} {torch.median(diff_percent):<8.2%} {torch.max(diff_percent):<8.2%} {torch.min(diff_percent):<8.2%}\n"

# Over-engineer separators to match output width
lines = output_str.split("\n")
line_length = [len(line) for line in lines]
longest_line = max(line_length)
title = "# ERROR DIFFERENCE #"
separator_length = max(longest_line, len(title))

pre_title_length = max(0, ((separator_length - len(title)) // 2))
post_title_length = max(0, ((separator_length - len(title) + 1) // 2))
start_separator = (
"\n" + "#" * pre_title_length + title + "#" * post_title_length + "\n"
)
output_str = start_separator + output_str
end_separator = "#" * separator_length + "\n"
output_str += end_separator

logger.error(output_str)


def dump_error_output(
tester,
reference_output,
stage_output,
quantization_scale=None,
atol=1e-03,
rtol=1e-03,
qtol=0,
):
"""
Prints Quantization info and error tolerances, and saves the differing tensors to disc.
"""
# Capture assertion error and print more info
banner = "=" * 40 + "TOSA debug info" + "=" * 40
logger.error(banner)
path_to_tosa_files = tester.runner_util.intermediate_path

if path_to_tosa_files is None:
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")

export_stage = tester.stages.get(tester.stage_name(Export), None)
quantize_stage = tester.stages.get(tester.stage_name(Quantize), None)
if export_stage is not None and quantize_stage is not None:
output_node = _get_output_node(export_stage.artifact)
qp_input = _get_input_quantization_params(export_stage.artifact)
qp_output = _get_output_quantization_params(export_stage.artifact, output_node)
logger.error(f"Input QuantArgs: {qp_input}")
logger.error(f"Output QuantArgs: {qp_output}")

logger.error(f"{path_to_tosa_files=}")
import os

torch.save(
stage_output,
os.path.join(path_to_tosa_files, "torch_tosa_output.pt"),
)
torch.save(
reference_output,
os.path.join(path_to_tosa_files, "torch_ref_output.pt"),
)
logger.error(f"{atol=}, {rtol=}, {qtol=}")


if __name__ == "__main__":
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)

""" This is expected to produce the example output of print_diff"""
torch.manual_seed(0)
a = torch.rand(3, 3, 2, 2) * 0.01
b = a.clone().detach()
logger.info(b)

# Errors in all channels in element (1,1)
a[1, :, 1, 1] = 0
# Errors in (0,0) and (1,1) in channel 1
a[2, 1, 1, 1] = 0
a[2, 1, 0, 0] = 0

print_error_diffs(a, b)
61 changes: 24 additions & 37 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

import logging
import tempfile

from collections import Counter
from pprint import pformat
Expand All @@ -25,12 +24,10 @@
)
from executorch.backends.arm.test.common import get_target_board

from executorch.backends.arm.test.runner_utils import (
_get_input_quantization_params,
_get_output_node,
_get_output_quantization_params,
dbg_tosa_fb_to_json,
RunnerUtil,
from executorch.backends.arm.test.runner_utils import dbg_tosa_fb_to_json, RunnerUtil
from executorch.backends.arm.test.tester.analyze_output_utils import (
dump_error_output,
print_error_diffs,
)
from executorch.backends.arm.tosa_mapping import extract_tensor_meta

Expand Down Expand Up @@ -278,6 +275,7 @@ def run_method_and_compare_outputs(
atol=1e-03,
rtol=1e-03,
qtol=0,
error_callbacks=None,
):
"""
Compares the run_artifact output of 'stage' with the output of a reference stage.
Expand Down Expand Up @@ -366,7 +364,13 @@ def run_method_and_compare_outputs(
test_output = self.transpose_data_format(test_output, "NCHW")

self._compare_outputs(
reference_output, test_output, quantization_scale, atol, rtol, qtol
reference_output,
test_output,
quantization_scale,
atol,
rtol,
qtol,
error_callbacks,
)

return self
Expand Down Expand Up @@ -515,42 +519,25 @@ def _compare_outputs(
atol=1e-03,
rtol=1e-03,
qtol=0,
error_callbacks=None,
):
try:
super()._compare_outputs(
reference_output, stage_output, quantization_scale, atol, rtol, qtol
)
except AssertionError as e:
# Capture assertion error and print more info
banner = "=" * 40 + "TOSA debug info" + "=" * 40
logger.error(banner)
path_to_tosa_files = self.runner_util.intermediate_path
if path_to_tosa_files is None:
path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_")

export_stage = self.stages.get(self.stage_name(tester.Export), None)
quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None)
if export_stage is not None and quantize_stage is not None:
output_node = _get_output_node(export_stage.artifact)
qp_input = _get_input_quantization_params(export_stage.artifact)
qp_output = _get_output_quantization_params(
export_stage.artifact, output_node
if error_callbacks is None:
error_callbacks = [print_error_diffs, dump_error_output]
for callback in error_callbacks:
callback(
self,
reference_output,
stage_output,
quantization_scale=None,
atol=1e-03,
rtol=1e-03,
qtol=0,
)
logger.error(f"Input QuantArgs: {qp_input}")
logger.error(f"Output QuantArgs: {qp_output}")

logger.error(f"{path_to_tosa_files=}")
import os

torch.save(
stage_output,
os.path.join(path_to_tosa_files, "torch_tosa_output.pt"),
)
torch.save(
reference_output,
os.path.join(path_to_tosa_files, "torch_ref_output.pt"),
)
logger.error(f"{atol=}, {rtol=}, {qtol=}")
raise e


Expand Down
Loading