1- import random
21from collections import Counter , OrderedDict
2+ from torch .ao .ns .fx .utils import compute_sqnr
33from typing import Any , Callable , Dict , List , Optional , Tuple
44
5+ import math
6+ import random
57import torch
68
79from executorch .backends .test .harness .stages import (
@@ -302,17 +304,18 @@ def run_method_and_compare_outputs(
302304 atol = 1e-03 ,
303305 rtol = 1e-03 ,
304306 qtol = 0 ,
307+ snr : float | None = None ,
305308 ):
306309 number_of_runs = 1 if inputs is not None else num_runs
307310 reference_stage = self .stages [StageType .EXPORT ]
308311
309312 stage = stage or self .cur
310313
311- print (f"Comparing Stage { stage } with Stage { reference_stage } " )
314+ # print(f"Comparing Stage {stage} with Stage {reference_stage}")
312315 for run_iteration in range (number_of_runs ):
313316 inputs_to_run = inputs if inputs else next (self .generate_random_inputs ())
314317 input_shapes = [generated_input .shape for generated_input in inputs_to_run ]
315- print (f"Run { run_iteration } with input shapes: { input_shapes } " )
318+ # print(f"Run {run_iteration} with input shapes: {input_shapes}")
316319
317320 # Reference output (and quantization scale)
318321 (
@@ -325,13 +328,13 @@ def run_method_and_compare_outputs(
325328 # Output from running artifact at stage
326329 stage_output = self .stages [stage ].run_artifact (inputs_to_run )
327330 self ._compare_outputs (
328- reference_output , stage_output , quantization_scale , atol , rtol , qtol
331+ reference_output , stage_output , quantization_scale , atol , rtol , qtol , snr
329332 )
330333
331334 return self
332335
333336 @staticmethod
334- def _assert_outputs_equal (model_output , ref_output , atol = 1e-03 , rtol = 1e-03 ):
337+ def _assert_outputs_equal (model_output , ref_output , atol = 1e-03 , rtol = 1e-03 , snr : float | None = None ):
335338 """
336339 Helper testing function that asserts that the model output and the reference output
337340 are equal with some tolerance. Due to numerical differences between eager mode and
@@ -356,15 +359,18 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
356359 f"\t Mismatched count: { (model != ref ).sum ().item ()} / { model .numel ()} \n "
357360 )
358361 else :
362+ computed_snr = compute_sqnr (model .to (torch .float ), ref .to (torch .float ))
363+ snr = snr or float ("-inf" )
364+
359365 assert torch .allclose (
360366 model ,
361367 ref ,
362368 atol = atol ,
363369 rtol = rtol ,
364370 equal_nan = True ,
365- ), (
371+ ) and computed_snr >= snr or math . isnan ( computed_snr ) , (
366372 f"Output { i } does not match reference output.\n "
367- f"\t Given atol: { atol } , rtol: { rtol } .\n "
373+ f"\t Given atol: { atol } , rtol: { rtol } , snr: { snr } .\n "
368374 f"\t Output tensor shape: { model .shape } , dtype: { model .dtype } \n "
369375 f"\t Difference: max: { torch .max (model - ref )} , abs: { torch .max (torch .abs (model - ref ))} , mean abs error: { torch .mean (torch .abs (model - ref ).to (torch .double ))} .\n "
370376 f"\t -- Model vs. Reference --\n "
@@ -373,8 +379,10 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
373379 f"\t Mean: { model .to (torch .double ).mean ()} , { ref .to (torch .double ).mean ()} \n "
374380 f"\t Max: { model .max ()} , { ref .max ()} \n "
375381 f"\t Min: { model .min ()} , { ref .min ()} \n "
382+ f"\t SNR: { computed_snr } \n "
376383 )
377384
385+
378386 @staticmethod
379387 def _compare_outputs (
380388 reference_output ,
@@ -383,6 +391,7 @@ def _compare_outputs(
383391 atol = 1e-03 ,
384392 rtol = 1e-03 ,
385393 qtol = 0 ,
394+ snr : float | None = None ,
386395 ):
387396 """
388397 Compares the original of the original nn module with the output of the generated artifact.
@@ -405,6 +414,7 @@ def _compare_outputs(
405414 reference_output ,
406415 atol = atol ,
407416 rtol = rtol ,
417+ snr = snr ,
408418 )
409419
410420 @staticmethod
0 commit comments