Skip to content

Commit b059590

Browse files
committed
Fix hypot,isclose,integer
1 parent 61f39f1 commit b059590

File tree

6 files changed

+530
-195
lines changed

6 files changed

+530
-195
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Seeds for failure cases proptest has generated in the past. It is
2+
# automatically read and these particular cases re-run before any
3+
# novel cases are generated.
4+
#
5+
# It is recommended to check this file in to source control so that
6+
# everyone who runs the test benefits from these saved cases.
7+
cc 79031914da5204bfc75b0d7cf66e7f76d2e455d6c2837b66cbd11ebf7225be4a # shrinks to n = 32, k = 30

src/cmath/misc.rs

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,38 @@ pub fn abs(z: Complex64) -> f64 {
9898
}
9999

100100
/// Determine whether two complex numbers are close in value.
101+
///
102+
/// Default tolerances: rel_tol = 1e-09, abs_tol = 0.0
103+
/// Returns Err(EDOM) if rel_tol or abs_tol is negative.
101104
#[inline]
102-
pub fn isclose(a: Complex64, b: Complex64, rel_tol: f64, abs_tol: f64) -> bool {
105+
pub fn isclose(
106+
a: Complex64,
107+
b: Complex64,
108+
rel_tol: Option<f64>,
109+
abs_tol: Option<f64>,
110+
) -> Result<bool> {
111+
let rel_tol = rel_tol.unwrap_or(1e-09);
112+
let abs_tol = abs_tol.unwrap_or(0.0);
113+
114+
// Tolerances must be non-negative
115+
if rel_tol < 0.0 || abs_tol < 0.0 {
116+
return Err(Error::EDOM);
117+
}
118+
103119
// short circuit exact equality
104120
if a.re == b.re && a.im == b.im {
105-
return true;
121+
return Ok(true);
106122
}
107123

108124
// This catches the case of two infinities of opposite sign, or
109125
// one infinity and one finite number.
110126
if a.re.is_infinite() || a.im.is_infinite() || b.re.is_infinite() || b.im.is_infinite() {
111-
return false;
127+
return Ok(false);
112128
}
113129

114130
// now do the regular computation
115131
let diff = abs(Complex64::new(a.re - b.re, a.im - b.im));
116-
(diff <= rel_tol * abs(b)) || (diff <= rel_tol * abs(a)) || (diff <= abs_tol)
132+
Ok((diff <= rel_tol * abs(b)) || (diff <= rel_tol * abs(a)) || (diff <= abs_tol))
117133
}
118134

119135
#[cfg(test)]
@@ -315,39 +331,64 @@ mod tests {
315331
#[test]
316332
fn test_isclose_basic() {
317333
// Equal values
318-
assert!(isclose(
319-
Complex64::new(1.0, 2.0),
320-
Complex64::new(1.0, 2.0),
321-
1e-9,
322-
0.0
323-
));
334+
assert_eq!(
335+
isclose(
336+
Complex64::new(1.0, 2.0),
337+
Complex64::new(1.0, 2.0),
338+
Some(1e-9),
339+
Some(0.0)
340+
),
341+
Ok(true)
342+
);
324343
// Close values
325-
assert!(isclose(
326-
Complex64::new(1.0, 2.0),
327-
Complex64::new(1.0 + 1e-10, 2.0),
328-
1e-9,
329-
0.0
330-
));
344+
assert_eq!(
345+
isclose(
346+
Complex64::new(1.0, 2.0),
347+
Complex64::new(1.0 + 1e-10, 2.0),
348+
Some(1e-9),
349+
Some(0.0)
350+
),
351+
Ok(true)
352+
);
331353
// Not close
332-
assert!(!isclose(
333-
Complex64::new(1.0, 2.0),
334-
Complex64::new(2.0, 2.0),
335-
1e-9,
336-
0.0
337-
));
354+
assert_eq!(
355+
isclose(
356+
Complex64::new(1.0, 2.0),
357+
Complex64::new(2.0, 2.0),
358+
Some(1e-9),
359+
Some(0.0)
360+
),
361+
Ok(false)
362+
);
338363
// Infinities
339-
assert!(isclose(
340-
Complex64::new(f64::INFINITY, 0.0),
341-
Complex64::new(f64::INFINITY, 0.0),
342-
1e-9,
343-
0.0
344-
));
345-
assert!(!isclose(
346-
Complex64::new(f64::INFINITY, 0.0),
347-
Complex64::new(f64::NEG_INFINITY, 0.0),
348-
1e-9,
349-
0.0
350-
));
364+
assert_eq!(
365+
isclose(
366+
Complex64::new(f64::INFINITY, 0.0),
367+
Complex64::new(f64::INFINITY, 0.0),
368+
Some(1e-9),
369+
Some(0.0)
370+
),
371+
Ok(true)
372+
);
373+
assert_eq!(
374+
isclose(
375+
Complex64::new(f64::INFINITY, 0.0),
376+
Complex64::new(f64::NEG_INFINITY, 0.0),
377+
Some(1e-9),
378+
Some(0.0)
379+
),
380+
Ok(false)
381+
);
382+
// Negative tolerance
383+
assert!(
384+
isclose(
385+
Complex64::new(1.0, 2.0),
386+
Complex64::new(1.0, 2.0),
387+
Some(-1.0),
388+
Some(0.0)
389+
)
390+
.is_err()
391+
);
351392
}
352393

353394
proptest::proptest! {

src/math.rs

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,33 @@ pub(crate) fn math_1a(x: f64, func: fn(f64) -> f64) -> crate::Result<f64> {
9696
crate::err::is_error(r)
9797
}
9898

99-
/// Return the Euclidean distance, sqrt(x*x + y*y).
99+
/// Return the Euclidean norm of n-dimensional coordinates.
100100
///
101-
/// Uses high-precision vector_norm algorithm instead of libm hypot()
102-
/// for consistent results across platforms and better handling of overflow/underflow.
101+
/// Equivalent to sqrt(sum(x**2 for x in coords)).
102+
/// Uses high-precision vector_norm algorithm for consistent results
103+
/// across platforms and better handling of overflow/underflow.
103104
#[inline]
104-
pub fn hypot(x: f64, y: f64) -> f64 {
105-
let ax = x.abs();
106-
let ay = y.abs();
107-
let max = if ax > ay { ax } else { ay };
108-
let found_nan = x.is_nan() || y.is_nan();
109-
aggregate::vector_norm_2(ax, ay, max, found_nan)
105+
pub fn hypot(coords: &[f64]) -> f64 {
106+
let n = coords.len();
107+
if n == 0 {
108+
return 0.0;
109+
}
110+
111+
let mut max = 0.0_f64;
112+
let mut found_nan = false;
113+
let abs_coords: Vec<f64> = coords
114+
.iter()
115+
.map(|&x| {
116+
let ax = x.abs();
117+
found_nan |= ax.is_nan();
118+
if ax > max {
119+
max = ax;
120+
}
121+
ax
122+
})
123+
.collect();
124+
125+
aggregate::vector_norm(&abs_coords, max, found_nan)
110126
}
111127

112128
// Mathematical constants
@@ -208,16 +224,43 @@ mod tests {
208224
assert!(NAN.is_nan());
209225
}
210226

211-
// hypot tests
212-
fn test_hypot(x: f64, y: f64) {
213-
crate::test::test_math_2(x, y, "hypot", |x, y| Ok(hypot(x, y)));
227+
// hypot tests - n-dimensional
228+
fn test_hypot(coords: &[f64]) {
229+
let rs_result = hypot(coords);
230+
231+
pyo3::Python::attach(|py| {
232+
let math = PyModule::import(py, "math").unwrap();
233+
let py_func = math.getattr("hypot").unwrap();
234+
let py_args = pyo3::types::PyTuple::new(py, coords).unwrap();
235+
let py_result: f64 = py_func.call1(py_args).unwrap().extract().unwrap();
236+
237+
if py_result.is_nan() && rs_result.is_nan() {
238+
return;
239+
}
240+
assert_eq!(
241+
py_result.to_bits(),
242+
rs_result.to_bits(),
243+
"hypot({:?}): py={} vs rs={}",
244+
coords,
245+
py_result,
246+
rs_result
247+
);
248+
});
249+
}
250+
251+
#[test]
252+
fn test_hypot_basic() {
253+
test_hypot(&[3.0, 4.0]); // 5.0
254+
test_hypot(&[1.0, 2.0, 2.0]); // 3.0
255+
test_hypot(&[]); // 0.0
256+
test_hypot(&[5.0]); // 5.0
214257
}
215258

216259
#[test]
217260
fn edgetest_hypot() {
218261
for &x in &crate::test::EDGE_VALUES {
219262
for &y in &crate::test::EDGE_VALUES {
220-
test_hypot(x, y);
263+
test_hypot(&[x, y]);
221264
}
222265
}
223266
}
@@ -235,7 +278,7 @@ mod tests {
235278

236279
#[test]
237280
fn proptest_hypot(x: f64, y: f64) {
238-
test_hypot(x, y);
281+
test_hypot(&[x, y]);
239282
}
240283
}
241284
}

src/math/aggregate.rs

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -170,80 +170,8 @@ pub fn fsum(iter: impl IntoIterator<Item = f64>) -> crate::Result<f64> {
170170

171171
// VECTOR_NORM - for dist and hypot
172172

173-
/// Compute the Euclidean norm of two values with high precision.
174-
/// Optimized version for hypot(x, y).
175-
pub(super) fn vector_norm_2(x: f64, y: f64, max: f64, found_nan: bool) -> f64 {
176-
// Check for infinity first (inf wins over nan)
177-
if x.is_infinite() || y.is_infinite() {
178-
return f64::INFINITY;
179-
}
180-
if found_nan {
181-
return f64::NAN;
182-
}
183-
if max == 0.0 {
184-
return 0.0;
185-
}
186-
// n == 1 case: only one non-zero value
187-
if x == 0.0 || y == 0.0 {
188-
return max;
189-
}
190-
191-
let mut max_e: i32 = 0;
192-
crate::m::frexp(max, &mut max_e);
193-
194-
if max_e < -1023 {
195-
// When max_e < -1023, ldexp(1.0, -max_e) would overflow
196-
return f64::MIN_POSITIVE
197-
* vector_norm_2(
198-
x / f64::MIN_POSITIVE,
199-
y / f64::MIN_POSITIVE,
200-
max / f64::MIN_POSITIVE,
201-
found_nan,
202-
);
203-
}
204-
205-
let scale = crate::m::ldexp(1.0, -max_e);
206-
debug_assert!(max * scale >= 0.5);
207-
debug_assert!(max * scale < 1.0);
208-
209-
let mut csum = 1.0;
210-
let mut frac1 = 0.0;
211-
let mut frac2 = 0.0;
212-
213-
// Process x
214-
let xs = x * scale;
215-
debug_assert!(xs.abs() < 1.0);
216-
let pr = dl_mul(xs, xs);
217-
debug_assert!(pr.hi <= 1.0);
218-
let sm = dl_fast_sum(csum, pr.hi);
219-
csum = sm.hi;
220-
frac1 += pr.lo;
221-
frac2 += sm.lo;
222-
223-
// Process y
224-
let ys = y * scale;
225-
debug_assert!(ys.abs() < 1.0);
226-
let pr = dl_mul(ys, ys);
227-
debug_assert!(pr.hi <= 1.0);
228-
let sm = dl_fast_sum(csum, pr.hi);
229-
csum = sm.hi;
230-
frac1 += pr.lo;
231-
frac2 += sm.lo;
232-
233-
let mut h = (csum - 1.0 + (frac1 + frac2)).sqrt();
234-
let pr = dl_mul(-h, h);
235-
let sm = dl_fast_sum(csum, pr.hi);
236-
csum = sm.hi;
237-
frac1 += pr.lo;
238-
frac2 += sm.lo;
239-
let x = csum - 1.0 + (frac1 + frac2);
240-
h += x / (2.0 * h); // differential correction
241-
242-
h / scale
243-
}
244-
245173
/// Compute the Euclidean norm of a vector with high precision.
246-
fn vector_norm(vec: &[f64], max: f64, found_nan: bool) -> f64 {
174+
pub fn vector_norm(vec: &[f64], max: f64, found_nan: bool) -> f64 {
247175
let n = vec.len();
248176

249177
if max.is_infinite() {

0 commit comments

Comments
 (0)