@@ -398,9 +398,19 @@ def test_trtllm_batch_prefill(
398
398
else :
399
399
rtol , atol = 1e-2 , 1e-2
400
400
401
+ # Arbitary small mismatch rate
402
+ allowed_mismatch_rate = 1e-7
403
+ # Calculate max allowed mismatched elements based on tensor size
404
+ total_elements = (output .float () * o_scale ).numel ()
405
+ max_mismatched_elements = int (allowed_mismatch_rate * total_elements )
406
+
401
407
# convert to float32 for fp8 is not supported by assert_close
402
- torch .testing .assert_close (
403
- output .float () * o_scale , output_ref .float (), rtol = rtol , atol = atol
408
+ assert_close_with_mismatch_tolerance (
409
+ output .float () * o_scale ,
410
+ output_ref .float (),
411
+ rtol = rtol ,
412
+ atol = atol ,
413
+ max_mismatched_elements = max_mismatched_elements ,
404
414
)
405
415
406
416
if o_dtype != "nvfp4" : # wrapper api does not support fp4 output yet.
@@ -621,11 +631,18 @@ def test_trtllm_batch_decode(
621
631
if q_len_per_req > 1 :
622
632
rtol , atol = rtol * 2 , atol * 2
623
633
624
- torch .testing .assert_close (
634
+ # Arbitary small mismatch rate
635
+ allowed_mismatch_rate = 5e-5
636
+ # Calculate max allowed mismatched elements based on tensor size
637
+ total_elements = (output .float () * o_scale ).numel ()
638
+ max_mismatched_elements = int (allowed_mismatch_rate * total_elements )
639
+
640
+ assert_close_with_mismatch_tolerance (
625
641
output .float () * o_scale ,
626
642
output_ref .float (),
627
643
rtol = rtol ,
628
644
atol = atol ,
645
+ max_mismatched_elements = max_mismatched_elements ,
629
646
)
630
647
631
648
if o_dtype != "nvfp4" : # wrapper api does not support fp4 output yet.
0 commit comments