Skip to content

Commit 76a8906

Browse files
authored
Adjust atol/rtol for ring attention's quantized kv cache test (#13909)
Summary: In another PR, #13722, for whatever reason, this test was failing. Adjusting the margin here since I have seen this fail before on trunk but somehow it got resolved. So there is some level of flakiness particularly around quantized kv cache + ring attention Test Plan: CI Reviewers: Subscribers: Tasks: Tags: ### Summary [PLEASE REMOVE] See [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests) for ExecuTorch PR guidelines. [PLEASE REMOVE] If this PR closes an issue, please add a `Fixes #<issue-id>` line. [PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: <area>" label. For a list of available release notes labels, check out [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests). ### Test plan [PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.
1 parent aa08df5 commit 76a8906

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

examples/models/llama/tests/test_ring_attention.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,17 @@ def test_single_token_processing(
163163
)
164164

165165
# Check that outputs are the same
166-
self.assertTrue(
167-
torch.allclose(baseline_out, ring_out, rtol=1e-7, atol=1e-7),
168-
f"Outputs differ at position {pos}",
169-
)
166+
if kv_cache_type == KVCacheType.REGULAR:
167+
self.assertTrue(
168+
torch.allclose(baseline_out, ring_out, rtol=1e-7, atol=1e-7),
169+
f"Outputs differ at position {pos}",
170+
)
171+
else:
172+
# For quantized kv cache we need bigger margin
173+
self.assertTrue(
174+
torch.allclose(baseline_out, ring_out, rtol=1e-6, atol=1e-6),
175+
f"Outputs differ at position {pos}",
176+
)
170177

171178
def test_single_token_processing_quantized(self):
172179
"""Test single token processing with QuantizedKVCache."""

0 commit comments

Comments
 (0)