Skip to content

Commit 1421247

Browse files
authored
Merge pull request rust-lang#4558 from RalfJung/float-err-fix
fix mangitude of applied float error
2 parents 3015ce1 + 9f0b2a2 commit 1421247

File tree

6 files changed

+103
-69
lines changed

6 files changed

+103
-69
lines changed

src/tools/miri/src/intrinsics/mod.rs

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use rustc_abi::Size;
1010
use rustc_apfloat::ieee::{IeeeFloat, Semantics};
1111
use rustc_apfloat::{self, Float, Round};
1212
use rustc_middle::mir;
13-
use rustc_middle::ty::{self, FloatTy, ScalarInt};
13+
use rustc_middle::ty::{self, FloatTy};
1414
use rustc_span::{Symbol, sym};
1515

1616
use self::atomic::EvalContextExt as _;
@@ -230,7 +230,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
230230
let res = apply_random_float_error_ulp(
231231
this,
232232
res,
233-
2, // log2(4)
233+
4,
234234
);
235235

236236
// Clamp the result to the guaranteed range of this function according to the C standard,
@@ -274,7 +274,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
274274
let res = apply_random_float_error_ulp(
275275
this,
276276
res,
277-
2, // log2(4)
277+
4,
278278
);
279279

280280
// Clamp the result to the guaranteed range of this function according to the C standard,
@@ -336,9 +336,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
336336

337337
// Apply a relative error of 4ULP to introduce some non-determinism
338338
// simulating imprecise implementations and optimizations.
339-
apply_random_float_error_ulp(
340-
this, res, 2, // log2(4)
341-
)
339+
apply_random_float_error_ulp(this, res, 4)
342340
});
343341
let res = this.adjust_nan(res, &[f1, f2]);
344342
this.write_scalar(res, dest)?;
@@ -354,9 +352,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
354352

355353
// Apply a relative error of 4ULP to introduce some non-determinism
356354
// simulating imprecise implementations and optimizations.
357-
apply_random_float_error_ulp(
358-
this, res, 2, // log2(4)
359-
)
355+
apply_random_float_error_ulp(this, res, 4)
360356
});
361357
let res = this.adjust_nan(res, &[f1, f2]);
362358
this.write_scalar(res, dest)?;
@@ -373,9 +369,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
373369

374370
// Apply a relative error of 4ULP to introduce some non-determinism
375371
// simulating imprecise implementations and optimizations.
376-
apply_random_float_error_ulp(
377-
this, res, 2, // log2(4)
378-
)
372+
apply_random_float_error_ulp(this, res, 4)
379373
});
380374
let res = this.adjust_nan(res, &[f]);
381375
this.write_scalar(res, dest)?;
@@ -391,9 +385,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
391385

392386
// Apply a relative error of 4ULP to introduce some non-determinism
393387
// simulating imprecise implementations and optimizations.
394-
apply_random_float_error_ulp(
395-
this, res, 2, // log2(4)
396-
)
388+
apply_random_float_error_ulp(this, res, 4)
397389
});
398390
let res = this.adjust_nan(res, &[f]);
399391
this.write_scalar(res, dest)?;
@@ -448,7 +440,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
448440
}
449441
// Apply a relative error of 4ULP to simulate non-deterministic precision loss
450442
// due to optimizations.
451-
let res = apply_random_float_error_to_imm(this, res, 2 /* log2(4) */)?;
443+
let res = crate::math::apply_random_float_error_to_imm(this, res, 4)?;
452444
this.write_immediate(*res, dest)?;
453445
}
454446

@@ -486,31 +478,6 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
486478
}
487479
}
488480

