@@ -30,7 +30,24 @@ def nearly_equal(
3030 return diff < max (abs_tol , rel_tol * norm )
3131
3232
33- def verify_buffer (output , buf_name , reference , rel_tol = 0.04 , abs_tol = 1e-6 ):
33+ def verify_buffer (
34+ output , buf_name , reference , rel_tol = 0.04 , abs_tol = 1e-6 , max_error_rate = 0.0
35+ ):
36+ """
37+ Verify buffer contents match reference within tolerances.
38+
39+ Args:
40+ output: Output buffer to verify
41+ buf_name: Name of buffer for error messages
42+ reference: Reference data to compare against
43+ rel_tol: Relative tolerance for comparison
44+ abs_tol: Absolute tolerance for comparison
45+ max_error_rate: Maximum fraction of elements allowed to exceed tolerances (0.0 to 1.0)
46+ For example, 0.01 allows up to 1% of elements to fail
47+
48+ Returns:
49+ List of error indices. Empty if verification passes.
50+ """
3451 errors = []
3552 expected_np = torch_to_numpy (reference ).reshape ((- 1 ,))
3653 output = output .reshape ((- 1 ,))
@@ -49,6 +66,21 @@ def verify_buffer(output, buf_name, reference, rel_tol=0.04, abs_tol=1e-6):
4966 print (
5067 f"Mismatch in { buf_name } [{ i } ]: expected { float (expected_np [i ]):.6f} , got { float (output [i ]):.6f} "
5168 )
69+
70+ # Check if error rate is acceptable
71+ if max_error_rate > 0.0 and len (errors ) > 0 :
72+ error_rate = len (errors ) / compare_len
73+ max_allowed_errors = int (compare_len * max_error_rate )
74+ if len (errors ) <= max_allowed_errors :
75+ print (
76+ f"{ buf_name } : { len (errors )} errors ({ error_rate * 100 :.2f} %) within allowed rate of { max_error_rate * 100 :.2f} % ({ max_allowed_errors } errors)"
77+ )
78+ return [] # Pass - within allowed error rate
79+ else :
80+ print (
81+ f"{ buf_name } : { len (errors )} errors ({ error_rate * 100 :.2f} %) exceeds allowed rate of { max_error_rate * 100 :.2f} % ({ max_allowed_errors } errors)"
82+ )
83+
5284 return errors
5385
5486
@@ -59,6 +91,7 @@ def run_test(
5991 intermediate_buffers = None ,
6092 rel_tol = 0.04 ,
6193 abs_tol = 1e-6 ,
94+ max_error_rate = 0.0 ,
6295 warmup_iters = 1 ,
6396 timed_iters = 1 ,
6497):
@@ -72,6 +105,7 @@ def run_test(
72105 intermediate_buffers: Optional dict mapping buffer names to reference arrays for validation
73106 rel_tol: Relative tolerance for comparison of output and intermediate buffers
74107 abs_tol: Absolute tolerance for comparison of output and intermediate buffers
108+ max_error_rate: Maximum fraction of elements allowed to exceed tolerances (0.0 to 1.0)
75109
76110 Returns:
77111 (errors: list, latency_us: float, bandwidth_gbps: float)
@@ -144,7 +178,9 @@ def run_test(
144178 if buf_name in output_map :
145179 buf = output_map [buf_name ]
146180 output_np = buf .view_as_np ()
147- buf_errors = verify_buffer (output_np , buf_name , expected , rel_tol , abs_tol )
181+ buf_errors = verify_buffer (
182+ output_np , buf_name , expected , rel_tol , abs_tol , max_error_rate
183+ )
148184 if buf_errors :
149185 errors [buf_name ] = buf_errors
150186 else :
0 commit comments