Skip to content

Commit c1915f1

Browse files
committed
Add util function for pretty printing of output diffs
Change-Id: I93416fe3f2c175050b84a16004bf32849bafd1fe
1 parent de74961 commit c1915f1

File tree

2 files changed

+232
-3
lines changed

2 files changed

+232
-3
lines changed
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
8+
import torch
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def _print_channels(result, reference, channels_close, C, H, W, rtol, atol):
14+
15+
output_str = ""
16+
for c in range(C):
17+
if channels_close[c]:
18+
continue
19+
20+
max_diff = torch.max(torch.abs(reference - result))
21+
exp = f"{max_diff:2e}"[-3:]
22+
output_str += f"channel {c} (e{exp})\n"
23+
24+
for y in range(H):
25+
res = "["
26+
for x in range(W):
27+
if torch.allclose(reference[c, y, x], result[c, y, x], rtol, atol):
28+
res += " . "
29+
else:
30+
diff = (reference[c, y, x] - result[c, y, x]) / 10 ** (int(exp))
31+
res += f"{diff: .2f} "
32+
33+
# Break early for large widths
34+
if x == 16:
35+
res += "..."
36+
break
37+
38+
res += "]\n"
39+
output_str += res
40+
41+
return output_str
42+
43+
44+
def _print_elements(result, reference, C, H, W, rtol, atol):
45+
output_str = ""
46+
for y in range(H):
47+
res = "["
48+
for x in range(W):
49+
result_channels = result[:, y, x]
50+
reference_channels = reference[:, y, x]
51+
52+
n_errors = 0
53+
for a, b in zip(result_channels, reference_channels):
54+
if not torch.allclose(a, b, rtol, atol):
55+
n_errors = n_errors + 1
56+
57+
if n_errors == 0:
58+
res += ". "
59+
else:
60+
res += f"{n_errors} "
61+
62+
# Break early for large widths
63+
if x == 16:
64+
res += "..."
65+
break
66+
67+
res += "]\n"
68+
output_str += res
69+
70+
return output_str
71+
72+
73+
def print_error_diffs(
74+
result: torch.Tensor | tuple,
75+
reference: torch.Tensor | tuple,
76+
quantization_scale=None,
77+
atol=1e-03,
78+
rtol=1e-03,
79+
qtol=0,
80+
):
81+
"""
82+
Prints the error difference between a result tensor and a reference tensor in NCHW format.
83+
Certain formatting rules are applied to clarify errors:
84+
85+
- Batches are only expanded if they contain errors.
86+
-> Shows if errors are related to batch handling
87+
- If errors appear in all channels, only the number of errors in each HW element are printed.
88+
-> Shows if errors are related to HW handling
89+
- If at least one channel is free from errors, or if C==1, errors are printed channel by channel
90+
-> Shows if errors are related to channel handling or single errors such as rounding/quantization errors
91+
92+
Example output of shape (3,3,2,2):
93+
94+
############################ ERROR DIFFERENCE #############################
95+
BATCH 0
96+
.
97+
BATCH 1
98+
[. . ]
99+
[. 3 ]
100+
BATCH 2
101+
channel 1 (e-03)
102+
[ 1.85 . ]
103+
[ . 9.32 ]
104+
105+
MEAN MEDIAN MAX MIN (error as % of reference output range)
106+
60.02% 55.73% 100.17% 19.91%
107+
###########################################################################
108+
109+
110+
"""
111+
112+
if isinstance(reference, tuple):
113+
reference = reference[0]
114+
if isinstance(result, tuple):
115+
result = result[0]
116+
117+
if not result.shape == reference.shape:
118+
raise ValueError("Output needs to be of same shape")
119+
shape = result.shape
120+
121+
match len(shape):
122+
case 4:
123+
N, C, H, W = (shape[0], shape[1], shape[2], shape[3])
124+
case 3:
125+
N, C, H, W = (1, shape[0], shape[1], shape[2])
126+
case 2:
127+
N, C, H, W = (1, 1, shape[0], shape[1])
128+
case 1:
129+
N, C, H, W = (1, 1, 1, shape[0])
130+
case _:
131+
raise ValueError("Invalid tensor rank")
132+
133+
if quantization_scale is not None:
134+
atol += quantization_scale * qtol
135+
136+
# Reshape tensors to 4D NCHW format
137+
result = torch.reshape(result, (N, C, H, W))
138+
reference = torch.reshape(reference, (N, C, H, W))
139+
140+
output_str = ""
141+
for n in range(N):
142+
output_str += f"BATCH {n}\n"
143+
result_batch = result[n, :, :, :]
144+
reference_batch = reference[n, :, :, :]
145+
is_close = torch.allclose(result_batch, reference_batch, rtol, atol)
146+
if is_close:
147+
output_str += ".\n"
148+
else:
149+
channels_close = [None] * C
150+
for c in range(C):
151+
result_hw = result[n, c, :, :]
152+
reference_hw = reference[n, c, :, :]
153+
154+
channels_close[c] = torch.allclose(result_hw, reference_hw, rtol, atol)
155+
156+
if any(channels_close) or len(channels_close) == 1:
157+
output_str += _print_channels(
158+
result[n, :, :, :],
159+
reference[n, :, :, :],
160+
channels_close,
161+
C,
162+
H,
163+
W,
164+
rtol,
165+
atol,
166+
)
167+
else:
168+
output_str += _print_elements(
169+
result[n, :, :, :], reference[n, :, :, :], C, H, W, rtol, atol
170+
)
171+
172+
reference_range = torch.max(reference) - torch.min(reference)
173+
diff = torch.abs(reference - result).flatten()
174+
diff = diff[diff.nonzero()]
175+
if not len(diff) == 0:
176+
diff_percent = diff / reference_range
177+
output_str += "\nMEAN MEDIAN MAX MIN (error as % of reference output range)\n"
178+
output_str += f"{torch.mean(diff_percent):<8.2%} {torch.median(diff_percent):<8.2%} {torch.max(diff_percent):<8.2%} {torch.min(diff_percent):<8.2%}\n"
179+
180+
# Over-engineer separators to match output width
181+
lines = output_str.split("\n")
182+
line_length = [len(line) for line in lines]
183+
longest_line = max(line_length)
184+
title = "# ERROR DIFFERENCE #"
185+
separator_length = max(longest_line, len(title))
186+
187+
pre_title_length = max(0, ((separator_length - len(title)) // 2))
188+
post_title_length = max(0, ((separator_length - len(title) + 1) // 2))
189+
start_separator = (
190+
"\n" + "#" * pre_title_length + title + "#" * post_title_length + "\n"
191+
)
192+
output_str = start_separator + output_str
193+
end_separator = "#" * separator_length + "\n"
194+
output_str += end_separator
195+
196+
logger.info(output_str)
197+
198+
199+
if __name__ == "__main__":
200+
import sys
201+
202+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
203+
204+
""" This is expected to produce the example output of print_diff"""
205+
torch.manual_seed(0)
206+
a = torch.rand(3, 3, 2, 2) * 0.01
207+
b = a.clone().detach()
208+
logger.info(b)
209+
210+
# Errors in all channels in element (1,1)
211+
a[1, :, 1, 1] = 0
212+
# Errors in (0,0) and (1,1) in channel 1
213+
a[2, 1, 1, 1] = 0
214+
a[2, 1, 0, 0] = 0
215+
216+
print_error_diffs(a, b)

backends/arm/test/tester/arm_tester.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
dbg_tosa_fb_to_json,
3333
RunnerUtil,
3434
)
35+
from executorch.backends.arm.test.tester.analyze_output_utils import print_error_diffs
3536
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
3637

3738
from executorch.backends.xnnpack.test.tester import Tester
@@ -278,6 +279,7 @@ def run_method_and_compare_outputs(
278279
atol=1e-03,
279280
rtol=1e-03,
280281
qtol=0,
282+
callback=print_error_diffs,
281283
):
282284
"""
283285
Compares the run_artifact output of 'stage' with the output of a reference stage.
@@ -365,9 +367,20 @@ def run_method_and_compare_outputs(
365367
):
366368
test_output = self.transpose_data_format(test_output, "NCHW")
367369

368-
self._compare_outputs(
369-
reference_output, test_output, quantization_scale, atol, rtol, qtol
370-
)
370+
try:
371+
self._compare_outputs(
372+
reference_output, test_output, quantization_scale, atol, rtol, qtol
373+
)
374+
except AssertionError as e:
375+
callback(
376+
reference_output,
377+
test_output,
378+
quantization_scale,
379+
atol,
380+
rtol,
381+
qtol,
382+
)
383+
raise e
371384

372385
return self
373386

0 commit comments

Comments
 (0)