1- import math
21import random
32from collections import Counter , OrderedDict
43from typing import Any , Callable , Dict , List , Optional , Tuple
1817 ToExecutorch ,
1918)
2019from executorch .exir .dim_order_utils import get_memory_format
21- from torch .ao .ns .fx .utils import compute_sqnr
2220
2321from torch .export import ExportedProgram
2422from torch .testing import FileCheck
@@ -304,13 +302,13 @@ def run_method_and_compare_outputs(
304302 atol = 1e-03 ,
305303 rtol = 1e-03 ,
306304 qtol = 0 ,
307- snr : float | None = None ,
308305 ):
309306 number_of_runs = 1 if inputs is not None else num_runs
310307 reference_stage = self .stages [StageType .EXPORT ]
311308
312309 stage = stage or self .cur
313310
311+ print (f"Comparing Stage { stage } with Stage { reference_stage } " )
314312 for run_iteration in range (number_of_runs ):
315313 inputs_to_run = inputs if inputs else next (self .generate_random_inputs ())
316314 input_shapes = [
@@ -330,21 +328,13 @@ def run_method_and_compare_outputs(
330328 # Output from running artifact at stage
331329 stage_output = self .stages [stage ].run_artifact (inputs_to_run )
332330 self ._compare_outputs (
333- reference_output ,
334- stage_output ,
335- quantization_scale ,
336- atol ,
337- rtol ,
338- qtol ,
339- snr ,
331+ reference_output , stage_output , quantization_scale , atol , rtol , qtol
340332 )
341333
342334 return self
343335
344336 @staticmethod
345- def _assert_outputs_equal (
346- model_output , ref_output , atol = 1e-03 , rtol = 1e-03 , snr : float | None = None
347- ):
337+ def _assert_outputs_equal (model_output , ref_output , atol = 1e-03 , rtol = 1e-03 ):
348338 """
349339 Helper testing function that asserts that the model output and the reference output
350340 are equal with some tolerance. Due to numerical differences between eager mode and
@@ -369,22 +359,15 @@ def _assert_outputs_equal(
369359 f"\t Mismatched count: { (model != ref ).sum ().item ()} / { model .numel ()} \n "
370360 )
371361 else :
372- computed_snr = compute_sqnr (model .to (torch .float ), ref .to (torch .float ))
373- snr = snr or float ("-inf" )
374-
375- assert (
376- torch .allclose (
377- model ,
378- ref ,
379- atol = atol ,
380- rtol = rtol ,
381- equal_nan = True ,
382- )
383- and computed_snr >= snr
384- or math .isnan (computed_snr )
362+ assert torch .allclose (
363+ model ,
364+ ref ,
365+ atol = atol ,
366+ rtol = rtol ,
367+ equal_nan = True ,
385368 ), (
386369 f"Output { i } does not match reference output.\n "
387- f"\t Given atol: { atol } , rtol: { rtol } , snr: { snr } .\n "
370+ f"\t Given atol: { atol } , rtol: { rtol } .\n "
388371 f"\t Output tensor shape: { model .shape } , dtype: { model .dtype } \n "
389372 f"\t Difference: max: { torch .max (model - ref )} , abs: { torch .max (torch .abs (model - ref ))} , mean abs error: { torch .mean (torch .abs (model - ref ).to (torch .double ))} .\n "
390373 f"\t -- Model vs. Reference --\n "
@@ -393,7 +376,6 @@ def _assert_outputs_equal(
393376 f"\t Mean: { model .to (torch .double ).mean ()} , { ref .to (torch .double ).mean ()} \n "
394377 f"\t Max: { model .max ()} , { ref .max ()} \n "
395378 f"\t Min: { model .min ()} , { ref .min ()} \n "
396- f"\t SNR: { computed_snr } \n "
397379 )
398380
399381 @staticmethod
@@ -404,7 +386,6 @@ def _compare_outputs(
404386 atol = 1e-03 ,
405387 rtol = 1e-03 ,
406388 qtol = 0 ,
407- snr : float | None = None ,
408389 ):
409390 """
410391 Compares the original of the original nn module with the output of the generated artifact.
@@ -427,7 +408,6 @@ def _compare_outputs(
427408 reference_output ,
428409 atol = atol ,
429410 rtol = rtol ,
430- snr = snr ,
431411 )
432412
433413 @staticmethod
0 commit comments