Skip to content

Commit d10b60d

Browse files
committed
Arm backend: Fix mypy warnings in test/tester
Signed-off-by: [email protected] Change-Id: Ia282b3edab43ebea758eba6898e536652837a732
1 parent 964515c commit d10b60d

File tree

3 files changed

+283
-168
lines changed

3 files changed

+283
-168
lines changed

backends/arm/test/tester/analyze_output_utils.py

Lines changed: 74 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import logging
77
import tempfile
8+
from typing import Any, cast, Sequence
89

910
import torch
1011
from executorch.backends.arm.test.runner_utils import (
@@ -17,9 +18,30 @@
1718
logger = logging.getLogger(__name__)
1819

1920

20-
def _print_channels(result, reference, channels_close, C, H, W, rtol, atol):
21+
TensorLike = torch.Tensor | tuple[torch.Tensor, ...]
22+
23+
24+
def _ensure_tensor(value: TensorLike) -> torch.Tensor:
25+
if isinstance(value, torch.Tensor):
26+
return value
27+
if value and isinstance(value[0], torch.Tensor):
28+
return value[0]
29+
raise TypeError("Expected a Tensor or a non-empty tuple of Tensors")
30+
31+
32+
def _print_channels(
33+
result: torch.Tensor,
34+
reference: torch.Tensor,
35+
channels_close: Sequence[bool],
36+
C: int,
37+
H: int,
38+
W: int,
39+
rtol: float,
40+
atol: float,
41+
) -> str:
2142

2243
output_str = ""
44+
exp = "000"
2345
booldata = False
2446
if reference.dtype == torch.bool or result.dtype == torch.bool:
2547
booldata = True
@@ -62,7 +84,15 @@ def _print_channels(result, reference, channels_close, C, H, W, rtol, atol):
6284
return output_str
6385

6486

65-
def _print_elements(result, reference, C, H, W, rtol, atol):
87+
def _print_elements(
88+
result: torch.Tensor,
89+
reference: torch.Tensor,
90+
C: int,
91+
H: int,
92+
W: int,
93+
rtol: float,
94+
atol: float,
95+
) -> str:
6696
output_str = ""
6797
for y in range(H):
6898
res = "["
@@ -92,14 +122,16 @@ def _print_elements(result, reference, C, H, W, rtol, atol):
92122

93123

94124
def print_error_diffs(
95-
tester,
96-
result: torch.Tensor | tuple,
97-
reference: torch.Tensor | tuple,
98-
quantization_scale=None,
99-
atol=1e-03,
100-
rtol=1e-03,
101-
qtol=0,
102-
):
125+
tester_or_result: Any,
126+
result_or_reference: TensorLike,
127+
reference: TensorLike | None = None,
128+
# Force remaining args to be keyword-only to keep the two positional call patterns unambiguous.
129+
*,
130+
quantization_scale: float | None = None,
131+
atol: float = 1e-03,
132+
rtol: float = 1e-03,
133+
qtol: float = 0,
134+
) -> None:
103135
"""
104136
Prints the error difference between a result tensor and a reference tensor in NCHW format.
105137
Certain formatting rules are applied to clarify errors:
@@ -130,15 +162,16 @@ def print_error_diffs(
130162
131163
132164
"""
133-
134-
if isinstance(reference, tuple):
135-
reference = reference[0]
136-
if isinstance(result, tuple):
137-
result = result[0]
138-
139-
if not result.shape == reference.shape:
165+
if reference is None:
166+
result = _ensure_tensor(cast(TensorLike, tester_or_result))
167+
reference_tensor = _ensure_tensor(result_or_reference)
168+
else:
169+
result = _ensure_tensor(result_or_reference)
170+
reference_tensor = _ensure_tensor(reference)
171+
172+
if result.shape != reference_tensor.shape:
140173
raise ValueError(
141-
f"Output needs to be of same shape: {result.shape} != {reference.shape}"
174+
f"Output needs to be of same shape: {result.shape} != {reference_tensor.shape}"
142175
)
143176
shape = result.shape
144177

@@ -161,29 +194,29 @@ def print_error_diffs(
161194

162195
# Reshape tensors to 4D NCHW format
163196
result = torch.reshape(result, (N, C, H, W))
164-
reference = torch.reshape(reference, (N, C, H, W))
197+
reference_tensor = torch.reshape(reference_tensor, (N, C, H, W))
165198

166199
output_str = ""
167200
for n in range(N):
168201
output_str += f"BATCH {n}\n"
169202
result_batch = result[n, :, :, :]
170-
reference_batch = reference[n, :, :, :]
203+
reference_batch = reference_tensor[n, :, :, :]
171204

172205
is_close = torch.allclose(result_batch, reference_batch, rtol, atol)
173206
if is_close:
174207
output_str += ".\n"
175208
else:
176-
channels_close = [None] * C
209+
channels_close: list[bool] = [False] * C
177210
for c in range(C):
178211
result_hw = result[n, c, :, :]
179-
reference_hw = reference[n, c, :, :]
212+
reference_hw = reference_tensor[n, c, :, :]
180213

181214
channels_close[c] = torch.allclose(result_hw, reference_hw, rtol, atol)
182215

183216
if any(channels_close) or len(channels_close) == 1:
184217
output_str += _print_channels(
185218
result[n, :, :, :],
186-
reference[n, :, :, :],
219+
reference_tensor[n, :, :, :],
187220
channels_close,
188221
C,
189222
H,
@@ -193,17 +226,23 @@ def print_error_diffs(
193226
)
194227
else:
195228
output_str += _print_elements(
196-
result[n, :, :, :], reference[n, :, :, :], C, H, W, rtol, atol
229+
result[n, :, :, :],
230+
reference_tensor[n, :, :, :],
231+
C,
232+
H,
233+
W,
234+
rtol,
235+
atol,
197236
)
198237
if reference_batch.dtype == torch.bool or result_batch.dtype == torch.bool:
199238
mismatches = (reference_batch != result_batch).sum().item()
200239
total = reference_batch.numel()
201240
output_str += f"(BOOLEAN tensor) {mismatches} / {total} elements differ ({mismatches / total:.2%})\n"
202241

203242
# Only compute numeric error metrics if tensor is not boolean
204-
if reference.dtype != torch.bool and result.dtype != torch.bool:
205-
reference_range = torch.max(reference) - torch.min(reference)
206-
diff = torch.abs(reference - result).flatten()
243+
if reference_tensor.dtype != torch.bool and result.dtype != torch.bool:
244+
reference_range = torch.max(reference_tensor) - torch.min(reference_tensor)
245+
diff = torch.abs(reference_tensor - result).flatten()
207246
diff = diff[diff.nonzero()]
208247
if not len(diff) == 0:
209248
diff_percent = diff / reference_range
@@ -230,14 +269,14 @@ def print_error_diffs(
230269

231270

232271
def dump_error_output(
233-
tester,
234-
reference_output,
235-
stage_output,
236-
quantization_scale=None,
237-
atol=1e-03,
238-
rtol=1e-03,
239-
qtol=0,
240-
):
272+
tester: Any,
273+
reference_output: TensorLike,
274+
stage_output: TensorLike,
275+
quantization_scale: float | None = None,
276+
atol: float = 1e-03,
277+
rtol: float = 1e-03,
278+
qtol: float = 0,
279+
) -> None:
241280
"""
242281
Prints Quantization info and error tolerances, and saves the differing tensors to disc.
243282
"""

0 commit comments

Comments
 (0)