Skip to content

Commit 11ffdb5

Browse files
authored
Merge pull request rust-lang#4592 from RalfJung/sqrt
implement sqrt for f16 and f128
2 parents 045e5e3 + 77f2d86 commit 11ffdb5

File tree

3 files changed

+54
-38
lines changed

3 files changed

+54
-38
lines changed

src/tools/miri/src/intrinsics/math.rs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
11
use rand::Rng;
2-
use rustc_apfloat::{self, Float, Round};
2+
use rustc_apfloat::{self, Float, FloatConvert, Round};
33
use rustc_middle::mir;
44
use rustc_middle::ty::{self, FloatTy};
55

66
use self::helpers::{ToHost, ToSoft};
77
use super::check_intrinsic_arg_count;
88
use crate::*;
99

10+
fn sqrt<'tcx, F: Float + FloatConvert<F> + Into<Scalar>>(
11+
this: &mut MiriInterpCx<'tcx>,
12+
args: &[OpTy<'tcx>],
13+
dest: &MPlaceTy<'tcx>,
14+
) -> InterpResult<'tcx> {
15+
let [f] = check_intrinsic_arg_count(args)?;
16+
let f = this.read_scalar(f)?;
17+
let f: F = f.to_float()?;
18+
// Sqrt is specified to be fully precise.
19+
let res = math::sqrt(f);
20+
let res = this.adjust_nan(res, &[f]);
21+
this.write_scalar(res, dest)
22+
}
23+
1024
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
1125
pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
1226
fn emulate_math_intrinsic(
@@ -20,22 +34,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
2034

2135
match intrinsic_name {
2236
// Operations we can do with soft-floats.
23-
"sqrtf32" => {
24-
let [f] = check_intrinsic_arg_count(args)?;
25-
let f = this.read_scalar(f)?.to_f32()?;
26-
// Sqrt is specified to be fully precise.
27-
let res = math::sqrt(f);
28-
let res = this.adjust_nan(res, &[f]);
29-
this.write_scalar(res, dest)?;
30-
}
31-
"sqrtf64" => {
32-
let [f] = check_intrinsic_arg_count(args)?;
33-
let f = this.read_scalar(f)?.to_f64()?;
34-
// Sqrt is specified to be fully precise.
35-
let res = math::sqrt(f);
36-
let res = this.adjust_nan(res, &[f]);
37-
this.write_scalar(res, dest)?;
38-
}
37+
"sqrtf16" => sqrt::<rustc_apfloat::ieee::Half>(this, args, dest)?,
38+
"sqrtf32" => sqrt::<rustc_apfloat::ieee::Single>(this, args, dest)?,
39+
"sqrtf64" => sqrt::<rustc_apfloat::ieee::Double>(this, args, dest)?,
40+
"sqrtf128" => sqrt::<rustc_apfloat::ieee::Quad>(this, args, dest)?,
3941

4042
"fmaf32" => {
4143
let [a, b, c] = check_intrinsic_arg_count(args)?;

src/tools/miri/src/math.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::ops::Neg;
22
use std::{f32, f64};
33

44
use rand::Rng as _;
5-
use rustc_apfloat::Float as _;
5+
use rustc_apfloat::Float;
66
use rustc_apfloat::ieee::{DoubleS, IeeeFloat, Semantics, SingleS};
77
use rustc_middle::ty::{self, FloatTy, ScalarInt};
88

@@ -317,19 +317,19 @@ where
317317
}
318318
}
319319

320-
pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFloat<S> {
320+
pub(crate) fn sqrt<F: Float>(x: F) -> F {
321321
match x.category() {
322322
// preserve zero sign
323323
rustc_apfloat::Category::Zero => x,
324324
// propagate NaN
325325
rustc_apfloat::Category::NaN => x,
326326
// sqrt of negative number is NaN
327-
_ if x.is_negative() => IeeeFloat::NAN,
327+
_ if x.is_negative() => F::NAN,
328328
// sqrt(∞) = ∞
329-
rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY,
329+
rustc_apfloat::Category::Infinity => F::INFINITY,
330330
rustc_apfloat::Category::Normal => {
331331
// Floating point precision, excluding the integer bit
332-
let prec = i32::try_from(S::PRECISION).unwrap() - 1;
332+
let prec = i32::try_from(F::PRECISION).unwrap() - 1;
333333

334334
// x = 2^(exp - prec) * mant
335335
// where mant is an integer with prec+1 bits
@@ -394,7 +394,7 @@ pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFl
394394
res = (res + 1) >> 1;
395395

396396
// Build resulting value with res as mantissa and exp/2 as exponent
397-
IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec)
397+
F::from_u128(res).value.scalbn(exp / 2 - prec)
398398
}
399399
}
400400
}

src/tools/miri/tests/pass/float.rs

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,35 @@ fn basic() {
281281
assert_eq!(34.2f64.abs(), 34.2f64);
282282
assert_eq!((-1.0f128).abs(), 1.0f128);
283283
assert_eq!(34.2f128.abs(), 34.2f128);
284+
285+
assert_eq!(64_f16.sqrt(), 8_f16);
286+
assert_eq!(64_f32.sqrt(), 8_f32);
287+
assert_eq!(64_f64.sqrt(), 8_f64);
288+
assert_eq!(64_f128.sqrt(), 8_f128);
289+
assert_eq!(f16::INFINITY.sqrt(), f16::INFINITY);
290+
assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY);
291+
assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY);
292+
assert_eq!(f128::INFINITY.sqrt(), f128::INFINITY);
293+
assert_eq!(0.0_f16.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
294+
assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
295+
assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
296+
assert_eq!(0.0_f128.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
297+
assert_eq!((-0.0_f16).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
298+
assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
299+
assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
300+
assert_eq!((-0.0_f128).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
301+
assert!((-5.0_f16).sqrt().is_nan());
302+
assert!((-5.0_f32).sqrt().is_nan());
303+
assert!((-5.0_f64).sqrt().is_nan());
304+
assert!((-5.0_f128).sqrt().is_nan());
305+
assert!(f16::NEG_INFINITY.sqrt().is_nan());
306+
assert!(f32::NEG_INFINITY.sqrt().is_nan());
307+
assert!(f64::NEG_INFINITY.sqrt().is_nan());
308+
assert!(f128::NEG_INFINITY.sqrt().is_nan());
309+
assert!(f16::NAN.sqrt().is_nan());
310+
assert!(f32::NAN.sqrt().is_nan());
311+
assert!(f64::NAN.sqrt().is_nan());
312+
assert!(f128::NAN.sqrt().is_nan());
284313
}
285314

286315
/// Test casts from floats to ints and back
@@ -1012,21 +1041,6 @@ pub fn libm() {
10121041
unsafe { ldexp(a, b) }
10131042
}
10141043

1015-
assert_eq!(64_f32.sqrt(), 8_f32);
1016-
assert_eq!(64_f64.sqrt(), 8_f64);
1017-
assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY);
1018-
assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY);
1019-
assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
1020-
assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
1021-
assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
1022-
assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
1023-
assert!((-5.0_f32).sqrt().is_nan());
1024-
assert!((-5.0_f64).sqrt().is_nan());
1025-
assert!(f32::NEG_INFINITY.sqrt().is_nan());
1026-
assert!(f64::NEG_INFINITY.sqrt().is_nan());
1027-
assert!(f32::NAN.sqrt().is_nan());
1028-
assert!(f64::NAN.sqrt().is_nan());
1029-
10301044
assert_approx_eq!(25f32.powi(-2), 0.0016f32);
10311045
assert_approx_eq!(23.2f64.powi(2), 538.24f64);
10321046

0 commit comments

Comments
 (0)