1+ import math
12import random
23from collections import Counter , OrderedDict
34from typing import Any , Callable , Dict , List , Optional , Tuple
1718 ToExecutorch ,
1819)
1920from executorch .exir .dim_order_utils import get_memory_format
21+ from torch .ao .ns .fx .utils import compute_sqnr
2022
2123from torch .export import ExportedProgram
2224from torch .testing import FileCheck
@@ -302,13 +304,13 @@ def run_method_and_compare_outputs(
302304 atol = 1e-03 ,
303305 rtol = 1e-03 ,
304306 qtol = 0 ,
307+ snr : float | None = None ,
305308 ):
306309 number_of_runs = 1 if inputs is not None else num_runs
307310 reference_stage = self .stages [StageType .EXPORT ]
308311
309312 stage = stage or self .cur
310313
311- print (f"Comparing Stage { stage } with Stage { reference_stage } " )
312314 for run_iteration in range (number_of_runs ):
313315 inputs_to_run = inputs if inputs else next (self .generate_random_inputs ())
314316 input_shapes = [
@@ -328,13 +330,21 @@ def run_method_and_compare_outputs(
328330 # Output from running artifact at stage
329331 stage_output = self .stages [stage ].run_artifact (inputs_to_run )
330332 self ._compare_outputs (
331- reference_output , stage_output , quantization_scale , atol , rtol , qtol
333+ reference_output ,
334+ stage_output ,
335+ quantization_scale ,
336+ atol ,
337+ rtol ,
338+ qtol ,
339+ snr ,
332340 )
333341
334342 return self
335343
336344 @staticmethod
337- def _assert_outputs_equal (model_output , ref_output , atol = 1e-03 , rtol = 1e-03 ):
345+ def _assert_outputs_equal (
346+ model_output , ref_output , atol = 1e-03 , rtol = 1e-03 , snr : float | None = None
347+ ):
338348 """
339349 Helper testing function that asserts that the model output and the reference output
340350 are equal with some tolerance. Due to numerical differences between eager mode and
@@ -359,15 +369,22 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
359369 f"\t Mismatched count: { (model != ref ).sum ().item ()} / { model .numel ()} \n "
360370 )
361371 else :
362- assert torch .allclose (
363- model ,
364- ref ,
365- atol = atol ,
366- rtol = rtol ,
367- equal_nan = True ,
372+ computed_snr = compute_sqnr (model .to (torch .float ), ref .to (torch .float ))
373+ snr = snr or float ("-inf" )
374+
375+ assert (
376+ torch .allclose (
377+ model ,
378+ ref ,
379+ atol = atol ,
380+ rtol = rtol ,
381+ equal_nan = True ,
382+ )
383+ and computed_snr >= snr
384+ or math .isnan (computed_snr )
368385 ), (
369386 f"Output { i } does not match reference output.\n "
370- f"\t Given atol: { atol } , rtol: { rtol } .\n "
387+ f"\t Given atol: { atol } , rtol: { rtol } , snr: { snr } .\n "
371388 f"\t Output tensor shape: { model .shape } , dtype: { model .dtype } \n "
372389 f"\t Difference: max: { torch .max (model - ref )} , abs: { torch .max (torch .abs (model - ref ))} , mean abs error: { torch .mean (torch .abs (model - ref ).to (torch .double ))} .\n "
373390 f"\t -- Model vs. Reference --\n "
@@ -376,6 +393,7 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
376393 f"\t Mean: { model .to (torch .double ).mean ()} , { ref .to (torch .double ).mean ()} \n "
377394 f"\t Max: { model .max ()} , { ref .max ()} \n "
378395 f"\t Min: { model .min ()} , { ref .min ()} \n "
396+ f"\t SNR: { computed_snr } \n "
379397 )
380398
381399 @staticmethod
@@ -386,6 +404,7 @@ def _compare_outputs(
386404 atol = 1e-03 ,
387405 rtol = 1e-03 ,
388406 qtol = 0 ,
407+ snr : float | None = None ,
389408 ):
390409 """
391410 Compares the original of the original nn module with the output of the generated artifact.
@@ -408,6 +427,7 @@ def _compare_outputs(
408427 reference_output ,
409428 atol = atol ,
410429 rtol = rtol ,
430+ snr = snr ,
411431 )
412432
413433 @staticmethod
0 commit comments