44
55import torch
66
7+ from executorch .backends .test .harness .error_statistics import ErrorStatistics
78from executorch .backends .test .harness .stages import (
89 Export ,
910 Partition ,
@@ -182,10 +183,10 @@ def _post(self, stage):
182183 assert stage_type in self .stages
183184 self .stages [stage_type ] = stage
184185
185- def _run_stage (self , stage_instance , inputs = None ):
186+ def _run_stage (self , stage_instance , inputs = None , * args , ** kwargs ):
186187 assert isinstance (stage_instance , Stage )
187188 prev_stage_artifact = self ._pre (stage_instance )
188- stage_instance .run (prev_stage_artifact , inputs = inputs )
189+ stage_instance .run (prev_stage_artifact , inputs = inputs , * args , ** kwargs ) # noqa
189190 self ._post (stage_instance )
190191 return self
191192
@@ -212,11 +213,14 @@ def to_edge(self, to_edge_stage: Optional[ToEdge] = None):
212213 return res
213214
214215 def to_edge_transform_and_lower (
215- self , to_edge_and_transform_stage : Optional [ToEdgeTransformAndLower ] = None
216+ self ,
217+ to_edge_and_transform_stage : Optional [ToEdgeTransformAndLower ] = None ,
218+ generate_etrecord : bool = False ,
216219 ):
217220 return self ._run_stage (
218221 to_edge_and_transform_stage
219- or self ._get_default_stage (StageType .TO_EDGE_TRANSFORM_AND_LOWER )
222+ or self ._get_default_stage (StageType .TO_EDGE_TRANSFORM_AND_LOWER ),
223+ generate_etrecord = generate_etrecord ,
220224 )
221225
222226 def run_passes (self , run_passes_stage : Optional [RunPasses ] = None ):
@@ -302,20 +306,15 @@ def run_method_and_compare_outputs(
302306 atol = 1e-03 ,
303307 rtol = 1e-03 ,
304308 qtol = 0 ,
309+ statistics_callback : Callable [[ErrorStatistics ], None ] | None = None ,
305310 ):
306311 number_of_runs = 1 if inputs is not None else num_runs
307312 reference_stage = self .stages [StageType .EXPORT ]
308313
309314 stage = stage or self .cur
310315
311- print (f"Comparing Stage { stage } with Stage { reference_stage } " )
312- for run_iteration in range (number_of_runs ):
316+ for _ in range (number_of_runs ):
313317 inputs_to_run = inputs if inputs else next (self .generate_random_inputs ())
314- input_shapes = [
315- generated_input .shape if hasattr (generated_input , "shape" ) else None
316- for generated_input in inputs_to_run
317- ]
318- print (f"Run { run_iteration } with input shapes: { input_shapes } " )
319318
320319 # Reference output (and quantization scale)
321320 (
@@ -328,13 +327,25 @@ def run_method_and_compare_outputs(
328327 # Output from running artifact at stage
329328 stage_output = self .stages [stage ].run_artifact (inputs_to_run )
330329 self ._compare_outputs (
331- reference_output , stage_output , quantization_scale , atol , rtol , qtol
330+ reference_output ,
331+ stage_output ,
332+ quantization_scale ,
333+ atol ,
334+ rtol ,
335+ qtol ,
336+ statistics_callback ,
332337 )
333338
334339 return self
335340
336341 @staticmethod
337- def _assert_outputs_equal (model_output , ref_output , atol = 1e-03 , rtol = 1e-03 ):
342+ def _assert_outputs_equal (
343+ model_output ,
344+ ref_output ,
345+ atol = 1e-03 ,
346+ rtol = 1e-03 ,
347+ statistics_callback : Callable [[ErrorStatistics ], None ] | None = None ,
348+ ):
338349 """
339350 Helper testing function that asserts that the model output and the reference output
340351 are equal with some tolerance. Due to numerical differences between eager mode and
@@ -349,6 +360,11 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
349360 for i in range (len (model_output )):
350361 model = model_output [i ]
351362 ref = ref_output [i ]
363+
364+ error_stats = ErrorStatistics .from_tensors (model , ref )
365+ if statistics_callback is not None :
366+ statistics_callback (error_stats )
367+
352368 assert (
353369 ref .shape == model .shape
354370 ), f"Output { i } shape { model .shape } does not match reference output shape { ref .shape } "
@@ -386,6 +402,7 @@ def _compare_outputs(
386402 atol = 1e-03 ,
387403 rtol = 1e-03 ,
388404 qtol = 0 ,
405+ statistics_callback : Callable [[ErrorStatistics ], None ] | None = None ,
389406 ):
390407 """
391408 Compares the original of the original nn module with the output of the generated artifact.
@@ -408,6 +425,7 @@ def _compare_outputs(
408425 reference_output ,
409426 atol = atol ,
410427 rtol = rtol ,
428+ statistics_callback = statistics_callback ,
411429 )
412430
413431 @staticmethod
0 commit comments