489-
/// Applies a random ULP floating point error to `val` and returns the new value.
490-
/// So if you want an X ULP error, `ulp_exponent` should be log2(X).
491-
///
492-
/// Will fail if `val` is not a floating point number.
493-
fn apply_random_float_error_to_imm<'tcx>(
494-
ecx: &mut MiriInterpCx<'tcx>,
495-
val: ImmTy<'tcx>,
496-
ulp_exponent: u32,
497-
) -> InterpResult<'tcx, ImmTy<'tcx>> {
498-
let scalar = val.to_scalar_int()?;
499-
let res: ScalarInt = match val.layout.ty.kind() {
500-
ty::Float(FloatTy::F16) =>
501-
apply_random_float_error_ulp(ecx, scalar.to_f16(), ulp_exponent).into(),
502-
ty::Float(FloatTy::F32) =>
503-
apply_random_float_error_ulp(ecx, scalar.to_f32(), ulp_exponent).into(),
504-
ty::Float(FloatTy::F64) =>
505-
apply_random_float_error_ulp(ecx, scalar.to_f64(), ulp_exponent).into(),
506-
ty::Float(FloatTy::F128) =>
507-
apply_random_float_error_ulp(ecx, scalar.to_f128(), ulp_exponent).into(),
508-
_ => bug!("intrinsic called with non-float input type"),
509-
};
510-
511-
interp_ok(ImmTy::from_scalar_int(res, val.layout))
512-
}
513-
514481
/// For the intrinsics:
515482
/// - sinf32, sinf64
516483
/// - cosf32, cosf64

src/tools/miri/src/machine.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> {
12781278
ecx: &mut InterpCx<'tcx, Self>,
12791279
val: ImmTy<'tcx>,
12801280
) -> InterpResult<'tcx, ImmTy<'tcx>> {
1281-
crate::math::apply_random_float_error_to_imm(ecx, val, 2 /* log2(4) */)
1281+
crate::math::apply_random_float_error_to_imm(ecx, val, 4)
12821282
}
12831283

12841284
#[inline(always)]

src/tools/miri/src/math.rs

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,20 @@ use rustc_middle::ty::{self, FloatTy, ScalarInt};
66
use crate::*;
77

