Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/attention/test_xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def test_xqa(
total_elements = ref_output_batch.numel()
passing_elements = within_tolerance.sum().item()
pass_ratio = passing_elements / total_elements
required_ratio = 0.99
required_ratio = 0.98

assert pass_ratio >= required_ratio, (
f"Batch validation failed: "
Expand Down
8 changes: 6 additions & 2 deletions tests/utils/test_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,9 +823,13 @@ def test_sequential_probs_topk_topp_sample(self, batch_size, vocab_size, p):
pipe = LogitsPipe([TopK(), TopP(), Sample()], input_type=TensorType.PROBS)
samples_pipe = pipe(probs, top_k=k, top_p=p, generator=gen2)

# Allow small differences due to floating point precision in intermediate steps
# Allow small differences due to floating point precision in intermediate steps.
# Threshold accounts for batch-size granularity (1/batch_size per mismatch).
diff_ratio = (samples_pipe != samples_direct).sum().item() / batch_size
assert diff_ratio < 0.02, f"Too many differences: {diff_ratio * 100:.2f}%"
threshold = max(0.03, 2.0 / batch_size)
assert diff_ratio < threshold, (
f"Too many differences: {diff_ratio * 100:.2f}% (threshold: {threshold * 100:.2f}%)"
)

@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
Expand Down
Loading