Skip to content

Commit f91995b

Browse files
committed
as
1 parent 1ad2c09 commit f91995b

File tree

2 files changed

+145
-94
lines changed

2 files changed

+145
-94
lines changed

examples/mlx_spda_optimization/evaluator.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,22 @@ def compare_attention_outputs(
217217

218218
# Check MLX's allclose function
219219
allclose_result = bool(mx.allclose(output1, output2, atol=tolerance, rtol=tolerance))
220+
221+
# Additional robust check: if MSE is extremely small, consider it a match
222+
# This handles cases where allclose is too strict due to floating-point precision
223+
mse_perfect = mse < 1e-8
224+
225+
# Final decision: either allclose passes OR MSE is extremely small
226+
final_allclose = allclose_result or mse_perfect
220227

221228
return {
222229
"mse": mse,
223230
"mae": mae,
224231
"max_diff": max_diff,
225232
"relative_error": relative_error,
226-
"allclose": allclose_result,
233+
"allclose": final_allclose,
234+
"allclose_strict": allclose_result,
235+
"mse_perfect": mse_perfect,
227236
"tolerance_used": tolerance,
228237
}
229238

@@ -316,11 +325,12 @@ def test_correctness_by_category(evolved_attention_fn, config: Dict) -> Dict[str
316325
# Adjust tolerance based on category
317326
if category == "short":
318327
# Short sequences should be nearly perfect (using mx.fast.scaled_dot_product_attention)
319-
tolerance = 1e-5
328+
# Use slightly more forgiving tolerance to account for floating-point precision
329+
tolerance = 1e-4
320330
expected_quality = "perfect"
321331
elif category == "transition":
322332
# Transition sequences should still be high quality
323-
tolerance = 1e-4
333+
tolerance = 1e-3
324334
expected_quality = "high"
325335
elif category == "long":
326336
# Long sequences may have some quality degradation due to block approximation

0 commit comments

Comments
 (0)