88
/// Disturbes a floating-point result by a relative error in the range (-2^scale, 2^scale).
9-
///
10-
/// For a 2^N ULP error, you can use an `err_scale` of `-(F::PRECISION - 1 - N)`.
11-
/// In other words, a 1 ULP (absolute) error is the same as a `2^-(F::PRECISION-1)` relative error.
12-
/// (Subtracting 1 compensates for the integer bit.)
139
pub(crate) fn apply_random_float_error<F: rustc_apfloat::Float>(
1410
ecx: &mut crate::MiriInterpCx<'_>,
1511
val: F,
1612
err_scale: i32,
1713
) -> F {
1814
if !ecx.machine.float_nondet
1915
|| matches!(ecx.machine.float_rounding_error, FloatRoundingErrorMode::None)
16+
// relative errors don't do anything to zeros... avoid messing up the sign
17+
|| val.is_zero()
2018
{
2119
return val;
2220
}
23-
2421
let rng = ecx.machine.rng.get_mut();
22+
2523
// Generate a random integer in the range [0, 2^PREC).
2624
// (When read as binary, the position of the first `1` determines the exponent,
2725
// and the remaining bits fill the mantissa. `PREC` is one plus the size of the mantissa,
@@ -37,43 +35,66 @@ pub(crate) fn apply_random_float_error<F: rustc_apfloat::Float>(
3735
let err = r.scalbn(err_scale.strict_sub(F::PRECISION.try_into().unwrap()));
3836
// give it a random sign
3937
let err = if rng.random() { -err } else { err };
40-
// multiple the value with (1+err)
41-
(val * (F::from_u128(1).value + err).value).value
38+
// Compute `val*(1+err)`, distributed out as `val + val*err` to avoid the imprecise addition
39+
// error being amplified by multiplication.
40+
(val + (val * err).value).value
4241
}
4342

44-
/// [`apply_random_float_error`] gives instructions to apply a 2^N ULP error.
45-
/// This function implements these instructions such that applying a 2^N ULP error is less error prone.
46-
/// So for a 2^N ULP error, you would pass N as the `ulp_exponent` argument.
43+
/// Applies an error of `[-N, +N]` ULP to the given value.
4744
pub(crate) fn apply_random_float_error_ulp<F: rustc_apfloat::Float>(
4845
ecx: &mut crate::MiriInterpCx<'_>,
4946
val: F,
50-
ulp_exponent: u32,
47+
max_error: u32,
5148
) -> F {
52-
let n = i32::try_from(ulp_exponent)
53-
.expect("`err_scale_for_ulp`: exponent is too large to create an error scale");
54-
// we know this fits
55-
let prec = i32::try_from(F::PRECISION).unwrap();
56-
let err_scale = -(prec - n - 1);
57-
apply_random_float_error(ecx, val, err_scale)
49+
// We could try to be clever and reuse `apply_random_float_error`, but that is hard to get right
50+
// (see <https://github.com/rust-lang/miri/pull/4558#discussion_r2316838085> for why) so we
51+
// implement the logic directly instead.
52+
if !ecx.machine.float_nondet
53+
|| matches!(ecx.machine.float_rounding_error, FloatRoundingErrorMode::None)
54+
// FIXME: also disturb zeros? That requires a lot more cases in `fixed_float_value`
55+
// and might make the std test suite quite unhappy.
56+
|| val.is_zero()
57+
{
58+
return val;
59+
}
60+
let rng = ecx.machine.rng.get_mut();
61+
62+
let max_error = i64::from(max_error);
63+
let error = match ecx.machine.float_rounding_error {
64+
FloatRoundingErrorMode::Random => rng.random_range(-max_error..=max_error),
65+
FloatRoundingErrorMode::Max =>
66+
if rng.random() {
67+
max_error
68+
} else {
69+
-max_error
70+
},
71+
FloatRoundingErrorMode::None => unreachable!(),
72+
};
73+
// If upwards ULP and downwards ULP differ, we take the average.
74+
let ulp = (((val.next_up().value - val).value + (val - val.next_down().value).value).value
75+
/ F::from_u128(2).value)
76+
.value;
77+
// Shift the value by N times the ULP
78+
(val + (ulp * F::from_i128(error.into()).value).value).value
5879
}
5980

60-
/// Applies a random 16ULP floating point error to `val` and returns the new value.
81+
/// Applies an error of `[-N, +N]` ULP to the given value.
6182
/// Will fail if `val` is not a floating point number.
6283
pub(crate) fn apply_random_float_error_to_imm<'tcx>(
6384
ecx: &mut MiriInterpCx<'tcx>,
6485
val: ImmTy<'tcx>,
65-
ulp_exponent: u32,
86+
max_error: u32,
6687
) -> InterpResult<'tcx, ImmTy<'tcx>> {
6788
let scalar = val.to_scalar_int()?;
6889
let res: ScalarInt = match val.layout.ty.kind() {
6990
ty::Float(FloatTy::F16) =>
70-
apply_random_float_error_ulp(ecx, scalar.to_f16(), ulp_exponent).into(),
91+
apply_random_float_error_ulp(ecx, scalar.to_f16(), max_error).into(),
7192
ty::Float(FloatTy::F32) =>
72-
apply_random_float_error_ulp(ecx, scalar.to_f32(), ulp_exponent).into(),
93+
apply_random_float_error_ulp(ecx, scalar.to_f32(), max_error).into(),
7394
ty::Float(FloatTy::F64) =>
74-
apply_random_float_error_ulp(ecx, scalar.to_f64(), ulp_exponent).into(),
95+
apply_random_float_error_ulp(ecx, scalar.to_f64(), max_error).into(),
7596
ty::Float(FloatTy::F128) =>
76-
apply_random_float_error_ulp(ecx, scalar.to_f128(), ulp_exponent).into(),
97+
apply_random_float_error_ulp(ecx, scalar.to_f128(), max_error).into(),
7798
_ => bug!("intrinsic called with non-float input type"),
7899
};
79100

src/tools/miri/tests/pass/float.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ macro_rules! assert_approx_eq {
3939
}};
4040

