diff --git a/tests/attention/test_xqa.py b/tests/attention/test_xqa.py index 3fac89f726..c22cff3f74 100644 --- a/tests/attention/test_xqa.py +++ b/tests/attention/test_xqa.py @@ -458,7 +458,7 @@ def test_xqa( total_elements = ref_output_batch.numel() passing_elements = within_tolerance.sum().item() pass_ratio = passing_elements / total_elements - required_ratio = 0.99 + required_ratio = 0.98 assert pass_ratio >= required_ratio, ( f"Batch validation failed: " diff --git a/tests/utils/test_logits_processor.py b/tests/utils/test_logits_processor.py index 44baca82e3..f4dfbea7ce 100644 --- a/tests/utils/test_logits_processor.py +++ b/tests/utils/test_logits_processor.py @@ -823,9 +823,13 @@ def test_sequential_probs_topk_topp_sample(self, batch_size, vocab_size, p): pipe = LogitsPipe([TopK(), TopP(), Sample()], input_type=TensorType.PROBS) samples_pipe = pipe(probs, top_k=k, top_p=p, generator=gen2) - # Allow small differences due to floating point precision in intermediate steps + # Allow small differences due to floating point precision in intermediate steps. + # Threshold accounts for batch-size granularity (1/batch_size per mismatch). diff_ratio = (samples_pipe != samples_direct).sum().item() / batch_size - assert diff_ratio < 0.02, f"Too many differences: {diff_ratio * 100:.2f}%" + threshold = max(0.03, 2.0 / batch_size) + assert diff_ratio < threshold, ( + f"Too many differences: {diff_ratio * 100:.2f}% (threshold: {threshold * 100:.2f}%)" + ) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256])