55
66import logging
77import tempfile
8+ from typing import Any , cast , Sequence
89
910import torch
1011from executorch .backends .arm .test .runner_utils import (
1718logger = logging .getLogger (__name__ )
1819
1920
20- def _print_channels (result , reference , channels_close , C , H , W , rtol , atol ):
21+ TensorLike = torch .Tensor | tuple [torch .Tensor , ...]
22+
23+
24+ def _ensure_tensor (value : TensorLike ) -> torch .Tensor :
25+ if isinstance (value , torch .Tensor ):
26+ return value
27+ if value and isinstance (value [0 ], torch .Tensor ):
28+ return value [0 ]
29+ raise TypeError ("Expected a Tensor or a non-empty tuple of Tensors" )
30+
31+
32+ def _print_channels (
33+ result : torch .Tensor ,
34+ reference : torch .Tensor ,
35+ channels_close : Sequence [bool ],
36+ C : int ,
37+ H : int ,
38+ W : int ,
39+ rtol : float ,
40+ atol : float ,
41+ ) -> str :
2142
2243 output_str = ""
44+ exp = "000"
2345 booldata = False
2446 if reference .dtype == torch .bool or result .dtype == torch .bool :
2547 booldata = True
@@ -62,7 +84,15 @@ def _print_channels(result, reference, channels_close, C, H, W, rtol, atol):
6284 return output_str
6385
6486
65- def _print_elements (result , reference , C , H , W , rtol , atol ):
87+ def _print_elements (
88+ result : torch .Tensor ,
89+ reference : torch .Tensor ,
90+ C : int ,
91+ H : int ,
92+ W : int ,
93+ rtol : float ,
94+ atol : float ,
95+ ) -> str :
6696 output_str = ""
6797 for y in range (H ):
6898 res = "["
@@ -92,14 +122,16 @@ def _print_elements(result, reference, C, H, W, rtol, atol):
92122
93123
94124def print_error_diffs (
95- tester ,
96- result : torch .Tensor | tuple ,
97- reference : torch .Tensor | tuple ,
98- quantization_scale = None ,
99- atol = 1e-03 ,
100- rtol = 1e-03 ,
101- qtol = 0 ,
102- ):
125+ tester_or_result : Any ,
126+ result_or_reference : TensorLike ,
127+ reference : TensorLike | None = None ,
128+ # Force remaining args to be keyword-only to keep the two positional call patterns unambiguous.
129+ * ,
130+ quantization_scale : float | None = None ,
131+ atol : float = 1e-03 ,
132+ rtol : float = 1e-03 ,
133+ qtol : float = 0 ,
134+ ) -> None :
103135 """
104136 Prints the error difference between a result tensor and a reference tensor in NCHW format.
105137 Certain formatting rules are applied to clarify errors:
@@ -130,15 +162,16 @@ def print_error_diffs(
130162
131163
132164 """
133-
134- if isinstance (reference , tuple ):
135- reference = reference [0 ]
136- if isinstance (result , tuple ):
137- result = result [0 ]
138-
139- if not result .shape == reference .shape :
165+ if reference is None :
166+ result = _ensure_tensor (cast (TensorLike , tester_or_result ))
167+ reference_tensor = _ensure_tensor (result_or_reference )
168+ else :
169+ result = _ensure_tensor (result_or_reference )
170+ reference_tensor = _ensure_tensor (reference )
171+
172+ if result .shape != reference_tensor .shape :
140173 raise ValueError (
141- f"Output needs to be of same shape: { result .shape } != { reference .shape } "
174+ f"Output needs to be of same shape: { result .shape } != { reference_tensor .shape } "
142175 )
143176 shape = result .shape
144177
@@ -161,29 +194,29 @@ def print_error_diffs(
161194
162195 # Reshape tensors to 4D NCHW format
163196 result = torch .reshape (result , (N , C , H , W ))
164- reference = torch .reshape (reference , (N , C , H , W ))
197+ reference_tensor = torch .reshape (reference_tensor , (N , C , H , W ))
165198
166199 output_str = ""
167200 for n in range (N ):
168201 output_str += f"BATCH { n } \n "
169202 result_batch = result [n , :, :, :]
170- reference_batch = reference [n , :, :, :]
203+ reference_batch = reference_tensor [n , :, :, :]
171204
172205 is_close = torch .allclose (result_batch , reference_batch , rtol , atol )
173206 if is_close :
174207 output_str += ".\n "
175208 else :
176- channels_close = [None ] * C
209+ channels_close : list [ bool ] = [False ] * C
177210 for c in range (C ):
178211 result_hw = result [n , c , :, :]
179- reference_hw = reference [n , c , :, :]
212+ reference_hw = reference_tensor [n , c , :, :]
180213
181214 channels_close [c ] = torch .allclose (result_hw , reference_hw , rtol , atol )
182215
183216 if any (channels_close ) or len (channels_close ) == 1 :
184217 output_str += _print_channels (
185218 result [n , :, :, :],
186- reference [n , :, :, :],
219+ reference_tensor [n , :, :, :],
187220 channels_close ,
188221 C ,
189222 H ,
@@ -193,17 +226,23 @@ def print_error_diffs(
193226 )
194227 else :
195228 output_str += _print_elements (
196- result [n , :, :, :], reference [n , :, :, :], C , H , W , rtol , atol
229+ result [n , :, :, :],
230+ reference_tensor [n , :, :, :],
231+ C ,
232+ H ,
233+ W ,
234+ rtol ,
235+ atol ,
197236 )
198237 if reference_batch .dtype == torch .bool or result_batch .dtype == torch .bool :
199238 mismatches = (reference_batch != result_batch ).sum ().item ()
200239 total = reference_batch .numel ()
201240 output_str += f"(BOOLEAN tensor) { mismatches } / { total } elements differ ({ mismatches / total :.2%} )\n "
202241
203242 # Only compute numeric error metrics if tensor is not boolean
204- if reference .dtype != torch .bool and result .dtype != torch .bool :
205- reference_range = torch .max (reference ) - torch .min (reference )
206- diff = torch .abs (reference - result ).flatten ()
243+ if reference_tensor .dtype != torch .bool and result .dtype != torch .bool :
244+ reference_range = torch .max (reference_tensor ) - torch .min (reference_tensor )
245+ diff = torch .abs (reference_tensor - result ).flatten ()
207246 diff = diff [diff .nonzero ()]
208247 if not len (diff ) == 0 :
209248 diff_percent = diff / reference_range
@@ -230,14 +269,14 @@ def print_error_diffs(
230269
231270
232271def dump_error_output (
233- tester ,
234- reference_output ,
235- stage_output ,
236- quantization_scale = None ,
237- atol = 1e-03 ,
238- rtol = 1e-03 ,
239- qtol = 0 ,
240- ):
272+ tester : Any ,
273+ reference_output : TensorLike ,
274+ stage_output : TensorLike ,
275+ quantization_scale : float | None = None ,
276+ atol : float = 1e-03 ,
277+ rtol : float = 1e-03 ,
278+ qtol : float = 0 ,
279+ ) -> None :
241280 """
242281 Prints Quantization info and error tolerances, and saves the differing tensors to disc.
243282 """
0 commit comments