4141
($a:expr, $b: expr) => {
42-
// accept up to 12ULP (4ULP for host floats and 4ULP for miri artificial error and 4 for any additional effects
43-
// due to having multiple error sources.
44-
assert_approx_eq!($a, $b, 12);
42+
// accept up to 8ULP (4ULP for host floats and 4ULP for miri artificial error).
43+
assert_approx_eq!($a, $b, 8);
4544
};
4645
}
4746

@@ -176,6 +175,7 @@ fn assert_eq_msg<T: PartialEq + Debug>(x: T, y: T, msg: impl Display) {
176175
}
177176

178177
/// Check that floats have bitwise equality
178+
#[track_caller]
179179
fn assert_biteq<F: Float>(a: F, b: F, msg: impl Display) {
180180
let ab = a.to_bits();
181181
let bb = b.to_bits();
@@ -189,6 +189,7 @@ fn assert_biteq<F: Float>(a: F, b: F, msg: impl Display) {
189189
}
190190

191191
/// Check that two floats have equality
192+
#[track_caller]
192193
fn assert_feq<F: Float>(a: F, b: F, msg: impl Display) {
193194
let ab = a.to_bits();
194195
let bb = b.to_bits();

src/tools/miri/tests/pass/float_extra_rounding_error.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::hint::black_box;
99

1010
fn main() {
1111
let expected = cfg_select! {
12-
random => 13, // FIXME: why is it 13?
12+
random => 9, // -4 ..= +4 ULP error
1313
max => 2,
1414
none => 1,
1515
};
@@ -20,4 +20,12 @@ fn main() {
2020
values.insert(val.to_bits());
2121
}
2222
assert_eq!(values.len(), expected);
23+
24+
if !cfg!(none) {
25+
// Ensure the smallest and biggest value are 8 ULP apart.
26+
// We can just subtract the raw bit representations for this.
27+
let min = *values.iter().min().unwrap();
28+
let max = *values.iter().max().unwrap();
29+
assert_eq!(min.abs_diff(max), 8);
30+
}
2331
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// We're testing x86 target specific features
2+
//@only-target: x86_64 i686
3+
4+
//! rsqrt and rcp SSE/AVX operations are approximate. We use that as license to treat them as
5+
//! non-deterministic. Ensure that we do indeed see random results within the expected error bounds.
6+
7+
#[cfg(target_arch = "x86")]
8+
use std::arch::x86::*;
9+
#[cfg(target_arch = "x86_64")]
10+
use std::arch::x86_64::*;
11+
use std::collections::HashSet;
12+
13+
fn main() {
14+
let mut vals = HashSet::new();
15+
for _ in 0..50 {
16+
unsafe {
17+
// Compute the inverse square root of 4.0, four times.
18+
let a = _mm_setr_ps(4.0, 4.0, 4.0, 4.0);
19+
let exact = 0.5;
20+
let r = _mm_rsqrt_ps(a);
21+
let r: [f32; 4] = std::mem::transmute(r);
22+
// Check the results.
23+
for r in r {
24+
vals.insert(r.to_bits());
25+
// Ensure the relative error is less than 2^-12.
26+
let rel_error = (r - exact) / exact;
27+
let log_error = rel_error.abs().log2();
28+
assert!(
29+
rel_error == 0.0 || log_error < -12.0,
30+
"got an error of {rel_error} = 2^{log_error}"
31+
);
32+
}
33+
}
34+
}
35+
// Ensure we saw a bunch of different results.
36+
assert!(vals.len() >= 50);
37+
}

0 commit comments

Comments
 (0)