diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 0b5c9dc..269317b 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -11,24 +11,26 @@ env: jobs: build: - name: Build and test + name: Build and test (${{ matrix.os }}) runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: - os: [ - ubuntu-latest, - windows-latest, - macos-latest, - ] + os: [ubuntu-latest, windows-latest, macos-latest] rust: [stable] + include: + - os: ubuntu-latest + features: complex + - os: windows-latest + features: complex + - os: macos-latest + features: complex,mul_add steps: - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.13" - check-latest: true + python-version: "3.14" - name: Setup Rust uses: actions-rs/toolchain@v1 with: @@ -37,9 +39,9 @@ jobs: override: true - name: Build run: cargo build --verbose - - name: Run tests - if: matrix.os != 'macos-latest' - run: cargo test --verbose && cargo test --release --verbose - - name: Run tests with FMA (macOS) - if: matrix.os == 'macos-latest' - run: cargo test --verbose --features mul_add && cargo test --release --verbose --features mul_add + - name: Run tests with num-bigint + run: cargo test --verbose --features ${{ matrix.features }},num-bigint + - name: Run tests with num-bigint (Release) + run: cargo test --verbose --features ${{ matrix.features }},num-bigint --release + - name: Run tests with malachite-bigint + run: cargo test --verbose --features ${{ matrix.features }},malachite-bigint diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..6324d40 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.14 diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..0e80bd5 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,195 @@ +# pymath Agent Guidelines + +This project is a strict port of CPython's math and cmath modules to Rust. + +## Core Principle + +**Every function must match CPython exactly** - same logic, same special case handling, same error conditions. + +- `math` module → `Modules/mathmodule.c` +- `cmath` module → `Modules/cmathmodule.c` + +## Porting Rules + +### 1. Use Existing Helpers + +CPython uses helpers like `math_1`, `math_2`, `FUNC1`, `FUNC2`, etc. We have Rust equivalents: + +| CPython | Rust | +|---------|------| +| `FUNC1(name, func, can_overflow, ...)` | `math_1(x, func, can_overflow)` | +| `FUNC2(name, func, ...)` | `math_2(x, y, func)` | +| `m_log`, `m_log10`, `m_log2` | Same names in `exponential.rs` | + +If a function uses a helper in CPython, use the corresponding helper here. + +### 2. Create Missing Helpers + +If CPython has a helper we don't have yet, implement it. Examples: +- `math_1_fn` for Rust function pointers (vs C function pointers) +- Special case handlers for specific functions + +### 3. Error Handling + +CPython sets `errno` and calls `is_error()`. We return `Result` directly: + +```rust +// CPython: +errno = EDOM; +if (errno && is_error(r, 1)) return NULL; + +// Rust: +return Err(crate::Error::EDOM); +``` + +**Never use `set_errno(libc::EDOM)` or similar** - just return `Err()` directly. +The only valid `set_errno` call is `set_errno(0)` to clear errno before libm calls. + +### 4. Special Cases + +CPython has explicit special case handling for IEEE specials (NaN, Inf, etc.). Copy this logic exactly: + +```rust +// Example from pow(): +if !x.is_finite() || !y.is_finite() { + if x.is_nan() { + return Ok(if y == 0.0 { 1.0 } else { x }); // NaN**0 = 1 + } + // ... more special cases +} +``` + +### 5. Reference CPython Source + +Always check CPython source in the `cpython/` directory: + +**For math module** (`Modules/mathmodule.c`): +- Function implementations +- Helper macros (`FUNC1`, `FUNC1D`, `FUNC2`) +- Special case comments +- Error conditions + +**For cmath module** (`Modules/cmathmodule.c`): +- `special_type()` enum and function +- 7x7 special value tables (e.g., `tanh_special_values`) +- `SPECIAL_VALUE` macro usage +- Complex-specific error handling + +### 6. Fused Multiply-Add (mul_add) + +For bit-exact matching with CPython, use `crate::mul_add(a, b, c)` instead of `a * b + c` in specific cases. + +**Why this matters**: CPython compiled with clang on macOS may use FMA (fused multiply-add) instructions for expressions like `1.0 + x * x`. FMA computes `a * b + c` in a single operation without intermediate rounding, which can produce results that differ by 1-2 ULP from separate multiply and add operations. + +**When to use `mul_add`**: +- Expressions of the form `a * b + c` or `c + a * b` in complex math functions +- Especially in formulas like `1.0 + x * x` → `mul_add(x, x, 1.0)` + +**Example** (from `c_tanh`): +```rust +// Wrong - may differ from CPython by 1-2 ULP +let denom = 1.0 + txty * txty; +let r_re = tx * (1.0 + ty * ty) / denom; + +// Correct - matches CPython exactly +let denom = mul_add(txty, txty, 1.0); +let r_re = tx * mul_add(ty, ty, 1.0) / denom; +``` + +**Example** (from `c_asinh`): +```rust +// mul_add for cross-product calculations +let r_re = m::asinh(mul_add(s1.re, s2.im, -(s2.re * s1.im))); +let r_im = m::atan2(z.im, mul_add(s1.re, s2.re, -(s1.im * s2.im))); +``` + +**Example** (from `c_atanh`): +```rust +// mul_add for squared terms +let one_minus_re = 1.0 - z.re; +let r_re = m::log1p(4.0 * z.re / mul_add(one_minus_re, one_minus_re, ay * ay)) / 4.0; +let r_im = -m::atan2(-2.0 * z.im, mul_add(one_minus_re, 1.0 + z.re, -(ay * ay))) / 2.0; +``` + +**Feature flag**: The `mul_add` feature controls whether hardware FMA is used: + +- `mul_add` enabled: Uses `f64::mul_add()` (hardware FMA instruction) +- `mul_add` disabled (default): Falls back to `a * b + c` (separate operations) + +Note: macOS CI always enables `mul_add` because CPython on macOS uses FMA. + +**How to identify missing mul_add usage**: +1. If a test fails with 1-2 ULP difference +2. Look for `a * b + c` or `c + a * b` patterns in the failing function +3. Replace with `mul_add(a, b, c)` and re-test + +### 7. Platform-specific sincos (macOS, cmath only) + +On macOS, Python's cmath module uses Apple's `__sincos_stret` function, which computes sin and cos together with slightly different results than calling them separately (up to 1 ULP difference). + +For bit-exact matching on macOS, use `m::sincos(x)` which returns `(sin, cos)` tuple: + +```rust +// Instead of: +let sin_x = m::sin(x); +let cos_x = m::cos(x); + +// Use: +let (sin_x, cos_x) = m::sincos(x); +``` + +Required in cmath functions that use both sin and cos of the same angle: + +- `cosh`, `sinh` - for the imaginary argument +- `exp` - for the imaginary argument +- `rect` - for the phi angle + +On non-macOS platforms, `m::sincos(x)` falls back to calling sin and cos separately. + +## Testing + +### EDGE_VALUES + +All float functions must be tested with `crate::test::EDGE_VALUES` which includes: +- Zeros: `0.0`, `-0.0` +- Infinities: `INFINITY`, `NEG_INFINITY` +- NaNs: `NAN`, `-NAN`, and NaN with different payload +- Subnormals +- Boundary values: `MIN_POSITIVE`, `MAX`, `MIN` +- Large values near infinity +- Trigonometric special values: `PI`, `PI/2`, `PI/4`, `TAU` + +### Error Type Verification + +Tests must verify both: +1. Correct values for Ok results +2. Correct error types (EDOM vs ERANGE) for Err results + +Python `ValueError` → `Error::EDOM` +Python `OverflowError` → `Error::ERANGE` + +## File Structure + +### Core +- `src/lib.rs` - Root module, `mul_add` function +- `src/err.rs` - Error types (EDOM, ERANGE) +- `src/test.rs` - Test helpers, `EDGE_VALUES`, `EDGE_INTS` + +### System libm bindings +- `src/m_sys.rs` - Raw FFI declarations (`extern "C"`) +- `src/m.rs` - Safe wrappers, platform-specific `sincos` + +### math module +- `src/math.rs` - Main module, `math_1`, `math_2` helpers, `hypot`, constants +- `src/math/exponential.rs` - exp, log, pow, sqrt, cbrt, etc. +- `src/math/trigonometric.rs` - sin, cos, tan, asin, acos, atan, etc. +- `src/math/misc.rs` - frexp, ldexp, modf, fmod, copysign, isclose, ulp, etc. +- `src/math/gamma.rs` - gamma, lgamma, erf, erfc +- `src/math/aggregate.rs` - fsum, prod, sumprod, dist (vector operations) +- `src/math/integer.rs` - gcd, lcm, isqrt, comb, perm, factorial (requires `_bigint` feature) + +### cmath module (requires `complex` feature) +- `src/cmath.rs` - Main module, `special_type`, `special_value!` macro, shared constants +- `src/cmath/exponential.rs` - sqrt, exp, log, log10 +- `src/cmath/trigonometric.rs` - sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh +- `src/cmath/misc.rs` - phase, polar, rect, abs, isfinite, isnan, isinf, isclose diff --git a/Cargo.toml b/Cargo.toml index 758d33d..ef8c28d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,18 @@ [package] name = "pymath" -version = "0.0.2" +version = "0.0.3" edition = "2024" description = "A binary representation compatible Rust implementation of Python's math library." license = "PSF-2.0" [features] +default = ["complex"] +complex = ["dep:num-complex"] +num-bigint = ["_bigint", "dep:num-bigint"] +malachite-bigint = ["_bigint", "dep:malachite-bigint"] +_bigint = ["dep:num-traits", "dep:num-integer"] # Internal feature. User must use num-bigint or malachite-bigint instead. + # Do not enable this feature unless you really need it. # CPython didn't intend to use FMA for its math library. # This project uses this feature in CI to verify the code doesn't have additional bugs on aarch64-apple-darwin. @@ -17,7 +23,12 @@ mul_add = [] [dependencies] libc = "0.2" +num-complex = { version = "0.4", optional = true } +num-bigint = { version = "0.4", optional = true } +num-traits = { version = "0.2", optional = true } +num-integer = { version = "0.1", optional = true } +malachite-bigint = { version = "0.2", optional = true } [dev-dependencies] proptest = "1.6.0" -pyo3 = { version = "0.24", features = ["abi3"] } +pyo3 = { version = "0.27", features = ["abi3", "auto-initialize"] } diff --git a/README.md b/README.md index be8aa44..85d92ee 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,16 @@ A binary representation compatible Rust implementation of Python's math library. Each function has been carefully translated from CPython's C implementation to Rust, preserving the same algorithms, constants, and corner case handling. The code maintains the same numerical properties, but in Rust! +## Module Structure + +- `pymath::math` - Real number math functions (Python's `math` module) +- `pymath::cmath` - Complex number functions (Python's `cmath` module, requires `complex` feature) +- `pymath::m` - Direct libm bindings + ## Usage ```rust -use pymath::{gamma, lgamma}; +use pymath::math::{gamma, lgamma}; fn main() { // Get the same results as Python's math.gamma and math.lgamma diff --git a/proptest-regressions/cmath.txt b/proptest-regressions/cmath.txt new file mode 100644 index 0000000..2407ef2 --- /dev/null +++ b/proptest-regressions/cmath.txt @@ -0,0 +1,10 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc d9f364ff419553f1fcf92be2834534eb8bd5faacc285ed6587b5a3c827b235bd # shrinks to re = 1.00214445281558e32, im = -3.0523230669839245e-65 +cc 26832b65255697171026596e8fbeee37e5a2bf82133ca5cff83c21b95add285f # shrinks to re = 5.946781139174558e-217, im = 3.4143760786656616e281 +cc 78224345cd95a0451f2f83872fb7ca6f9462c7eed58bf5490d6b9b717c5b8e02 # shrinks to re = -9.234931944778561e90, im = 1.0662656145085839e88 +cc a1b5e651ca7e81b1cf0fee4c7fb4a982982d24a10b4d0b27ae38559b70f0e9db # shrinks to re = -0.6032998606032244, im = -4.778999871811813e-156 diff --git a/proptest-regressions/cmath/misc.txt b/proptest-regressions/cmath/misc.txt new file mode 100644 index 0000000..71b606c --- /dev/null +++ b/proptest-regressions/cmath/misc.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 6b3810fffb8caa48de63b8c593267f33b68ed7c82057a67ee94dd2077e1d1813 # shrinks to re = 4.861109893051668e77, im = -4.975947432969132e-264 diff --git a/proptest-regressions/cmath/trigonometric.txt b/proptest-regressions/cmath/trigonometric.txt new file mode 100644 index 0000000..04c0b1b --- /dev/null +++ b/proptest-regressions/cmath/trigonometric.txt @@ -0,0 +1,11 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc a0c58bc19132b1fba3a2edc94204f7582b5dbea7b87139f582da75f58aefa06e # shrinks to re = 2.4494096702311403e-49, im = -1.2388169541772939e300 +cc 409f359905a7e77eacdbd4643c3a1483cc128e46a1c06795c246f3d57a2e61b9 # shrinks to re = -4.17454178893395e-69, im = 7.373554801445074e178 +cc 5acbcf76c3bcbaf0b2b89a987e38c75265434f96fdc11598830e2221df72fc53 # shrinks to re = 0.00010260965539526095, im = 4.984721877290597e-19 +cc ac04191f916633de0408e071f5f6ada8f7fe7a1be02caca7a32f37ecb148a5ac # shrinks to re = 3.049837651806167e74, im = -2.222842222335753e-166 +cc 77ec3c08d2e057fb9467eb13c3b953e4e30f2800259d3653f7bcdfcbaf53614f # shrinks to re = 7.812038268590211e52, im = -2.623972069152808e-109 diff --git a/proptest-regressions/gamma.txt b/proptest-regressions/gamma.txt deleted file mode 100644 index 77162bb..0000000 --- a/proptest-regressions/gamma.txt +++ /dev/null @@ -1,13 +0,0 @@ -# Seeds for failure cases proptest has generated in the past. It is -# automatically read and these particular cases re-run before any -# novel cases are generated. -# -# It is recommended to check this file in to source control so that -# everyone who runs the test benefits from these saved cases. -cc e8ed768221998086795d95c68921437e80c4b7fe68fe9da15ca40faa216391b5 # shrinks to x = 0.0 -cc 23c7f86ab299daa966772921d8c615afda11e1b77944bed40e88264a68e62ac3 # shrinks to x = -19.80948467648103 -cc f57954d91904549b9431755f196b630435a43cbefd558b932efad487a403c6c8 # shrinks to x = 0.003585187864492183 -cc 7a9a04aed4ed7e3d23eb7b32b748542b1062e349ae83cc1fad39672a5b2156cd # shrinks to x = -3.8510064710745118 -cc d884d4ef56bcd40d025660e0dec152754fd4fd4e48bc0bdf41e73ea001798fd8 # shrinks to x = 0.9882904125102558 -cc 3f1d36f364ce29810d0c37003465b186c07460861c7a3bf4b8962401b376f2d9 # shrinks to x = 1.402608516799205 -cc 4439ce674d91257d104063e2d5ade7908c83462d195f98a0c304ea25b022d0f4 # shrinks to x = 3.6215752811868267 diff --git a/proptest-regressions/math.txt b/proptest-regressions/math.txt new file mode 100644 index 0000000..df0d5ee --- /dev/null +++ b/proptest-regressions/math.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 6f6becc96f663a83d20def559f516c7c5ce1a90b87c373d6c025dd3ab8f1fc39 # shrinks to x = 5.868849392888587e-309, y = 1.985586796867676e-308 diff --git a/proptest-regressions/math/aggregate.txt b/proptest-regressions/math/aggregate.txt new file mode 100644 index 0000000..cc1e50c --- /dev/null +++ b/proptest-regressions/math/aggregate.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc f58d401e380fa9d3a8dd5b137a668be6bd437776f47186a9479aee07c0aa28b8 # shrinks to p1 = 0.0, p2 = 0.0, q1 = -1.156587418587806e301, q2 = -1.315804087909368e-150 diff --git a/proptest-regressions/math/exponential.txt b/proptest-regressions/math/exponential.txt new file mode 100644 index 0000000..109fe6c --- /dev/null +++ b/proptest-regressions/math/exponential.txt @@ -0,0 +1,10 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc c756e015f8e09d21d37e7f805e54f9270cac137e4570d10ecd5302d752e2559f # shrinks to x = -2.9225692145627912e-248, y = 1.565985096226121e-308 +cc 2f4b02dd2c2019dee35dcf980b7b4f077899289a7005bba21f849776768eda7a # shrinks to x = 6.095938323843682e143 +cc b52b190a89a473a7e5269f9a3efa83735318ae3bd52f3340c32bc4a40f4cda88 # shrinks to x = -0.0, y = -9.787367203123051e54 +cc a34a3387ec2547faa88352529f9730d47c6202f9b418a0b4474c9ebb850081e7 # shrinks to x = -1.3916042894622981e-207 diff --git a/proptest-regressions/lib.txt b/proptest-regressions/math/gamma.txt similarity index 73% rename from proptest-regressions/lib.txt rename to proptest-regressions/math/gamma.txt index 25c7cb0..16562d4 100644 --- a/proptest-regressions/lib.txt +++ b/proptest-regressions/math/gamma.txt @@ -4,4 +4,4 @@ # # It is recommended to check this file in to source control so that # everyone who runs the test benefits from these saved cases. -cc 531a136f9fcde9d1da1ba5d173e62eee8ec8f7c877eb34abbc6d47611a641bc7 # shrinks to x = 0.0 +cc ac5c11a6ec8450aef2b03644e8b46a066f81c689b55f9657fcd0ae73d5829bc2 # shrinks to x = -1.8815128643365582 diff --git a/src/cmath.rs b/src/cmath.rs new file mode 100644 index 0000000..90bce66 --- /dev/null +++ b/src/cmath.rs @@ -0,0 +1,181 @@ +//! Complex math functions matching Python's cmath module behavior. +//! +//! These implementations follow the algorithms from cmathmodule.c +//! to ensure numerical precision and correct handling of edge cases. + +mod exponential; +mod misc; +mod trigonometric; + +pub use exponential::{exp, log, log10, sqrt}; +pub use misc::{abs, isclose, isfinite, isinf, isnan, phase, polar, rect}; +pub use trigonometric::{acos, acosh, asin, asinh, atan, atanh, cos, cosh, sin, sinh, tan, tanh}; + +#[cfg(test)] +use crate::Result; +use crate::m; +#[cfg(test)] +use num_complex::Complex64; + +// Shared constants + +const M_LN2: f64 = core::f64::consts::LN_2; + +/// Used to avoid spurious overflow in sqrt, log, inverse trig/hyperbolic functions. +const CM_LARGE_DOUBLE: f64 = f64::MAX / 4.0; +const CM_LOG_LARGE_DOUBLE: f64 = 709.0895657128241; // log(CM_LARGE_DOUBLE) + +const INF: f64 = f64::INFINITY; + +// Special value table constants +const P: f64 = core::f64::consts::PI; +const P14: f64 = 0.25 * core::f64::consts::PI; +const P12: f64 = 0.5 * core::f64::consts::PI; +const P34: f64 = 0.75 * core::f64::consts::PI; +const N: f64 = f64::NAN; +#[allow(clippy::excessive_precision)] +const U: f64 = -9.5426319407711027e33; // unlikely value, used as placeholder + +/// Helper to create Complex64 in const context (for special value tables) +#[inline] +const fn c(re: f64, im: f64) -> num_complex::Complex64 { + num_complex::Complex64::new(re, im) +} + +/// Special value types for classifying doubles. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(usize)] +enum SpecialType { + NInf = 0, // negative infinity + Neg = 1, // negative finite (nonzero) + NZero = 2, // -0. + PZero = 3, // +0. + Pos = 4, // positive finite (nonzero) + PInf = 5, // positive infinity + Nan = 6, // NaN +} + +/// Return special value from table if input is non-finite. +macro_rules! special_value { + ($z:expr, $table:expr) => { + if !$z.re.is_finite() || !$z.im.is_finite() { + return Ok($table[special_type($z.re) as usize][special_type($z.im) as usize]); + } + }; +} +pub(crate) use special_value; + +/// Classify a double into one of seven special types. +#[inline] +fn special_type(d: f64) -> SpecialType { + if d.is_finite() { + if d != 0.0 { + if m::copysign(1.0, d) == 1.0 { + SpecialType::Pos + } else { + SpecialType::Neg + } + } else if m::copysign(1.0, d) == 1.0 { + SpecialType::PZero + } else { + SpecialType::NZero + } + } else if d.is_nan() { + SpecialType::Nan + } else if m::copysign(1.0, d) == 1.0 { + SpecialType::PInf + } else { + SpecialType::NInf + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + + /// Compare complex result with CPython, allowing small ULP differences for finite values. + pub fn assert_complex_eq(py_re: f64, py_im: f64, rs: Complex64, func: &str, re: f64, im: f64) { + let check_component = |py: f64, rs: f64, component: &str| { + if py.is_nan() && rs.is_nan() { + // Both NaN - OK + } else if py.is_nan() || rs.is_nan() { + panic!("{func}({re}, {im}).{component}: py={py} vs rs={rs} (one is NaN)",); + } else if py.is_infinite() && rs.is_infinite() { + // Check sign matches + if py.is_sign_positive() != rs.is_sign_positive() { + panic!("{func}({re}, {im}).{component}: py={py} vs rs={rs} (sign mismatch)",); + } + } else if py.is_infinite() || rs.is_infinite() { + panic!("{func}({re}, {im}).{component}: py={py} vs rs={rs} (one is infinite)",); + } else { + // Both finite - allow small ULP difference + let py_bits = py.to_bits() as i64; + let rs_bits = rs.to_bits() as i64; + let ulp_diff = (py_bits - rs_bits).abs(); + if ulp_diff != 0 { + panic!( + "{func}({re}, {im}).{component}: py={py} (bits={:#x}) vs rs={rs} (bits={:#x}), ULP diff={ulp_diff}", + py.to_bits(), + rs.to_bits() + ); + } + } + }; + check_component(py_re, rs.re, "re"); + check_component(py_im, rs.im, "im"); + } + + pub fn test_cmath_func(func_name: &str, rs_func: F, re: f64, im: f64) + where + F: Fn(Complex64) -> Result, + { + use pyo3::prelude::*; + + let rs_result = rs_func(Complex64::new(re, im)); + + pyo3::Python::attach(|py| { + let cmath = pyo3::types::PyModule::import(py, "cmath").unwrap(); + let py_func = cmath.getattr(func_name).unwrap(); + let py_result = py_func.call1((pyo3::types::PyComplex::from_doubles(py, re, im),)); + + match py_result { + Ok(result) => { + use pyo3::types::PyComplexMethods; + let c = result.cast::().unwrap(); + let py_re = c.real(); + let py_im = c.imag(); + match rs_result { + Ok(rs) => { + assert_complex_eq(py_re, py_im, rs, func_name, re, im); + } + Err(e) => { + panic!( + "{func_name}({re}, {im}): py=({py_re}, {py_im}) but rs returned error {e:?}" + ); + } + } + } + Err(e) => { + // CPython raised an exception - check we got an error too + if rs_result.is_ok() { + let rs = rs_result.unwrap(); + // Some special cases may return values for domain errors in Python + // Check if it's a domain error + if e.is_instance_of::(py) { + panic!( + "{func_name}({re}, {im}): py raised ValueError but rs=({}, {})", + rs.re, rs.im + ); + } else if e.is_instance_of::(py) { + panic!( + "{func_name}({re}, {im}): py raised OverflowError but rs=({}, {})", + rs.re, rs.im + ); + } + } + // Both raised errors - OK + } + } + }); + } +} diff --git a/src/cmath/exponential.rs b/src/cmath/exponential.rs new file mode 100644 index 0000000..1fdc30f --- /dev/null +++ b/src/cmath/exponential.rs @@ -0,0 +1,260 @@ +//! Complex exponential and logarithmic functions. + +use super::{ + CM_LARGE_DOUBLE, CM_LOG_LARGE_DOUBLE, INF, M_LN2, N, P, P12, P14, P34, U, c, special_type, + special_value, +}; +use crate::{Error, Result, m}; +use num_complex::Complex64; + +// Local constants +const M_LN10: f64 = core::f64::consts::LN_10; + +/// Scale factors for subnormal handling in sqrt. +const CM_SCALE_UP: i32 = 2 * (f64::MANTISSA_DIGITS as i32 / 2) + 1; // 54 for IEEE 754 +const CM_SCALE_DOWN: i32 = -(CM_SCALE_UP + 1) / 2; // -27 + +// Special value tables + +#[rustfmt::skip] +static EXP_SPECIAL_VALUES: [[Complex64; 7]; 7] = [ + [c(0.0, 0.0), c(U, U), c(0.0, -0.0), c(0.0, 0.0), c(U, U), c(0.0, 0.0), c(0.0, 0.0)], + [c(N, N), c(U, U), c(U, U), c(U, U), c(U, U), c(N, N), c(N, N)], + [c(N, N), c(U, U), c(1.0, -0.0), c(1.0, 0.0), c(U, U), c(N, N), c(N, N)], + [c(N, N), c(U, U), c(1.0, -0.0), c(1.0, 0.0), c(U, U), c(N, N), c(N, N)], + [c(N, N), c(U, U), c(U, U), c(U, U), c(U, U), c(N, N), c(N, N)], + [c(INF, N), c(U, U), c(INF, -0.0), c(INF, 0.0), c(U, U), c(INF, N), c(INF, N)], + [c(N, N), c(N, N), c(N, -0.0), c(N, 0.0), c(N, N), c(N, N), c(N, N)], +]; + +#[rustfmt::skip] +static LOG_SPECIAL_VALUES: [[Complex64; 7]; 7] = [ + [c(INF, -P34), c(INF, -P), c(INF, -P), c(INF, P), c(INF, P), c(INF, P34), c(INF, N)], + [c(INF, -P12), c(U, U), c(U, U), c(U, U), c(U, U), c(INF, P12), c(N, N)], + [c(INF, -P12), c(U, U), c(-INF, -P), c(-INF, P), c(U, U), c(INF, P12), c(N, N)], + [c(INF, -P12), c(U, U), c(-INF, -0.0), c(-INF, 0.0), c(U, U), c(INF, P12), c(N, N)], + [c(INF, -P12), c(U, U), c(U, U), c(U, U), c(U, U), c(INF, P12), c(N, N)], + [c(INF, -P14), c(INF, -0.0), c(INF, -0.0), c(INF, 0.0), c(INF, 0.0), c(INF, P14), c(INF, N)], + [c(INF, N), c(N, N), c(N, N), c(N, N), c(N, N), c(INF, N), c(N, N)], +]; + +#[rustfmt::skip] +static SQRT_SPECIAL_VALUES: [[Complex64; 7]; 7] = [ + [c(INF, -INF), c(0.0, -INF), c(0.0, -INF), c(0.0, INF), c(0.0, INF), c(INF, INF), c(N, INF)], + [c(INF, -INF), c(U, U), c(U, U), c(U, U), c(U, U), c(INF, INF), c(N, N)], + [c(INF, -INF), c(U, U), c(0.0, -0.0), c(0.0, 0.0), c(U, U), c(INF, INF), c(N, N)], + [c(INF, -INF), c(U, U), c(0.0, -0.0), c(0.0, 0.0), c(U, U), c(INF, INF), c(N, N)], + [c(INF, -INF), c(U, U), c(U, U), c(U, U), c(U, U), c(INF, INF), c(N, N)], + [c(INF, -INF), c(INF, -0.0), c(INF, -0.0), c(INF, 0.0), c(INF, 0.0), c(INF, INF), c(INF, N)], + [c(INF, -INF), c(N, N), c(N, N), c(N, N), c(N, N), c(INF, INF), c(N, N)], +]; + +/// Complex square root. +/// +/// Uses symmetries to reduce to the case when x = z.real and y = z.imag +/// are nonnegative, with careful handling of overflow and subnormals. +#[inline] +pub fn sqrt(z: Complex64) -> Result { + special_value!(z, SQRT_SPECIAL_VALUES); + + if z.re == 0.0 && z.im == 0.0 { + return Ok(Complex64::new(0.0, z.im)); + } + + let ax = m::fabs(z.re); + let ay = m::fabs(z.im); + + let s = if ax < f64::MIN_POSITIVE && ay < f64::MIN_POSITIVE { + // Handle subnormal case + let ax_scaled = m::ldexp(ax, CM_SCALE_UP); + m::ldexp( + m::sqrt(ax_scaled + m::hypot(ax_scaled, m::ldexp(ay, CM_SCALE_UP))), + CM_SCALE_DOWN, + ) + } else { + let ax8 = ax / 8.0; + 2.0 * m::sqrt(ax8 + m::hypot(ax8, ay / 8.0)) + }; + + let d = ay / (2.0 * s); + + if z.re >= 0.0 { + Ok(Complex64::new(s, m::copysign(d, z.im))) + } else { + Ok(Complex64::new(d, m::copysign(s, z.im))) + } +} + +/// Complex exponential. +#[inline] +pub fn exp(z: Complex64) -> Result { + // Handle special values + if !z.re.is_finite() || !z.im.is_finite() { + let r = if z.re.is_infinite() && z.im.is_finite() && z.im != 0.0 { + if z.re > 0.0 { + Complex64::new( + m::copysign(INF, m::cos(z.im)), + m::copysign(INF, m::sin(z.im)), + ) + } else { + Complex64::new( + m::copysign(0.0, m::cos(z.im)), + m::copysign(0.0, m::sin(z.im)), + ) + } + } else { + EXP_SPECIAL_VALUES[special_type(z.re) as usize][special_type(z.im) as usize] + }; + // need to set errno = EDOM if y is +/- infinity and x is not a NaN and not -infinity + if z.im.is_infinite() && (z.re.is_finite() || (z.re.is_infinite() && z.re > 0.0)) { + return Err(Error::EDOM); + } + return Ok(r); + } + + let (sin_im, cos_im) = m::sincos(z.im); + let (r_re, r_im); + if z.re > CM_LOG_LARGE_DOUBLE { + let l = m::exp(z.re - 1.0); + r_re = l * cos_im * core::f64::consts::E; + r_im = l * sin_im * core::f64::consts::E; + } else { + let l = m::exp(z.re); + r_re = l * cos_im; + r_im = l * sin_im; + } + + // detect overflow + if r_re.is_infinite() || r_im.is_infinite() { + return Err(Error::ERANGE); + } + Ok(Complex64::new(r_re, r_im)) +} + +/// Complex natural logarithm. +#[inline] +pub fn log(z: Complex64) -> Result { + special_value!(z, LOG_SPECIAL_VALUES); + + let ax = m::fabs(z.re); + let ay = m::fabs(z.im); + + let r_re = if ax > CM_LARGE_DOUBLE || ay > CM_LARGE_DOUBLE { + m::log(m::hypot(ax / 2.0, ay / 2.0)) + M_LN2 + } else if ax < f64::MIN_POSITIVE && ay < f64::MIN_POSITIVE { + if ax > 0.0 || ay > 0.0 { + // catch cases where hypot(ax, ay) is subnormal + m::log(m::hypot( + m::ldexp(ax, f64::MANTISSA_DIGITS as i32), + m::ldexp(ay, f64::MANTISSA_DIGITS as i32), + )) - f64::MANTISSA_DIGITS as f64 * M_LN2 + } else { + // log(+/-0. +/- 0i) + return Err(Error::EDOM); + } + } else { + let h = m::hypot(ax, ay); + if (0.71..=1.73).contains(&h) { + let am = if ax > ay { ax } else { ay }; // max(ax, ay) + let an = if ax > ay { ay } else { ax }; // min(ax, ay) + m::log1p((am - 1.0) * (am + 1.0) + an * an) / 2.0 + } else { + m::log(h) + } + }; + + let r_im = m::atan2(z.im, z.re); + Ok(Complex64::new(r_re, r_im)) +} + +/// Complex base-10 logarithm. +#[inline] +pub fn log10(z: Complex64) -> Result { + let r = log(z)?; + Ok(Complex64::new(r.re / M_LN10, r.im / M_LN10)) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_cmath_func(func_name: &str, rs_func: F, re: f64, im: f64) + where + F: Fn(Complex64) -> Result, + { + crate::cmath::tests::test_cmath_func(func_name, rs_func, re, im); + } + + fn test_sqrt(re: f64, im: f64) { + test_cmath_func("sqrt", sqrt, re, im); + } + fn test_exp(re: f64, im: f64) { + test_cmath_func("exp", exp, re, im); + } + fn test_log(re: f64, im: f64) { + test_cmath_func("log", log, re, im); + } + fn test_log10(re: f64, im: f64) { + test_cmath_func("log10", log10, re, im); + } + + use crate::test::EDGE_VALUES; + + #[test] + fn edgetest_sqrt() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_sqrt(re, im); + } + } + } + + #[test] + fn edgetest_exp() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_exp(re, im); + } + } + } + + #[test] + fn edgetest_log() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_log(re, im); + } + } + } + + #[test] + fn edgetest_log10() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_log10(re, im); + } + } + } + + proptest::proptest! { + #[test] + fn proptest_sqrt(re: f64, im: f64) { + test_sqrt(re, im); + } + + #[test] + fn proptest_exp(re: f64, im: f64) { + test_exp(re, im); + } + + #[test] + fn proptest_log(re: f64, im: f64) { + test_log(re, im); + } + + #[test] + fn proptest_log10(re: f64, im: f64) { + test_log10(re, im); + } + } +} diff --git a/src/cmath/misc.rs b/src/cmath/misc.rs new file mode 100644 index 0000000..9831cac --- /dev/null +++ b/src/cmath/misc.rs @@ -0,0 +1,369 @@ +//! Complex polar coordinate and utility functions. + +use super::{INF, N, U, c, special_type}; +use crate::{Error, Result, m}; +use num_complex::Complex64; + +#[rustfmt::skip] +static RECT_SPECIAL_VALUES: [[Complex64; 7]; 7] = [ + [c(INF, N), c(U, U), c(-INF, 0.0), c(-INF, -0.0), c(U, U), c(INF, N), c(INF, N)], + [c(N, N), c(U, U), c(U, U), c(U, U), c(U, U), c(N, N), c(N, N)], + [c(0.0, 0.0), c(U, U), c(-0.0, 0.0), c(-0.0, -0.0), c(U, U), c(0.0, 0.0), c(0.0, 0.0)], + [c(0.0, 0.0), c(U, U), c(0.0, -0.0), c(0.0, 0.0), c(U, U), c(0.0, 0.0), c(0.0, 0.0)], + [c(N, N), c(U, U), c(U, U), c(U, U), c(U, U), c(N, N), c(N, N)], + [c(INF, N), c(U, U), c(INF, -0.0), c(INF, 0.0), c(U, U), c(INF, N), c(INF, N)], + [c(N, N), c(N, N), c(N, 0.0), c(N, 0.0), c(N, N), c(N, N), c(N, N)], +]; + +/// Return the phase angle (argument) of z. +#[inline] +pub fn phase(z: Complex64) -> Result { + crate::err::set_errno(0); + let phi = m::atan2(z.im, z.re); + match crate::err::get_errno() { + 0 => Ok(phi), + libc::EDOM => Err(Error::EDOM), + libc::ERANGE => Err(Error::ERANGE), + _ => Err(Error::EDOM), // Unknown errno treated as domain error (like PyErr_SetFromErrno) + } +} + +/// Convert z to polar coordinates (r, phi). +#[inline] +pub fn polar(z: Complex64) -> Result<(f64, f64)> { + let phi = m::atan2(z.im, z.re); + let r = m::hypot(z.re, z.im); + if r.is_infinite() && z.re.is_finite() && z.im.is_finite() { + return Err(Error::ERANGE); + } + Ok((r, phi)) +} + +/// Convert polar coordinates (r, phi) to rectangular form. +#[inline] +pub fn rect(r: f64, phi: f64) -> Result { + // Handle special values + if !r.is_finite() || !phi.is_finite() { + // if r is +/-infinity and phi is finite but nonzero then + // result is (+-INF +-INF i), but we need to compute cos(phi) + // and sin(phi) to figure out the signs. + let z = if r.is_infinite() && phi.is_finite() && phi != 0.0 { + if r > 0.0 { + Complex64::new(m::copysign(INF, m::cos(phi)), m::copysign(INF, m::sin(phi))) + } else { + Complex64::new( + -m::copysign(INF, m::cos(phi)), + -m::copysign(INF, m::sin(phi)), + ) + } + } else { + RECT_SPECIAL_VALUES[special_type(r) as usize][special_type(phi) as usize] + }; + // need to set errno = EDOM if r is a nonzero number and phi is infinite + if r != 0.0 && !r.is_nan() && phi.is_infinite() { + return Err(Error::EDOM); + } + return Ok(z); + } else if phi == 0.0 { + // Workaround for buggy results with phi=-0.0 on OS X 10.8. + return Ok(Complex64::new(r, r * phi)); + } + + let (sin_phi, cos_phi) = m::sincos(phi); + Ok(Complex64::new(r * cos_phi, r * sin_phi)) +} + +/// Return True if both real and imaginary parts are finite. +#[inline] +pub fn isfinite(z: Complex64) -> bool { + z.re.is_finite() && z.im.is_finite() +} + +/// Return True if either real or imaginary part is NaN. +#[inline] +pub fn isnan(z: Complex64) -> bool { + z.re.is_nan() || z.im.is_nan() +} + +/// Return True if either real or imaginary part is infinite. +#[inline] +pub fn isinf(z: Complex64) -> bool { + z.re.is_infinite() || z.im.is_infinite() +} + +/// Complex absolute value (magnitude). +#[inline] +pub fn abs(z: Complex64) -> f64 { + m::hypot(z.re, z.im) +} + +/// Determine whether two complex numbers are close in value. +#[inline] +pub fn isclose(a: Complex64, b: Complex64, rel_tol: f64, abs_tol: f64) -> bool { + // short circuit exact equality + if a.re == b.re && a.im == b.im { + return true; + } + + // This catches the case of two infinities of opposite sign, or + // one infinity and one finite number. + if a.re.is_infinite() || a.im.is_infinite() || b.re.is_infinite() || b.im.is_infinite() { + return false; + } + + // now do the regular computation + let diff = abs(Complex64::new(a.re - b.re, a.im - b.im)); + (diff <= rel_tol * abs(b)) || (diff <= rel_tol * abs(a)) || (diff <= abs_tol) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::test::EDGE_VALUES; + + fn test_phase_impl(re: f64, im: f64) { + use pyo3::prelude::*; + + let rs_result = phase(Complex64::new(re, im)); + + pyo3::Python::attach(|py| { + let cmath = pyo3::types::PyModule::import(py, "cmath").unwrap(); + let py_func = cmath.getattr("phase").unwrap(); + let py_result = py_func.call1((pyo3::types::PyComplex::from_doubles(py, re, im),)); + + match py_result { + Ok(result) => { + let py_val: f64 = result.extract().unwrap(); + match rs_result { + Ok(rs_val) => { + if py_val.is_nan() && rs_val.is_nan() { + return; + } + assert_eq!( + py_val.to_bits(), + rs_val.to_bits(), + "phase({re}, {im}): py={py_val} vs rs={rs_val}" + ); + } + Err(e) => { + panic!("phase({re}, {im}): py={py_val} but rs returned error {e:?}"); + } + } + } + Err(e) => { + // Python raised an exception - check we got an error too + if rs_result.is_ok() { + let rs_val = rs_result.unwrap(); + if e.is_instance_of::(py) { + panic!("phase({re}, {im}): py raised ValueError but rs={rs_val}"); + } else if e.is_instance_of::(py) { + panic!("phase({re}, {im}): py raised OverflowError but rs={rs_val}"); + } + } + // Both raised errors - OK + } + } + }); + } + + #[test] + fn edgetest_phase() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_phase_impl(re, im); + } + } + } + + fn test_polar_impl(re: f64, im: f64) { + use pyo3::prelude::*; + + let rs_result = polar(Complex64::new(re, im)); + + pyo3::Python::attach(|py| { + let cmath = pyo3::types::PyModule::import(py, "cmath").unwrap(); + let py_func = cmath.getattr("polar").unwrap(); + let py_result = py_func.call1((pyo3::types::PyComplex::from_doubles(py, re, im),)); + + match py_result { + Ok(result) => { + let (py_r, py_phi): (f64, f64) = result.extract().unwrap(); + match rs_result { + Ok((rs_r, rs_phi)) => { + // Check r + if !py_r.is_nan() || !rs_r.is_nan() { + if py_r.is_nan() || rs_r.is_nan() { + panic!("polar({re}, {im}).r: py={py_r} vs rs={rs_r}"); + } + assert_eq!( + py_r.to_bits(), + rs_r.to_bits(), + "polar({re}, {im}).r: py={py_r} vs rs={rs_r}" + ); + } + // Check phi + if !py_phi.is_nan() || !rs_phi.is_nan() { + if py_phi.is_nan() || rs_phi.is_nan() { + panic!("polar({re}, {im}).phi: py={py_phi} vs rs={rs_phi}"); + } + assert_eq!( + py_phi.to_bits(), + rs_phi.to_bits(), + "polar({re}, {im}).phi: py={py_phi} vs rs={rs_phi}" + ); + } + } + Err(_) => { + panic!( + "polar({re}, {im}): py=({py_r}, {py_phi}) but rs returned error" + ); + } + } + } + Err(_) => { + // CPython raised error - check we did too + assert!( + rs_result.is_err(), + "polar({re}, {im}): py raised error but rs succeeded" + ); + } + } + }); + } + + #[test] + fn edgetest_polar() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_polar_impl(re, im); + } + } + } + + fn test_rect_impl(r: f64, phi: f64) { + use pyo3::prelude::*; + + let rs_result = rect(r, phi); + + pyo3::Python::attach(|py| { + let cmath = pyo3::types::PyModule::import(py, "cmath").unwrap(); + let py_func = cmath.getattr("rect").unwrap(); + let py_result = py_func.call1((r, phi)); + + match py_result { + Ok(result) => { + use pyo3::types::PyComplexMethods; + let c = result.cast::().unwrap(); + let py_re = c.real(); + let py_im = c.imag(); + match rs_result { + Ok(rs) => { + crate::cmath::tests::assert_complex_eq( + py_re, py_im, rs, "rect", r, phi, + ); + } + Err(_) => { + panic!("rect({r}, {phi}): py=({py_re}, {py_im}) but rs returned error"); + } + } + } + Err(_) => { + // CPython raised error + assert!( + rs_result.is_err(), + "rect({r}, {phi}): py raised error but rs succeeded" + ); + } + } + }); + } + + #[test] + fn edgetest_rect() { + for &r in &EDGE_VALUES { + for &phi in &EDGE_VALUES { + test_rect_impl(r, phi); + } + } + } + + #[test] + fn test_isfinite() { + assert!(isfinite(Complex64::new(1.0, 2.0))); + assert!(!isfinite(Complex64::new(f64::INFINITY, 0.0))); + assert!(!isfinite(Complex64::new(0.0, f64::INFINITY))); + assert!(!isfinite(Complex64::new(f64::NAN, 0.0))); + } + + #[test] + fn test_isnan() { + assert!(!isnan(Complex64::new(1.0, 2.0))); + assert!(!isnan(Complex64::new(f64::INFINITY, 0.0))); + assert!(isnan(Complex64::new(f64::NAN, 0.0))); + assert!(isnan(Complex64::new(0.0, f64::NAN))); + } + + #[test] + fn test_isinf() { + assert!(!isinf(Complex64::new(1.0, 2.0))); + assert!(isinf(Complex64::new(f64::INFINITY, 0.0))); + assert!(isinf(Complex64::new(0.0, f64::INFINITY))); + assert!(!isinf(Complex64::new(f64::NAN, 0.0))); + } + + #[test] + fn test_isclose_basic() { + // Equal values + assert!(isclose( + Complex64::new(1.0, 2.0), + Complex64::new(1.0, 2.0), + 1e-9, + 0.0 + )); + // Close values + assert!(isclose( + Complex64::new(1.0, 2.0), + Complex64::new(1.0 + 1e-10, 2.0), + 1e-9, + 0.0 + )); + // Not close + assert!(!isclose( + Complex64::new(1.0, 2.0), + Complex64::new(2.0, 2.0), + 1e-9, + 0.0 + )); + // Infinities + assert!(isclose( + Complex64::new(f64::INFINITY, 0.0), + Complex64::new(f64::INFINITY, 0.0), + 1e-9, + 0.0 + )); + assert!(!isclose( + Complex64::new(f64::INFINITY, 0.0), + Complex64::new(f64::NEG_INFINITY, 0.0), + 1e-9, + 0.0 + )); + } + + proptest::proptest! { + #[test] + fn proptest_phase(re: f64, im: f64) { + test_phase_impl(re, im); + } + + #[test] + fn proptest_polar(re: f64, im: f64) { + test_polar_impl(re, im); + } + + #[test] + fn proptest_rect(r: f64, phi: f64) { + test_rect_impl(r, phi); + } + } +} diff --git a/src/cmath/trigonometric.rs b/src/cmath/trigonometric.rs new file mode 100644 index 0000000..2ffb128 --- /dev/null +++ b/src/cmath/trigonometric.rs @@ -0,0 +1,591 @@ +//! Complex trigonometric and hyperbolic functions. + +use super::{ + CM_LARGE_DOUBLE, CM_LOG_LARGE_DOUBLE, INF, M_LN2, N, P, P12, P14, P34, U, c, special_type, + special_value, sqrt, +}; +use crate::{Error, Result, m, mul_add}; +use num_complex::Complex64; + +// Local constants +const CM_SQRT_LARGE_DOUBLE: f64 = 6.703903964971298e+153; // sqrt(CM_LARGE_DOUBLE) +const CM_SQRT_DBL_MIN: f64 = 1.4916681462400413e-154; // sqrt(f64::MIN_POSITIVE) + +// Special value tables + +#[rustfmt::skip] +static ACOS_SPECIAL_VALUES: [[Complex64; 7]; 7] = [ + [c(P34, INF), c(P, INF), c(P, INF), c(P, -INF), c(P, -INF), c(P34, -INF), c(N, INF)], + [c(P12, INF), c(U, U), c(U, U), c(U, U), c(U, U), c(P12, -INF), c(N, N)], + [c(P12, INF), c(U, U), c(P12, 0.0), c(P12, -0.0),c(U, U), c(P12, -INF), c(P12, N)], + [c(P12, INF), c(U, U), c(P12, 0.0), c(P12, -0.0),c(U, U), c(P12, -INF), c(P12, N)], + [c(P12, INF), c(U, U), c(U, U), c(U, U), c(U, U), c(P12, -INF), c(N, N)], + [c(P14, INF), c(0.0, INF), c(0.0, INF), c(0.0, -INF),c(0.0, -INF),c(P14, -INF), c(N, INF)], + [c(N, INF), c(N, N), c(N, N), c(N, N), c(N, N), c(N, -INF), c(N, N)], +]; + +#[rustfmt::skip] +static ACOSH_SPECIAL_VALUES: [[Complex64; 7]; 7] = [ + [c(INF, -P34), c(INF, -P), c(INF, -P), c(INF, P), c(INF, P), c(INF, P34), c(INF, N)], + [c(INF, -P12), c(U, U), c(U, U), c(U, U), c(U, U), c(INF, P12), c(N, N)], + [c(INF, -P12), c(U, U), c(0.0, -P12), c(0.0, P12), c(U, U), c(INF, P12), c(N, P12)], + [c(INF, -P12), c(U, U), c(0.0, -P12), c(0.0, P12), c(U, U), c(INF, P12), c(N, P12)], + [c(INF, -P12), c(U, U), c(U, U), c(U, U), c(U, U), c(INF, P12), c(N, N)], + [c(INF, -P14), c(INF, -0.0), c(INF, -0.0), c(INF, 0.0), c(INF, 0.0), c(INF, P14), c(INF, N)], + [c(INF, N), c(N, N), c(N, N), c(N, N), c(N, N), c(INF, N), c(N, N)], +]; + +#[rustfmt::skip] +static ASINH_SPECIAL_VALUES: [[Complex64; 7]; 7] = [ + [c(-INF, -P14), c(-INF, -0.0), c(-INF, -0.0), c(-INF, 0.0), c(-INF, 0.0), c(-INF, P14), c(-INF, N)], + [c(-INF, -P12), c(U, U), c(U, U), c(U, U), c(U, U), c(-INF, P12), c(N, N)], + [c(-INF, -P12), c(U, U), c(-0.0, -0.0), c(-0.0, 0.0), c(U, U), c(-INF, P12), c(N, N)], + [c(INF, -P12), c(U, U), c(0.0, -0.0), c(0.0, 0.0), c(U, U), c(INF, P12), c(N, N)], + [c(INF, -P12), c(U, U), c(U, U), c(U, U), c(U, U), c(INF, P12), c(N, N)], + [c(INF, -P14), c(INF, -0.0), c(INF, -0.0), c(INF, 0.0), c(INF, 0.0), c(INF, P14), c(INF, N)], + [c(INF, N), c(N, N), c(N, -0.0), c(N, 0.0), c(N, N), c(INF, N), c(N, N)], +]; + +#[rustfmt::skip] +static ATANH_SPECIAL_VALUES: [[Complex64; 7]; 7] = [ + [c(-0.0, -P12), c(-0.0, -P12), c(-0.0, -P12), c(-0.0, P12), c(-0.0, P12), c(-0.0, P12), c(-0.0, N)], + [c(-0.0, -P12), c(U, U), c(U, U), c(U, U), c(U, U), c(-0.0, P12), c(N, N)], + [c(-0.0, -P12), c(U, U), c(-0.0, -0.0), c(-0.0, 0.0), c(U, U), c(-0.0, P12), c(-0.0, N)], + [c(0.0, -P12), c(U, U), c(0.0, -0.0), c(0.0, 0.0), c(U, U), c(0.0, P12), c(0.0, N)], + [c(0.0, -P12), c(U, U), c(U, U), c(U, U), c(U, U), c(0.0, P12), c(N, N)], + [c(0.0, -P12), c(0.0, -P12), c(0.0, -P12), c(0.0, P12), c(0.0, P12), c(0.0, P12), c(0.0, N)], + [c(0.0, -P12), c(N, N), c(N, N), c(N, N), c(N, N), c(0.0, P12), c(N, N)], +]; + +#[rustfmt::skip] +static COSH_SPECIAL_VALUES: [[Complex64; 7]; 7] = [ + [c(INF, N), c(U, U), c(INF, 0.0), c(INF, -0.0), c(U, U), c(INF, N), c(INF, N)], + [c(N, N), c(U, U), c(U, U), c(U, U), c(U, U), c(N, N), c(N, N)], + [c(N, 0.0), c(U, U), c(1.0, 0.0), c(1.0, -0.0), c(U, U), c(N, 0.0), c(N, 0.0)], + [c(N, 0.0), c(U, U), c(1.0, -0.0), c(1.0, 0.0), c(U, U), c(N, 0.0), c(N, 0.0)], + [c(N, N), c(U, U), c(U, U), c(U, U), c(U, U), c(N, N), c(N, N)], + [c(INF, N), c(U, U), c(INF, -0.0), c(INF, 0.0), c(U, U), c(INF, N), c(INF, N)], + [c(N, N), c(N, N), c(N, 0.0), c(N, 0.0), c(N, N), c(N, N), c(N, N)], +]; + +#[rustfmt::skip] +static SINH_SPECIAL_VALUES: [[Complex64; 7]; 7] = [ + [c(INF, N), c(U, U), c(-INF, -0.0), c(-INF, 0.0), c(U, U), c(INF, N), c(INF, N)], + [c(N, N), c(U, U), c(U, U), c(U, U), c(U, U), c(N, N), c(N, N)], + [c(0.0, N), c(U, U), c(-0.0, -0.0), c(-0.0, 0.0), c(U, U), c(0.0, N), c(0.0, N)], + [c(0.0, N), c(U, U), c(0.0, -0.0), c(0.0, 0.0), c(U, U), c(0.0, N), c(0.0, N)], + [c(N, N), c(U, U), c(U, U), c(U, U), c(U, U), c(N, N), c(N, N)], + [c(INF, N), c(U, U), c(INF, -0.0), c(INF, 0.0), c(U, U), c(INF, N), c(INF, N)], + [c(N, N), c(N, N), c(N, -0.0), c(N, 0.0), c(N, N), c(N, N), c(N, N)], +]; + +#[rustfmt::skip] +static TANH_SPECIAL_VALUES: [[Complex64; 7]; 7] = [ + [c(-1.0, 0.0), c(U, U), c(-1.0, -0.0), c(-1.0, 0.0), c(U, U), c(-1.0, 0.0), c(-1.0, 0.0)], + [c(N, N), c(U, U), c(U, U), c(U, U), c(U, U), c(N, N), c(N, N)], + [c(-0.0, N), c(U, U), c(-0.0, -0.0), c(-0.0, 0.0), c(U, U), c(-0.0, N), c(-0.0, N)], + [c(0.0, N), c(U, U), c(0.0, -0.0), c(0.0, 0.0), c(U, U), c(0.0, N), c(0.0, N)], + [c(N, N), c(U, U), c(U, U), c(U, U), c(U, U), c(N, N), c(N, N)], + [c(1.0, 0.0), c(U, U), c(1.0, -0.0), c(1.0, 0.0), c(U, U), c(1.0, 0.0), c(1.0, 0.0)], + [c(N, N), c(N, N), c(N, -0.0), c(N, 0.0), c(N, N), c(N, N), c(N, N)], +]; + +/// Complex hyperbolic cosine. +#[inline] +pub fn cosh(z: Complex64) -> Result { + // Special treatment for cosh(+/-inf + iy) if y is finite and nonzero + if !z.re.is_finite() || !z.im.is_finite() { + let r = if z.re.is_infinite() && z.im.is_finite() && z.im != 0.0 { + if z.re > 0.0 { + Complex64::new( + m::copysign(INF, m::cos(z.im)), + m::copysign(INF, m::sin(z.im)), + ) + } else { + Complex64::new( + m::copysign(INF, m::cos(z.im)), + -m::copysign(INF, m::sin(z.im)), + ) + } + } else { + COSH_SPECIAL_VALUES[special_type(z.re) as usize][special_type(z.im) as usize] + }; + // need to set errno = EDOM if y is +/- infinity and x is not a NaN + if z.im.is_infinite() && !z.re.is_nan() { + return Err(Error::EDOM); + } + return Ok(r); + } + + let (r_re, r_im); + let (sin_im, cos_im) = m::sincos(z.im); + if m::fabs(z.re) > CM_LOG_LARGE_DOUBLE { + // deal correctly with cases where cosh(z.real) overflows but cosh(z) does not + let x_minus_one = z.re - m::copysign(1.0, z.re); + r_re = cos_im * m::cosh(x_minus_one) * core::f64::consts::E; + r_im = sin_im * m::sinh(x_minus_one) * core::f64::consts::E; + } else { + r_re = cos_im * m::cosh(z.re); + r_im = sin_im * m::sinh(z.re); + } + + // detect overflow + if r_re.is_infinite() || r_im.is_infinite() { + return Err(Error::ERANGE); + } + Ok(Complex64::new(r_re, r_im)) +} + +/// Complex hyperbolic sine. +#[inline] +pub fn sinh(z: Complex64) -> Result { + // Special treatment for sinh(+/-inf + iy) if y is finite and nonzero + if !z.re.is_finite() || !z.im.is_finite() { + let r = if z.re.is_infinite() && z.im.is_finite() && z.im != 0.0 { + if z.re > 0.0 { + Complex64::new( + m::copysign(INF, m::cos(z.im)), + m::copysign(INF, m::sin(z.im)), + ) + } else { + Complex64::new( + -m::copysign(INF, m::cos(z.im)), + m::copysign(INF, m::sin(z.im)), + ) + } + } else { + SINH_SPECIAL_VALUES[special_type(z.re) as usize][special_type(z.im) as usize] + }; + // need to set errno = EDOM if y is +/- infinity and x is not a NaN + if z.im.is_infinite() && !z.re.is_nan() { + return Err(Error::EDOM); + } + return Ok(r); + } + + let (r_re, r_im); + let (sin_im, cos_im) = m::sincos(z.im); + if m::fabs(z.re) > CM_LOG_LARGE_DOUBLE { + let x_minus_one = z.re - m::copysign(1.0, z.re); + r_re = cos_im * m::sinh(x_minus_one) * core::f64::consts::E; + r_im = sin_im * m::cosh(x_minus_one) * core::f64::consts::E; + } else { + r_re = cos_im * m::sinh(z.re); + r_im = sin_im * m::cosh(z.re); + } + + // detect overflow + if r_re.is_infinite() || r_im.is_infinite() { + return Err(Error::ERANGE); + } + Ok(Complex64::new(r_re, r_im)) +} + +/// Complex hyperbolic tangent. +#[inline] +pub fn tanh(z: Complex64) -> Result { + // Special treatment for tanh(+/-inf + iy) if y is finite and nonzero + if !z.re.is_finite() || !z.im.is_finite() { + let r = if z.re.is_infinite() && z.im.is_finite() && z.im != 0.0 { + if z.re > 0.0 { + Complex64::new(1.0, m::copysign(0.0, 2.0 * m::sin(z.im) * m::cos(z.im))) + } else { + Complex64::new(-1.0, m::copysign(0.0, 2.0 * m::sin(z.im) * m::cos(z.im))) + } + } else { + TANH_SPECIAL_VALUES[special_type(z.re) as usize][special_type(z.im) as usize] + }; + // need to set errno = EDOM if z.imag is +/-infinity and z.real is finite + if z.im.is_infinite() && z.re.is_finite() { + return Err(Error::EDOM); + } + return Ok(r); + } + + // danger of overflow in 2.*z.im! + if m::fabs(z.re) > CM_LOG_LARGE_DOUBLE { + let r = Complex64::new( + m::copysign(1.0, z.re), + 4.0 * m::sin(z.im) * m::cos(z.im) * m::exp(-2.0 * m::fabs(z.re)), + ); + return Ok(r); + } + + let tx = m::tanh(z.re); + let ty = m::tan(z.im); + let cx = 1.0 / m::cosh(z.re); + let txty = tx * ty; + let denom = mul_add(txty, txty, 1.0); + let r = Complex64::new(tx * mul_add(ty, ty, 1.0) / denom, ((ty / denom) * cx) * cx); + Ok(r) +} + +/// Complex cosine. +/// cos(z) = cosh(iz) +#[inline] +pub fn cos(z: Complex64) -> Result { + let r = Complex64::new(-z.im, z.re); + cosh(r) +} + +/// Complex sine. +/// sin(z) = -i * sinh(iz) +#[inline] +pub fn sin(z: Complex64) -> Result { + let s = Complex64::new(-z.im, z.re); + let s = sinh(s)?; + Ok(Complex64::new(s.im, -s.re)) +} + +/// Complex tangent. +/// tan(z) = -i * tanh(iz) +#[inline] +pub fn tan(z: Complex64) -> Result { + let s = Complex64::new(-z.im, z.re); + let s = tanh(s)?; + Ok(Complex64::new(s.im, -s.re)) +} + +/// Complex inverse hyperbolic sine. +#[inline] +pub fn asinh(z: Complex64) -> Result { + special_value!(z, ASINH_SPECIAL_VALUES); + + if m::fabs(z.re) > CM_LARGE_DOUBLE || m::fabs(z.im) > CM_LARGE_DOUBLE { + // Avoid overflow for large arguments + let r_re = if z.im >= 0.0 { + m::copysign(m::log(m::hypot(z.re / 2.0, z.im / 2.0)) + M_LN2 * 2.0, z.re) + } else { + -m::copysign( + m::log(m::hypot(z.re / 2.0, z.im / 2.0)) + M_LN2 * 2.0, + -z.re, + ) + }; + let r_im = m::atan2(z.im, m::fabs(z.re)); + return Ok(Complex64::new(r_re, r_im)); + } + + let s1 = sqrt(Complex64::new(1.0 + z.im, -z.re))?; + let s2 = sqrt(Complex64::new(1.0 - z.im, z.re))?; + let r_re = m::asinh(mul_add(s1.re, s2.im, -(s2.re * s1.im))); + let r_im = m::atan2(z.im, mul_add(s1.re, s2.re, -(s1.im * s2.im))); + Ok(Complex64::new(r_re, r_im)) +} + +/// Complex inverse hyperbolic cosine. +#[inline] +pub fn acosh(z: Complex64) -> Result { + special_value!(z, ACOSH_SPECIAL_VALUES); + + if m::fabs(z.re) > CM_LARGE_DOUBLE || m::fabs(z.im) > CM_LARGE_DOUBLE { + // Avoid overflow for large arguments + let r_re = m::log(m::hypot(z.re / 2.0, z.im / 2.0)) + M_LN2 * 2.0; + let r_im = m::atan2(z.im, z.re); + return Ok(Complex64::new(r_re, r_im)); + } + + let s1 = sqrt(Complex64::new(z.re - 1.0, z.im))?; + let s2 = sqrt(Complex64::new(z.re + 1.0, z.im))?; + let r_re = m::asinh(mul_add(s1.re, s2.re, s1.im * s2.im)); + let r_im = 2.0 * m::atan2(s1.im, s2.re); + Ok(Complex64::new(r_re, r_im)) +} + +/// Complex inverse hyperbolic tangent. +#[inline] +pub fn atanh(z: Complex64) -> Result { + special_value!(z, ATANH_SPECIAL_VALUES); + + // Reduce to case where z.real >= 0., using atanh(z) = -atanh(-z) + if z.re < 0.0 { + let r = atanh(Complex64::new(-z.re, -z.im))?; + return Ok(Complex64::new(-r.re, -r.im)); + } + + let ay = m::fabs(z.im); + if z.re > CM_SQRT_LARGE_DOUBLE || ay > CM_SQRT_LARGE_DOUBLE { + // if abs(z) is large then we use the approximation + // atanh(z) ~ 1/z +/- i*pi/2 (+/- depending on sign of z.imag) + let h = m::hypot(z.re / 2.0, z.im / 2.0); + let r_re = z.re / 4.0 / h / h; + let r_im = m::copysign(P12, z.im); + return Ok(Complex64::new(r_re, r_im)); + } else if z.re == 1.0 && ay < CM_SQRT_DBL_MIN { + // C99 standard says: atanh(1+/-0.) should be inf +/- 0i + if ay == 0.0 { + return Err(Error::EDOM); + } else { + let r_re = -m::log(m::sqrt(ay) / m::sqrt(m::hypot(ay, 2.0))); + let r_im = m::copysign(m::atan2(2.0, -ay) / 2.0, z.im); + return Ok(Complex64::new(r_re, r_im)); + } + } + + let one_minus_re = 1.0 - z.re; + let r_re = m::log1p(4.0 * z.re / mul_add(one_minus_re, one_minus_re, ay * ay)) / 4.0; + let r_im = -m::atan2(-2.0 * z.im, mul_add(one_minus_re, 1.0 + z.re, -(ay * ay))) / 2.0; + Ok(Complex64::new(r_re, r_im)) +} + +/// Complex arc cosine. +#[inline] +pub fn acos(z: Complex64) -> Result { + special_value!(z, ACOS_SPECIAL_VALUES); + + if m::fabs(z.re) > CM_LARGE_DOUBLE || m::fabs(z.im) > CM_LARGE_DOUBLE { + // Avoid overflow for large arguments + let r_re = m::atan2(m::fabs(z.im), z.re); + let r_im = if z.re < 0.0 { + -m::copysign(m::log(m::hypot(z.re / 2.0, z.im / 2.0)) + M_LN2 * 2.0, z.im) + } else { + m::copysign( + m::log(m::hypot(z.re / 2.0, z.im / 2.0)) + M_LN2 * 2.0, + -z.im, + ) + }; + return Ok(Complex64::new(r_re, r_im)); + } + + let s1 = sqrt(Complex64::new(1.0 - z.re, -z.im))?; + let s2 = sqrt(Complex64::new(1.0 + z.re, z.im))?; + let r_re = 2.0 * m::atan2(s1.re, s2.re); + let r_im = m::asinh(mul_add(s2.re, s1.im, -(s2.im * s1.re))); + Ok(Complex64::new(r_re, r_im)) +} + +/// Complex arc sine. +/// asin(z) = -i * asinh(iz) +#[inline] +pub fn asin(z: Complex64) -> Result { + let s = asinh(Complex64::new(-z.im, z.re))?; + Ok(Complex64::new(s.im, -s.re)) +} + +/// Complex arc tangent. +/// atan(z) = -i * atanh(iz) +#[inline] +pub fn atan(z: Complex64) -> Result { + let s = atanh(Complex64::new(-z.im, z.re))?; + Ok(Complex64::new(s.im, -s.re)) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_cmath_func(func_name: &str, rs_func: F, re: f64, im: f64) + where + F: Fn(Complex64) -> Result, + { + crate::cmath::tests::test_cmath_func(func_name, rs_func, re, im); + } + + fn test_sin(re: f64, im: f64) { + test_cmath_func("sin", sin, re, im); + } + fn test_cos(re: f64, im: f64) { + test_cmath_func("cos", cos, re, im); + } + fn test_tan(re: f64, im: f64) { + test_cmath_func("tan", tan, re, im); + } + fn test_sinh(re: f64, im: f64) { + test_cmath_func("sinh", sinh, re, im); + } + fn test_cosh(re: f64, im: f64) { + test_cmath_func("cosh", cosh, re, im); + } + fn test_tanh(re: f64, im: f64) { + test_cmath_func("tanh", tanh, re, im); + } + fn test_asin(re: f64, im: f64) { + test_cmath_func("asin", asin, re, im); + } + fn test_acos(re: f64, im: f64) { + test_cmath_func("acos", acos, re, im); + } + fn test_atan(re: f64, im: f64) { + test_cmath_func("atan", atan, re, im); + } + fn test_asinh(re: f64, im: f64) { + test_cmath_func("asinh", asinh, re, im); + } + fn test_acosh(re: f64, im: f64) { + test_cmath_func("acosh", acosh, re, im); + } + fn test_atanh(re: f64, im: f64) { + test_cmath_func("atanh", atanh, re, im); + } + + use crate::test::EDGE_VALUES; + + #[test] + fn edgetest_sin() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_sin(re, im); + } + } + } + + #[test] + fn edgetest_cos() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_cos(re, im); + } + } + } + + #[test] + fn edgetest_tan() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_tan(re, im); + } + } + } + + #[test] + fn edgetest_sinh() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_sinh(re, im); + } + } + } + + #[test] + fn edgetest_cosh() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_cosh(re, im); + } + } + } + + #[test] + fn edgetest_tanh() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_tanh(re, im); + } + } + } + + #[test] + fn edgetest_asin() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_asin(re, im); + } + } + } + + #[test] + fn edgetest_acos() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_acos(re, im); + } + } + } + + #[test] + fn edgetest_atan() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_atan(re, im); + } + } + } + + #[test] + fn edgetest_asinh() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_asinh(re, im); + } + } + } + + #[test] + fn edgetest_acosh() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_acosh(re, im); + } + } + } + + #[test] + fn edgetest_atanh() { + for &re in &EDGE_VALUES { + for &im in &EDGE_VALUES { + test_atanh(re, im); + } + } + } + + proptest::proptest! { + #[test] + fn proptest_sin(re: f64, im: f64) { + test_sin(re, im); + } + + #[test] + fn proptest_cos(re: f64, im: f64) { + test_cos(re, im); + } + + #[test] + fn proptest_tan(re: f64, im: f64) { + test_tan(re, im); + } + + #[test] + fn proptest_sinh(re: f64, im: f64) { + test_sinh(re, im); + } + + #[test] + fn proptest_cosh(re: f64, im: f64) { + test_cosh(re, im); + } + + #[test] + fn proptest_tanh(re: f64, im: f64) { + test_tanh(re, im); + } + + #[test] + fn proptest_asin(re: f64, im: f64) { + test_asin(re, im); + } + + #[test] + fn proptest_acos(re: f64, im: f64) { + test_acos(re, im); + } + + #[test] + fn proptest_atan(re: f64, im: f64) { + test_atan(re, im); + } + + #[test] + fn proptest_asinh(re: f64, im: f64) { + test_asinh(re, im); + } + + #[test] + fn proptest_acosh(re: f64, im: f64) { + test_acosh(re, im); + } + + #[test] + fn proptest_atanh(re: f64, im: f64) { + test_atanh(re, im); + } + } +} diff --git a/src/err.rs b/src/err.rs index ffe23ae..59d802c 100644 --- a/src/err.rs +++ b/src/err.rs @@ -20,3 +20,78 @@ impl TryFrom for Error { } } } + +/// Set errno to the given value. +#[inline] +pub(crate) fn set_errno(value: i32) { + unsafe { + #[cfg(target_os = "linux")] + { + *libc::__errno_location() = value; + } + #[cfg(target_os = "macos")] + { + *libc::__error() = value; + } + #[cfg(target_os = "windows")] + { + unsafe extern "C" { + safe fn _errno() -> *mut i32; + } + *_errno() = value; + } + #[cfg(all(unix, not(any(target_os = "linux", target_os = "macos"))))] + { + // FreeBSD, NetBSD, OpenBSD, etc. use __error() + *libc::__error() = value; + } + } +} + +/// Get the current errno value. +#[inline] +pub(crate) fn get_errno() -> i32 { + unsafe { + #[cfg(target_os = "linux")] + { + *libc::__errno_location() + } + #[cfg(target_os = "macos")] + { + *libc::__error() + } + #[cfg(target_os = "windows")] + { + unsafe extern "C" { + safe fn _errno() -> *mut i32; + } + *_errno() + } + #[cfg(all(unix, not(any(target_os = "linux", target_os = "macos"))))] + { + // FreeBSD, NetBSD, OpenBSD, etc. use __error() + *libc::__error() + } + } +} + +/// Check errno after libm call and convert to Result. +#[inline] +pub(crate) fn is_error(x: f64) -> Result { + match get_errno() { + 0 => Ok(x), + libc::EDOM => Err(Error::EDOM), + libc::ERANGE => { + // Underflow to zero is not an error. + // Use 1.5 threshold to handle subnormal results that don't underflow to zero + // (e.g., on Ubuntu/ia64) and to correctly detect underflows in expm1() + // which may underflow toward -1.0 rather than 0.0. (bpo-46018) + if x.abs() < 1.5 { + Ok(x) + } else { + Err(Error::ERANGE) + } + } + _ => Ok(x), // Unknown errno, just return the value + } +} diff --git a/src/lib.rs b/src/lib.rs index 14363cb..c3a0db4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,29 +1,28 @@ +// Public modules +#[cfg(feature = "complex")] +pub mod cmath; +pub mod math; + +// Internal modules mod err; -mod gamma; -mod m; +pub(crate) mod m; +mod m_sys; #[cfg(test)] mod test; +// Re-export error types at root level pub use err::{Error, Result}; -pub use gamma::{gamma, lgamma}; -macro_rules! libm { - // Reset errno and handle errno when return type contains Result - (fn $name:ident($arg:ident: $ty:ty) -> Result<$ret:ty>) => { - #[inline(always)] - pub fn $name($arg: $ty) -> Result<$ret> { - errno::set_errno(errno::Errno(0)); - let r = unsafe { m::$name($arg) }; - crate::is_error(r) - } - }; - // Skip errno checking when return type is not Result - (fn $name:ident($arg:ident: $ty:ty) -> $ret:ty) => { - #[inline(always)] - pub fn $name($arg: $ty) -> $ret { - unsafe { m::$name($arg) } - } - }; +/// Fused multiply-add operation. +/// When `mul_add` feature is enabled, uses hardware FMA instruction. +/// Otherwise, uses separate multiply and add operations. +#[inline(always)] +pub(crate) fn mul_add(a: f64, b: f64, c: f64) -> f64 { + if cfg!(feature = "mul_add") { + a.mul_add(b, c) + } else { + a * b + c + } } macro_rules! pyo3_proptest { @@ -34,8 +33,7 @@ macro_rules! pyo3_proptest { let rs_result = $fn_name(x); - pyo3::prepare_freethreaded_python(); - Python::with_gil(|py| { + pyo3::Python::attach(|py| { let math = PyModule::import(py, "math").unwrap(); let py_func = math .getattr(stringify!($fn_name)) @@ -60,8 +58,7 @@ macro_rules! pyo3_proptest { let rs_result = Ok($fn_name(x)); - pyo3::prepare_freethreaded_python(); - Python::with_gil(|py| { + pyo3::Python::attach(|py| { let math = PyModule::import(py, "math").unwrap(); let py_func = math .getattr(stringify!($fn_name)) @@ -101,20 +98,4 @@ macro_rules! pyo3_proptest { }; } -libm!(fn erf(n: f64) -> f64); -pyo3_proptest!(erf(_), test_erf, proptest_erf, edgetest_erf); - -libm!(fn erfc(n: f64) -> f64); -pyo3_proptest!(erfc(_), test_erfc, proptest_erfc, edgetest_erfc); - -/// Call is_error when errno != 0, and where x is the result libm -/// returned. is_error will usually set up an exception and return -/// true (1), but may return false (0) without setting up an exception. -// fn is_error(x: f64) -> crate::Result { -// match errno::errno() { -// errno::Errno(0) => Ok(x), -// errno::Errno(libc::ERANGE) if x.abs() < 1.5 => Ok(0f64), -// errno::Errno(errno) => Err(errno.try_into().unwrap()), -// } -// } use pyo3_proptest; diff --git a/src/m.rs b/src/m.rs index 0dff12a..259ef22 100644 --- a/src/m.rs +++ b/src/m.rs @@ -1,65 +1,247 @@ -//! Partial copy of std::sys::_cmath +//! Safe wrappers for system libm functions. -// These symbols are all defined by `libm`, -// or by `compiler-builtins` on unsupported platforms. -#[allow(dead_code)] +use crate::m_sys; + +// Trigonometric functions + +#[inline(always)] +pub fn acos(n: f64) -> f64 { + unsafe { m_sys::acos(n) } +} + +#[inline(always)] +pub fn asin(n: f64) -> f64 { + unsafe { m_sys::asin(n) } +} + +#[inline(always)] +pub fn atan(n: f64) -> f64 { + unsafe { m_sys::atan(n) } +} + +#[inline(always)] +pub fn atan2(y: f64, x: f64) -> f64 { + unsafe { m_sys::atan2(y, x) } +} + +#[inline(always)] +pub fn cos(n: f64) -> f64 { + unsafe { m_sys::cos(n) } +} + +#[inline(always)] +pub fn sin(n: f64) -> f64 { + unsafe { m_sys::sin(n) } +} + +#[inline(always)] +pub fn tan(n: f64) -> f64 { + unsafe { m_sys::tan(n) } +} + +// Hyperbolic functions + +#[inline(always)] +pub fn acosh(n: f64) -> f64 { + unsafe { m_sys::acosh(n) } +} + +#[inline(always)] +pub fn asinh(n: f64) -> f64 { + unsafe { m_sys::asinh(n) } +} + +#[inline(always)] +pub fn atanh(n: f64) -> f64 { + unsafe { m_sys::atanh(n) } +} + +#[inline(always)] +pub fn cosh(n: f64) -> f64 { + unsafe { m_sys::cosh(n) } +} + +#[inline(always)] +pub fn sinh(n: f64) -> f64 { + unsafe { m_sys::sinh(n) } +} + +#[inline(always)] +pub fn tanh(n: f64) -> f64 { + unsafe { m_sys::tanh(n) } +} + +// Exponential and logarithmic functions + +#[inline(always)] +pub fn exp(n: f64) -> f64 { + unsafe { m_sys::exp(n) } +} + +#[inline(always)] +pub fn exp2(n: f64) -> f64 { + unsafe { m_sys::exp2(n) } +} + +#[inline(always)] +pub fn expm1(n: f64) -> f64 { + unsafe { m_sys::expm1(n) } +} + +#[inline(always)] +pub fn log(n: f64) -> f64 { + unsafe { m_sys::log(n) } +} + +#[inline(always)] +pub fn log10(n: f64) -> f64 { + unsafe { m_sys::log10(n) } +} + +#[inline(always)] +pub fn log1p(n: f64) -> f64 { + unsafe { m_sys::log1p(n) } +} + +#[inline(always)] +pub fn log2(n: f64) -> f64 { + unsafe { m_sys::log2(n) } +} + +// Power functions + +#[inline(always)] +pub fn cbrt(n: f64) -> f64 { + unsafe { m_sys::cbrt(n) } +} + +#[inline(always)] +pub fn hypot(x: f64, y: f64) -> f64 { + unsafe { m_sys::hypot(x, y) } +} + +#[inline(always)] +pub fn pow(x: f64, y: f64) -> f64 { + unsafe { m_sys::pow(x, y) } +} + +#[inline(always)] +pub fn sqrt(n: f64) -> f64 { + unsafe { m_sys::sqrt(n) } +} + +// Floating-point manipulation functions + +#[inline(always)] +pub fn ceil(n: f64) -> f64 { + unsafe { m_sys::ceil(n) } +} + +#[inline(always)] +pub fn copysign(x: f64, y: f64) -> f64 { + unsafe { m_sys::copysign(x, y) } +} + +#[inline(always)] +pub fn fabs(n: f64) -> f64 { + unsafe { m_sys::fabs(n) } +} + +// #[inline(always)] +// pub fn fdim(a: f64, b: f64) -> f64 { +// unsafe { m_sys::fdim(a, b) } +// } + +#[inline(always)] +pub fn floor(n: f64) -> f64 { + unsafe { m_sys::floor(n) } +} + +#[inline(always)] +pub fn fmod(x: f64, y: f64) -> f64 { + unsafe { m_sys::fmod(x, y) } +} + +#[inline(always)] +pub fn frexp(n: f64, exp: &mut i32) -> f64 { + unsafe { m_sys::frexp(n, exp) } +} + +#[inline(always)] +pub fn ldexp(x: f64, n: i32) -> f64 { + unsafe { m_sys::ldexp(x, n) } +} + +#[inline(always)] +pub fn modf(n: f64, iptr: &mut f64) -> f64 { + unsafe { m_sys::modf(n, iptr) } +} + +#[inline(always)] +pub fn nextafter(x: f64, y: f64) -> f64 { + unsafe { m_sys::nextafter(x, y) } +} + +#[inline(always)] +pub fn remainder(x: f64, y: f64) -> f64 { + unsafe { m_sys::remainder(x, y) } +} + +#[inline(always)] +pub fn trunc(n: f64) -> f64 { + unsafe { m_sys::trunc(n) } +} + +// Special functions + +#[inline(always)] +pub fn erf(n: f64) -> f64 { + unsafe { m_sys::erf(n) } +} + +#[inline(always)] +pub fn erfc(n: f64) -> f64 { + unsafe { m_sys::erfc(n) } +} + +// #[inline(always)] +// pub fn lgamma_r(n: f64, s: &mut i32) -> f64 { +// unsafe { m_sys::lgamma_r(n, s) } +// } + +// #[inline(always)] +// pub fn tgamma(n: f64) -> f64 { +// unsafe { m_sys::tgamma(n) } +// } + +// Platform-specific sincos + +/// Result type for sincos function (matches Apple's __double2) +#[cfg(all(feature = "complex", target_os = "macos"))] +#[repr(C)] +struct SinCosResult { + sin: f64, + cos: f64, +} + +#[cfg(all(feature = "complex", target_os = "macos"))] unsafe extern "C" { - pub fn acos(n: f64) -> f64; - pub fn asin(n: f64) -> f64; - pub fn atan(n: f64) -> f64; - pub fn atan2(a: f64, b: f64) -> f64; - pub fn cbrt(n: f64) -> f64; - pub fn cbrtf(n: f32) -> f32; - pub fn cosh(n: f64) -> f64; - pub fn expm1(n: f64) -> f64; - pub fn expm1f(n: f32) -> f32; - pub fn fdim(a: f64, b: f64) -> f64; - pub fn fdimf(a: f32, b: f32) -> f32; - #[cfg_attr(target_env = "msvc", link_name = "_hypot")] - pub fn hypot(x: f64, y: f64) -> f64; - #[cfg_attr(target_env = "msvc", link_name = "_hypotf")] - pub fn hypotf(x: f32, y: f32) -> f32; - pub fn log1p(n: f64) -> f64; - pub fn log1pf(n: f32) -> f32; - pub fn sinh(n: f64) -> f64; - pub fn tan(n: f64) -> f64; - pub fn tanh(n: f64) -> f64; - pub fn tgamma(n: f64) -> f64; - pub fn tgammaf(n: f32) -> f32; - pub fn lgamma_r(n: f64, s: &mut i32) -> f64; - #[cfg(not(target_os = "aix"))] - pub fn lgammaf_r(n: f32, s: &mut i32) -> f32; - pub fn erf(n: f64) -> f64; - pub fn erff(n: f32) -> f32; - pub fn erfc(n: f64) -> f64; - pub fn erfcf(n: f32) -> f32; - - // pub fn acosf128(n: f128) -> f128; - // pub fn asinf128(n: f128) -> f128; - // pub fn atanf128(n: f128) -> f128; - // pub fn atan2f128(a: f128, b: f128) -> f128; - // pub fn cbrtf128(n: f128) -> f128; - // pub fn coshf128(n: f128) -> f128; - // pub fn expm1f128(n: f128) -> f128; - // pub fn hypotf128(x: f128, y: f128) -> f128; - // pub fn log1pf128(n: f128) -> f128; - // pub fn sinhf128(n: f128) -> f128; - // pub fn tanf128(n: f128) -> f128; - // pub fn tanhf128(n: f128) -> f128; - // pub fn tgammaf128(n: f128) -> f128; - // pub fn lgammaf128_r(n: f128, s: &mut i32) -> f128; - // pub fn erff128(n: f128) -> f128; - // pub fn erfcf128(n: f128) -> f128; - - // cfg_if::cfg_if! { - // if #[cfg(not(all(target_os = "windows", target_env = "msvc", target_arch = "x86")))] { - // pub fn acosf(n: f32) -> f32; - // pub fn asinf(n: f32) -> f32; - // pub fn atan2f(a: f32, b: f32) -> f32; - // pub fn atanf(n: f32) -> f32; - // pub fn coshf(n: f32) -> f32; - // pub fn sinhf(n: f32) -> f32; - // pub fn tanf(n: f32) -> f32; - // pub fn tanhf(n: f32) -> f32; - // }} + #[link_name = "__sincos_stret"] + fn sincos_stret(x: f64) -> SinCosResult; +} + +/// Compute sin and cos together using Apple's optimized sincos. +/// This matches Python's cmath behavior on macOS. +#[cfg(all(feature = "complex", target_os = "macos"))] +#[inline(always)] +pub fn sincos(x: f64) -> (f64, f64) { + let sc = unsafe { sincos_stret(x) }; + (sc.sin, sc.cos) +} + +/// Fallback for non-macOS: call sin and cos separately +#[cfg(all(feature = "complex", not(target_os = "macos")))] +#[inline(always)] +pub fn sincos(x: f64) -> (f64, f64) { + (sin(x), cos(x)) } diff --git a/src/m_sys.rs b/src/m_sys.rs new file mode 100644 index 0000000..60406a8 --- /dev/null +++ b/src/m_sys.rs @@ -0,0 +1,98 @@ +// These symbols are all defined by `libm`, +// or by `compiler-builtins` on unsupported platforms. +#[cfg_attr(unix, link(name = "m"))] +#[allow(dead_code)] +unsafe extern "C" { + // Trigonometric functions + pub fn acos(n: f64) -> f64; + pub fn asin(n: f64) -> f64; + pub fn atan(n: f64) -> f64; + pub fn atan2(a: f64, b: f64) -> f64; + pub fn cos(n: f64) -> f64; + pub fn sin(n: f64) -> f64; + pub fn tan(n: f64) -> f64; + + // Hyperbolic functions + pub fn acosh(n: f64) -> f64; + pub fn asinh(n: f64) -> f64; + pub fn atanh(n: f64) -> f64; + pub fn cosh(n: f64) -> f64; + pub fn sinh(n: f64) -> f64; + pub fn tanh(n: f64) -> f64; + + // Exponential and logarithmic functions + pub fn exp(n: f64) -> f64; + pub fn exp2(n: f64) -> f64; + pub fn expm1(n: f64) -> f64; + pub fn expm1f(n: f32) -> f32; + pub fn log(n: f64) -> f64; + pub fn log10(n: f64) -> f64; + pub fn log1p(n: f64) -> f64; + pub fn log1pf(n: f32) -> f32; + pub fn log2(n: f64) -> f64; + + // Power functions + pub fn cbrt(n: f64) -> f64; + pub fn cbrtf(n: f32) -> f32; + #[cfg_attr(target_env = "msvc", link_name = "_hypot")] + pub fn hypot(x: f64, y: f64) -> f64; + #[cfg_attr(target_env = "msvc", link_name = "_hypotf")] + pub fn hypotf(x: f32, y: f32) -> f32; + pub fn pow(x: f64, y: f64) -> f64; + pub fn sqrt(n: f64) -> f64; + + // Floating-point manipulation functions + pub fn ceil(n: f64) -> f64; + pub fn copysign(x: f64, y: f64) -> f64; + pub fn fabs(n: f64) -> f64; + // pub fn fdim(a: f64, b: f64) -> f64; + pub fn fdimf(a: f32, b: f32) -> f32; + pub fn floor(n: f64) -> f64; + pub fn fmod(x: f64, y: f64) -> f64; + pub fn frexp(n: f64, exp: *mut i32) -> f64; + pub fn ldexp(x: f64, n: i32) -> f64; + pub fn modf(n: f64, iptr: *mut f64) -> f64; + pub fn nextafter(x: f64, y: f64) -> f64; + pub fn remainder(x: f64, y: f64) -> f64; + pub fn trunc(n: f64) -> f64; + + // Special functions + pub fn erf(n: f64) -> f64; + pub fn erfc(n: f64) -> f64; + pub fn erff(n: f32) -> f32; + pub fn erfcf(n: f32) -> f32; + // pub fn lgamma_r(n: f64, s: &mut i32) -> f64; + #[cfg(not(target_os = "aix"))] + pub fn lgammaf_r(n: f32, s: &mut i32) -> f32; + // pub fn tgamma(n: f64) -> f64; + pub fn tgammaf(n: f32) -> f32; + + // pub fn acosf128(n: f128) -> f128; + // pub fn asinf128(n: f128) -> f128; + // pub fn atanf128(n: f128) -> f128; + // pub fn atan2f128(a: f128, b: f128) -> f128; + // pub fn cbrtf128(n: f128) -> f128; + // pub fn coshf128(n: f128) -> f128; + // pub fn expm1f128(n: f128) -> f128; + // pub fn hypotf128(x: f128, y: f128) -> f128; + // pub fn log1pf128(n: f128) -> f128; + // pub fn sinhf128(n: f128) -> f128; + // pub fn tanf128(n: f128) -> f128; + // pub fn tanhf128(n: f128) -> f128; + // pub fn tgammaf128(n: f128) -> f128; + // pub fn lgammaf128_r(n: f128, s: &mut i32) -> f128; + // pub fn erff128(n: f128) -> f128; + // pub fn erfcf128(n: f128) -> f128; + + // cfg_if::cfg_if! { + // if #[cfg(not(all(target_os = "windows", target_env = "msvc", target_arch = "x86")))] { + // pub fn acosf(n: f32) -> f32; + // pub fn asinf(n: f32) -> f32; + // pub fn atan2f(a: f32, b: f32) -> f32; + // pub fn atanf(n: f32) -> f32; + // pub fn coshf(n: f32) -> f32; + // pub fn sinhf(n: f32) -> f32; + // pub fn tanf(n: f32) -> f32; + // pub fn tanhf(n: f32) -> f32; + // }} +} diff --git a/src/math.rs b/src/math.rs new file mode 100644 index 0000000..e65dece --- /dev/null +++ b/src/math.rs @@ -0,0 +1,241 @@ +//! Real number mathematical functions matching Python's math module. + +// Submodules +mod aggregate; +mod exponential; +mod gamma; +#[cfg(feature = "_bigint")] +pub mod integer; +mod misc; +mod trigonometric; + +// Re-export from submodules +pub use aggregate::{dist, fsum, prod, sumprod}; +pub use exponential::{cbrt, exp, exp2, expm1, log, log1p, log2, log10, pow, sqrt}; +pub use gamma::{erf, erfc, gamma, lgamma}; +pub use misc::{ + ceil, copysign, fabs, floor, fmod, frexp, isclose, isfinite, isinf, isnan, ldexp, modf, + nextafter, remainder, trunc, ulp, +}; +pub use trigonometric::{ + acos, acosh, asin, asinh, atan, atan2, atanh, cos, cosh, sin, sinh, tan, tanh, +}; + +/// Simple libm wrapper macro for functions that don't need errno handling. +macro_rules! libm_simple { + // 1-arg: (f64) -> f64 + (@1 $($name:ident),* $(,)?) => { + $( + #[inline] + pub fn $name(x: f64) -> f64 { + crate::m::$name(x) + } + )* + }; + // 2-arg: (f64, f64) -> f64 + (@2 $($name:ident),* $(,)?) => { + $( + #[inline] + pub fn $name(x: f64, y: f64) -> f64 { + crate::m::$name(x, y) + } + )* + }; +} + +pub(crate) use libm_simple; + +/// math_1: wrapper for 1-arg functions +/// - isnan(r) && !isnan(x) -> domain error +/// - isinf(r) && isfinite(x) -> overflow (can_overflow=true) or domain error (can_overflow=false) +/// - isfinite(r) && errno -> check errno (unnecessary on most platforms) +#[inline] +pub(crate) fn math_1(x: f64, func: fn(f64) -> f64, can_overflow: bool) -> crate::Result { + crate::err::set_errno(0); + let r = func(x); + if r.is_nan() && !x.is_nan() { + return Err(crate::Error::EDOM); + } + if r.is_infinite() && x.is_finite() { + return Err(if can_overflow { + crate::Error::ERANGE + } else { + crate::Error::EDOM + }); + } + // This branch unnecessary on most platforms + #[cfg(not(any(windows, target_os = "macos")))] + if r.is_finite() && crate::err::get_errno() != 0 { + return crate::err::is_error(r); + } + Ok(r) +} + +/// math_2: wrapper for 2-arg functions +/// - isnan(r) && !isnan(x) && !isnan(y) -> domain error +/// - isinf(r) && isfinite(x) && isfinite(y) -> range error +#[inline] +pub(crate) fn math_2(x: f64, y: f64, func: fn(f64, f64) -> f64) -> crate::Result { + let r = func(x, y); + if r.is_nan() && !x.is_nan() && !y.is_nan() { + return Err(crate::Error::EDOM); + } + if r.is_infinite() && x.is_finite() && y.is_finite() { + return Err(crate::Error::ERANGE); + } + Ok(r) +} + +/// math_1a: wrapper for 1-arg functions that set errno properly. +/// Used when the libm function is known to set errno correctly +/// (EDOM for invalid, ERANGE for overflow). +#[inline] +pub(crate) fn math_1a(x: f64, func: fn(f64) -> f64) -> crate::Result { + crate::err::set_errno(0); + let r = func(x); + crate::err::is_error(r) +} + +/// Return the Euclidean distance, sqrt(x*x + y*y). +/// +/// Uses high-precision vector_norm algorithm instead of libm hypot() +/// for consistent results across platforms and better handling of overflow/underflow. +#[inline] +pub fn hypot(x: f64, y: f64) -> f64 { + let ax = x.abs(); + let ay = y.abs(); + let max = if ax > ay { ax } else { ay }; + let found_nan = x.is_nan() || y.is_nan(); + aggregate::vector_norm_2(ax, ay, max, found_nan) +} + +// Mathematical constants + +/// The mathematical constant π = 3.141592... +pub const PI: f64 = std::f64::consts::PI; + +/// The mathematical constant e = 2.718281... +pub const E: f64 = std::f64::consts::E; + +/// The mathematical constant τ = 6.283185... +pub const TAU: f64 = std::f64::consts::TAU; + +/// Positive infinity. +pub const INF: f64 = f64::INFINITY; + +/// A floating point "not a number" (NaN) value. +pub const NAN: f64 = f64::NAN; + +// Angle conversion functions + +/// Convert angle x from radians to degrees. +#[inline] +pub fn degrees(x: f64) -> f64 { + x * (180.0 / PI) +} + +/// Convert angle x from degrees to radians. +#[inline] +pub fn radians(x: f64) -> f64 { + x * (PI / 180.0) +} + +#[cfg(test)] +mod tests { + use super::*; + use pyo3::prelude::*; + + // Angle conversion tests + fn test_degrees(x: f64) { + let rs_result = Ok(degrees(x)); + + pyo3::Python::attach(|py| { + let math = PyModule::import(py, "math").unwrap(); + let py_func = math.getattr("degrees").unwrap(); + let r = py_func.call1((x,)); + let Some((py_result, rs_result)) = crate::test::unwrap(py, r, rs_result) else { + return; + }; + let py_result_repr = py_result.to_bits(); + let rs_result_repr = rs_result.to_bits(); + assert_eq!( + py_result_repr, rs_result_repr, + "x = {x}, py_result = {py_result}, rs_result = {rs_result}" + ); + }); + } + + fn test_radians(x: f64) { + let rs_result = Ok(radians(x)); + + pyo3::Python::attach(|py| { + let math = PyModule::import(py, "math").unwrap(); + let py_func = math.getattr("radians").unwrap(); + let r = py_func.call1((x,)); + let Some((py_result, rs_result)) = crate::test::unwrap(py, r, rs_result) else { + return; + }; + let py_result_repr = py_result.to_bits(); + let rs_result_repr = rs_result.to_bits(); + assert_eq!( + py_result_repr, rs_result_repr, + "x = {x}, py_result = {py_result}, rs_result = {rs_result}" + ); + }); + } + + #[test] + fn edgetest_degrees() { + for &x in &crate::test::EDGE_VALUES { + test_degrees(x); + } + } + + #[test] + fn edgetest_radians() { + for &x in &crate::test::EDGE_VALUES { + test_radians(x); + } + } + + // Constants test + #[test] + fn test_constants() { + assert!((PI - 3.141592653589793).abs() < 1e-15); + assert!((E - 2.718281828459045).abs() < 1e-15); + assert!((TAU - 6.283185307179586).abs() < 1e-15); + assert!(INF.is_infinite() && INF > 0.0); + assert!(NAN.is_nan()); + } + + // hypot tests + fn test_hypot(x: f64, y: f64) { + crate::test::test_math_2(x, y, "hypot", |x, y| Ok(hypot(x, y))); + } + + #[test] + fn edgetest_hypot() { + for &x in &crate::test::EDGE_VALUES { + for &y in &crate::test::EDGE_VALUES { + test_hypot(x, y); + } + } + } + + proptest::proptest! { + #[test] + fn proptest_degrees(x: f64) { + test_degrees(x); + } + + #[test] + fn proptest_radians(x: f64) { + test_radians(x); + } + + #[test] + fn proptest_hypot(x: f64, y: f64) { + test_hypot(x, y); + } + } +} diff --git a/src/math/aggregate.rs b/src/math/aggregate.rs new file mode 100644 index 0000000..2676d60 --- /dev/null +++ b/src/math/aggregate.rs @@ -0,0 +1,574 @@ +//! Aggregate functions for sequences. + +/// Double-length number represented as hi + lo +#[derive(Clone, Copy)] +struct DoubleLength { + hi: f64, + lo: f64, +} + +/// Algorithm 1.1. Compensated summation of two floating-point numbers. +/// Requires: |a| >= |b| +#[inline] +fn dl_fast_sum(a: f64, b: f64) -> DoubleLength { + debug_assert!(a.abs() >= b.abs()); + let x = a + b; + let y = (a - x) + b; + DoubleLength { hi: x, lo: y } +} + +/// Algorithm 3.1 Error-free transformation of the sum +#[inline] +fn dl_sum(a: f64, b: f64) -> DoubleLength { + let x = a + b; + let z = x - a; + let y = (a - (x - z)) + (b - z); + DoubleLength { hi: x, lo: y } +} + +/// Algorithm 3.5. Error-free transformation of a product using FMA +#[inline] +fn dl_mul(x: f64, y: f64) -> DoubleLength { + let z = x * y; + let zz = x.mul_add(y, -z); + DoubleLength { hi: z, lo: zz } +} + +/// Triple-length number for extra precision +#[derive(Clone, Copy)] +struct TripleLength { + hi: f64, + lo: f64, + tiny: f64, +} + +const TL_ZERO: TripleLength = TripleLength { + hi: 0.0, + lo: 0.0, + tiny: 0.0, +}; + +/// Algorithm 5.10 with SumKVert for K=3 +#[inline] +fn tl_fma(x: f64, y: f64, total: TripleLength) -> TripleLength { + let pr = dl_mul(x, y); + let sm = dl_sum(total.hi, pr.hi); + let r1 = dl_sum(total.lo, pr.lo); + let r2 = dl_sum(r1.hi, sm.lo); + TripleLength { + hi: sm.hi, + lo: r2.hi, + tiny: total.tiny + r1.lo + r2.lo, + } +} + +#[inline] +fn tl_to_d(total: TripleLength) -> f64 { + let last = dl_sum(total.lo, total.hi); + total.tiny + last.lo + last.hi +} + +// FSUM - Shewchuk's algorithm + +const NUM_PARTIALS: usize = 32; + +/// Return an accurate floating-point sum of values in the iterable. +/// +/// Uses Shewchuk's algorithm for full precision summation. +/// Assumes IEEE-754 floating-point arithmetic. +/// +/// Returns ERANGE for intermediate overflow, EDOM for -inf + inf. +pub fn fsum(iter: impl IntoIterator) -> crate::Result { + let mut p: Vec = Vec::with_capacity(NUM_PARTIALS); + let mut special_sum = 0.0; + let mut inf_sum = 0.0; + + for x in iter { + let xsave = x; + let mut x = x; + let mut i = 0; + + for j in 0..p.len() { + let mut y = p[j]; + if x.abs() < y.abs() { + std::mem::swap(&mut x, &mut y); + } + let hi = x + y; + let yr = hi - x; + let lo = y - yr; + if lo != 0.0 { + p[i] = lo; + i += 1; + } + x = hi; + } + + p.truncate(i); + if x != 0.0 { + if !x.is_finite() { + // a nonfinite x could arise either as a result of + // intermediate overflow, or as a result of a nan or inf + // in the summands + if xsave.is_finite() { + // intermediate overflow + return Err(crate::Error::ERANGE); + } + if xsave.is_infinite() { + inf_sum += xsave; + } + special_sum += xsave; + // reset partials + p.clear(); + } else { + p.push(x); + } + } + } + + if special_sum != 0.0 { + if inf_sum.is_nan() { + // -inf + inf + return Err(crate::Error::EDOM); + } + return Ok(special_sum); + } + + let n = p.len(); + let mut hi = 0.0; + let mut lo = 0.0; + + if n > 0 { + let mut idx = n - 1; + hi = p[idx]; + + // sum_exact(ps, hi) from the top, stop when the sum becomes inexact + while idx > 0 { + idx -= 1; + let x = hi; + let y = p[idx]; + hi = x + y; + let yr = hi - x; + lo = y - yr; + if lo != 0.0 { + break; + } + } + + // Make half-even rounding work across multiple partials. + if idx > 0 && ((lo < 0.0 && p[idx - 1] < 0.0) || (lo > 0.0 && p[idx - 1] > 0.0)) { + let y = lo * 2.0; + let x = hi + y; + let yr = x - hi; + if y == yr { + hi = x; + } + } + } + + Ok(hi) +} + +// VECTOR_NORM - for dist and hypot + +/// Compute the Euclidean norm of two values with high precision. +/// Optimized version for hypot(x, y). +pub(super) fn vector_norm_2(x: f64, y: f64, max: f64, found_nan: bool) -> f64 { + // Check for infinity first (inf wins over nan) + if x.is_infinite() || y.is_infinite() { + return f64::INFINITY; + } + if found_nan { + return f64::NAN; + } + if max == 0.0 { + return 0.0; + } + // n == 1 case: only one non-zero value + if x == 0.0 || y == 0.0 { + return max; + } + + let mut max_e: i32 = 0; + crate::m::frexp(max, &mut max_e); + + if max_e < -1023 { + // When max_e < -1023, ldexp(1.0, -max_e) would overflow + return f64::MIN_POSITIVE + * vector_norm_2( + x / f64::MIN_POSITIVE, + y / f64::MIN_POSITIVE, + max / f64::MIN_POSITIVE, + found_nan, + ); + } + + let scale = crate::m::ldexp(1.0, -max_e); + debug_assert!(max * scale >= 0.5); + debug_assert!(max * scale < 1.0); + + let mut csum = 1.0; + let mut frac1 = 0.0; + let mut frac2 = 0.0; + + // Process x + let xs = x * scale; + debug_assert!(xs.abs() < 1.0); + let pr = dl_mul(xs, xs); + debug_assert!(pr.hi <= 1.0); + let sm = dl_fast_sum(csum, pr.hi); + csum = sm.hi; + frac1 += pr.lo; + frac2 += sm.lo; + + // Process y + let ys = y * scale; + debug_assert!(ys.abs() < 1.0); + let pr = dl_mul(ys, ys); + debug_assert!(pr.hi <= 1.0); + let sm = dl_fast_sum(csum, pr.hi); + csum = sm.hi; + frac1 += pr.lo; + frac2 += sm.lo; + + let mut h = (csum - 1.0 + (frac1 + frac2)).sqrt(); + let pr = dl_mul(-h, h); + let sm = dl_fast_sum(csum, pr.hi); + csum = sm.hi; + frac1 += pr.lo; + frac2 += sm.lo; + let x = csum - 1.0 + (frac1 + frac2); + h += x / (2.0 * h); // differential correction + + h / scale +} + +/// Compute the Euclidean norm of a vector with high precision. +fn vector_norm(vec: &[f64], max: f64, found_nan: bool) -> f64 { + let n = vec.len(); + + if max.is_infinite() { + return max; + } + if found_nan { + return f64::NAN; + } + if max == 0.0 || n <= 1 { + return max; + } + + let mut max_e: i32 = 0; + crate::m::frexp(max, &mut max_e); + + if max_e < -1023 { + // When max_e < -1023, ldexp(1.0, -max_e) would overflow. + // TODO: This can be in-place ops, but we allocate a copy since we take &[f64]. + // This is acceptable because subnormal inputs are extremely rare in practice. + let vec_copy: Vec = vec.iter().map(|&x| x / f64::MIN_POSITIVE).collect(); + return f64::MIN_POSITIVE * vector_norm(&vec_copy, max / f64::MIN_POSITIVE, found_nan); + } + + let scale = crate::m::ldexp(1.0, -max_e); + debug_assert!(max * scale >= 0.5); + debug_assert!(max * scale < 1.0); + + let mut csum = 1.0; + let mut frac1 = 0.0; + let mut frac2 = 0.0; + + for &v in vec { + debug_assert!(v.is_finite() && v.abs() <= max); + let x = v * scale; // lossless scaling + debug_assert!(x.abs() < 1.0); + let pr = dl_mul(x, x); // lossless squaring + debug_assert!(pr.hi <= 1.0); + let sm = dl_fast_sum(csum, pr.hi); // lossless addition + csum = sm.hi; + frac1 += pr.lo; // lossy addition + frac2 += sm.lo; // lossy addition + } + + let mut h = (csum - 1.0 + (frac1 + frac2)).sqrt(); + let pr = dl_mul(-h, h); + let sm = dl_fast_sum(csum, pr.hi); + csum = sm.hi; + frac1 += pr.lo; + frac2 += sm.lo; + let x = csum - 1.0 + (frac1 + frac2); + h += x / (2.0 * h); // differential correction + + h / scale +} + +/// Return the Euclidean distance between two points. +/// +/// The points are given as sequences of coordinates. +/// Uses high-precision vector_norm algorithm. +pub fn dist(p: &[f64], q: &[f64]) -> f64 { + assert_eq!( + p.len(), + q.len(), + "both points must have the same number of dimensions" + ); + + let n = p.len(); + if n == 0 { + return 0.0; + } + + let mut max = 0.0; + let mut found_nan = false; + let mut diffs: Vec = Vec::with_capacity(n); + + for i in 0..n { + let x = (p[i] - q[i]).abs(); + diffs.push(x); + found_nan |= x.is_nan(); + if x > max { + max = x; + } + } + + vector_norm(&diffs, max, found_nan) +} + +/// Return the sum of products of values from two sequences. +/// +/// Uses TripleLength arithmetic for high precision. +/// Equivalent to sum(p[i] * q[i] for i in range(len(p))). +pub fn sumprod(p: &[f64], q: &[f64]) -> f64 { + assert_eq!(p.len(), q.len(), "Inputs are not the same length"); + + let mut flt_total = TL_ZERO; + + for (&pi, &qi) in p.iter().zip(q.iter()) { + let new_flt_total = tl_fma(pi, qi, flt_total); + if new_flt_total.hi.is_finite() { + flt_total = new_flt_total; + } else { + // Overflow or special value, fall back to simple sum + return p.iter().zip(q.iter()).map(|(a, b)| a * b).sum(); + } + } + + tl_to_d(flt_total) +} + +/// Return the product of all elements in the iterable. +/// +/// If start is None, uses 1.0 as the start value. +pub fn prod(iter: impl IntoIterator, start: Option) -> f64 { + let mut result = start.unwrap_or(1.0); + for x in iter { + result *= x; + } + result +} + +#[cfg(test)] +mod tests { + use super::*; + use pyo3::prelude::*; + + fn test_fsum_impl(values: &[f64]) { + let rs_result = fsum(values.iter().copied()); + + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + let py_func = math.getattr("fsum").unwrap(); + let py_list = pyo3::types::PyList::new(py, values).unwrap(); + let r = py_func.call1((py_list,)); + + match r { + Ok(py_val) => { + let py_result: f64 = py_val.extract().unwrap(); + let rs_val = rs_result.unwrap_or_else(|e| { + panic!( + "fsum({:?}): py={} but rs returned error {:?}", + values, py_result, e + ) + }); + if py_result.is_nan() && rs_val.is_nan() { + return; + } + assert_eq!( + py_result.to_bits(), + rs_val.to_bits(), + "fsum({:?}): py={} vs rs={}", + values, + py_result, + rs_val + ); + } + Err(e) => { + let rs_err = rs_result.as_ref().err(); + if e.is_instance_of::(py) { + assert_eq!( + rs_err, + Some(&crate::Error::EDOM), + "fsum({:?}): py raised ValueError but rs={:?}", + values, + rs_err + ); + } else if e.is_instance_of::(py) { + assert_eq!( + rs_err, + Some(&crate::Error::ERANGE), + "fsum({:?}): py raised OverflowError but rs={:?}", + values, + rs_err + ); + } else { + panic!("fsum({:?}): py raised unexpected error {}", values, e); + } + } + } + }); + } + + #[test] + fn test_fsum() { + test_fsum_impl(&[1.0, 2.0, 3.0]); + test_fsum_impl(&[]); + test_fsum_impl(&[0.1, 0.2, 0.3]); + test_fsum_impl(&[1e100, 1.0, -1e100, 1e-100, 1e50, -1e50]); + test_fsum_impl(&[f64::INFINITY, 1.0]); + test_fsum_impl(&[f64::NEG_INFINITY, 1.0]); + test_fsum_impl(&[f64::INFINITY, f64::NEG_INFINITY]); // -inf + inf -> ValueError (EDOM) + test_fsum_impl(&[f64::NAN, 1.0]); + // Intermediate overflow cases + test_fsum_impl(&[1e308, 1e308]); // intermediate overflow -> OverflowError (ERANGE) + test_fsum_impl(&[1e308, 1e308, -1e308]); // intermediate overflow + } + + fn test_dist_impl(p: &[f64], q: &[f64]) { + let rs_result = dist(p, q); + + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + let py_func = math.getattr("dist").unwrap(); + let py_p = pyo3::types::PyList::new(py, p).unwrap(); + let py_q = pyo3::types::PyList::new(py, q).unwrap(); + let py_result: f64 = py_func.call1((py_p, py_q)).unwrap().extract().unwrap(); + + if py_result.is_nan() && rs_result.is_nan() { + return; + } + assert_eq!( + py_result.to_bits(), + rs_result.to_bits(), + "dist({:?}, {:?}): py={} vs rs={}", + p, + q, + py_result, + rs_result + ); + }); + } + + #[test] + fn test_dist() { + test_dist_impl(&[0.0, 0.0], &[3.0, 4.0]); // 3-4-5 triangle + test_dist_impl(&[1.0, 2.0], &[1.0, 2.0]); // same point + test_dist_impl(&[0.0], &[5.0]); // 1D + test_dist_impl(&[0.0, 0.0, 0.0], &[1.0, 1.0, 1.0]); // 3D + } + + fn test_sumprod_impl(p: &[f64], q: &[f64]) { + let rs_result = sumprod(p, q); + + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + let py_func = math.getattr("sumprod").unwrap(); + let py_p = pyo3::types::PyList::new(py, p).unwrap(); + let py_q = pyo3::types::PyList::new(py, q).unwrap(); + let py_result: f64 = py_func.call1((py_p, py_q)).unwrap().extract().unwrap(); + + if py_result.is_nan() && rs_result.is_nan() { + return; + } + assert_eq!( + py_result.to_bits(), + rs_result.to_bits(), + "sumprod({:?}, {:?}): py={} vs rs={}", + p, + q, + py_result, + rs_result + ); + }); + } + + #[test] + fn test_sumprod() { + test_sumprod_impl(&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]); + test_sumprod_impl(&[], &[]); + test_sumprod_impl(&[1.0], &[2.0]); + test_sumprod_impl(&[1e100, 1e100], &[1e100, -1e100]); + } + + fn test_prod_impl(values: &[f64], start: Option) { + let rs_result = prod(values.iter().copied(), start); + + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + let py_func = math.getattr("prod").unwrap(); + let py_list = pyo3::types::PyList::new(py, values).unwrap(); + let py_result: f64 = match start { + Some(s) => { + let kwargs = pyo3::types::PyDict::new(py); + kwargs.set_item("start", s).unwrap(); + py_func + .call((py_list,), Some(&kwargs)) + .unwrap() + .extract() + .unwrap() + } + None => py_func.call1((py_list,)).unwrap().extract().unwrap(), + }; + + if py_result.is_nan() && rs_result.is_nan() { + return; + } + assert_eq!( + py_result.to_bits(), + rs_result.to_bits(), + "prod({:?}, {:?}): py={} vs rs={}", + values, + start, + py_result, + rs_result + ); + }); + } + + #[test] + fn test_prod() { + test_prod_impl(&[1.0, 2.0, 3.0, 4.0], None); + test_prod_impl(&[2.0, 3.0], None); + test_prod_impl(&[], None); + test_prod_impl(&[1.0, 2.0, 3.0], Some(2.0)); + test_prod_impl(&[], Some(5.0)); + } + + proptest::proptest! { + #[test] + fn proptest_fsum(v1: f64, v2: f64, v3: f64, v4: f64) { + test_fsum_impl(&[v1, v2, v3, v4]); + } + + #[test] + fn proptest_dist(p1: f64, p2: f64, q1: f64, q2: f64) { + test_dist_impl(&[p1, p2], &[q1, q2]); + } + + #[test] + fn proptest_sumprod(p1: f64, p2: f64, q1: f64, q2: f64) { + test_sumprod_impl(&[p1, p2], &[q1, q2]); + } + + #[test] + fn proptest_prod(v1: f64, v2: f64, v3: f64) { + test_prod_impl(&[v1, v2, v3], None); + } + } +} diff --git a/src/math/exponential.rs b/src/math/exponential.rs new file mode 100644 index 0000000..7a817a5 --- /dev/null +++ b/src/math/exponential.rs @@ -0,0 +1,406 @@ +//! Exponential, logarithmic, and power functions. + +use crate::Result; + +// Exponential functions + +/// Return e raised to the power of x. +#[inline] +pub fn exp(x: f64) -> Result { + super::math_1(x, crate::m::exp, true) +} + +/// Return 2 raised to the power of x. +#[inline] +pub fn exp2(x: f64) -> Result { + super::math_1(x, crate::m::exp2, true) +} + +/// Return exp(x) - 1. +#[inline] +pub fn expm1(x: f64) -> Result { + super::math_1(x, crate::m::expm1, true) +} + +// Logarithmic functions + +/// m_log: log implementation +#[inline] +fn m_log(x: f64) -> f64 { + if x.is_finite() { + if x > 0.0 { + crate::m::log(x) + } else if x == 0.0 { + f64::NEG_INFINITY // log(0) = -inf + } else { + f64::NAN // log(-ve) = nan + } + } else if x.is_nan() || x > 0.0 { + x // log(nan) = nan, log(inf) = inf + } else { + f64::NAN // log(-inf) = nan + } +} + +/// m_log10: log10 implementation +#[inline] +fn m_log10(x: f64) -> f64 { + if x.is_finite() { + if x > 0.0 { + crate::m::log10(x) + } else if x == 0.0 { + f64::NEG_INFINITY // log10(0) = -inf + } else { + f64::NAN // log10(-ve) = nan + } + } else if x.is_nan() || x > 0.0 { + x // log10(nan) = nan, log10(inf) = inf + } else { + f64::NAN // log10(-inf) = nan + } +} + +/// m_log2: log2 implementation +#[inline] +fn m_log2(x: f64) -> f64 { + if !x.is_finite() { + if x.is_nan() || x > 0.0 { + x // log2(nan) = nan, log2(+inf) = +inf + } else { + f64::NAN // log2(-inf) = nan + } + } else if x > 0.0 { + crate::m::log2(x) + } else if x == 0.0 { + f64::NEG_INFINITY // log2(0) = -inf + } else { + f64::NAN // log2(-ve) = nan + } +} + +/// m_log1p: CPython's m_log1p implementation +#[inline] +fn m_log1p(x: f64) -> f64 { + // For x > -1, log1p is well-defined + // For x == -1, result is -inf + // For x < -1, result is nan + if x.is_nan() { + return x; + } + crate::m::log1p(x) +} + +/// Return the logarithm of x to the given base. +#[inline] +pub fn log(x: f64, base: Option) -> Result { + let num = m_log(x); + + // math_1 logic: check for domain errors + if num.is_nan() && !x.is_nan() { + return Err(crate::Error::EDOM); + } + if num.is_infinite() && x.is_finite() { + return Err(crate::Error::EDOM); + } + + match base { + None => Ok(num), + Some(b) => { + let den = m_log(b); + if den.is_nan() && !b.is_nan() { + return Err(crate::Error::EDOM); + } + if den.is_infinite() && b.is_finite() { + return Err(crate::Error::EDOM); + } + // log(x, 1) -> division by zero + if den == 0.0 { + return Err(crate::Error::EDOM); + } + Ok(num / den) + } + } +} + +/// math_1 for Rust functions (m_log10, m_log2, m_log1p) +#[inline] +fn math_1_fn(x: f64, func: fn(f64) -> f64, can_overflow: bool) -> Result { + let r = func(x); + if r.is_nan() && !x.is_nan() { + return Err(crate::Error::EDOM); + } + if r.is_infinite() && x.is_finite() { + return Err(if can_overflow { + crate::Error::ERANGE + } else { + crate::Error::EDOM + }); + } + Ok(r) +} + +/// Return the base-10 logarithm of x. +#[inline] +pub fn log10(x: f64) -> Result { + math_1_fn(x, m_log10, false) +} + +/// Return the base-2 logarithm of x. +#[inline] +pub fn log2(x: f64) -> Result { + math_1_fn(x, m_log2, false) +} + +/// Return the natural logarithm of 1+x (base e). +#[inline] +pub fn log1p(x: f64) -> Result { + math_1_fn(x, m_log1p, false) +} + +// Power functions + +/// Return the square root of x. +#[inline] +pub fn sqrt(x: f64) -> Result { + super::math_1(x, crate::m::sqrt, false) +} + +/// Return the cube root of x. +#[inline] +pub fn cbrt(x: f64) -> Result { + super::math_1(x, crate::m::cbrt, false) +} + +/// Return x raised to the power y. +#[inline] +pub fn pow(x: f64, y: f64) -> Result { + // Deal directly with IEEE specials, to cope with problems on various + // platforms whose semantics don't exactly match C99 + if !x.is_finite() || !y.is_finite() { + if x.is_nan() { + // NaN**0 = 1 + return Ok(if y == 0.0 { 1.0 } else { x }); + } else if y.is_nan() { + // 1**NaN = 1 + return Ok(if x == 1.0 { 1.0 } else { y }); + } else if x.is_infinite() { + let odd_y = y.is_finite() && crate::m::fmod(y.abs(), 2.0) == 1.0; + if y > 0.0 { + return Ok(if odd_y { x } else { x.abs() }); + } else if y == 0.0 { + return Ok(1.0); + } else { + // y < 0 + return Ok(if odd_y { + crate::m::copysign(0.0, x) + } else { + 0.0 + }); + } + } else { + // y is infinite + debug_assert!(y.is_infinite()); + if x.abs() == 1.0 { + return Ok(1.0); + } else if y > 0.0 && x.abs() > 1.0 { + return Ok(y); + } else if y < 0.0 && x.abs() < 1.0 { + return Ok(-y); // result is +inf + } else { + return Ok(0.0); + } + } + } + + // Let libm handle finite**finite + let r = crate::m::pow(x, y); + + // A NaN result should arise only from (-ve)**(finite non-integer); + // in this case we want to raise ValueError. + if !r.is_finite() { + if r.is_nan() { + return Err(crate::Error::EDOM); + } else if r.is_infinite() { + // An infinite result here arises either from: + // (A) (+/-0.)**negative (-> divide-by-zero) + // (B) overflow of x**y with x and y finite + if x == 0.0 { + return Err(crate::Error::EDOM); + } else { + return Err(crate::Error::ERANGE); + } + } + } + + Ok(r) +} + +// Tests + +#[cfg(test)] +mod tests { + use super::*; + + fn test_exp(x: f64) { + crate::test::test_math_1(x, "exp", exp); + } + fn test_exp2(x: f64) { + crate::test::test_math_1(x, "exp2", exp2); + } + fn test_expm1(x: f64) { + crate::test::test_math_1(x, "expm1", expm1); + } + fn test_log_without_base(x: f64) { + crate::test::test_math_1(x, "log", |x| log(x, None)); + } + fn test_log(x: f64, base: f64) { + crate::test::test_math_2(x, base, "log", |x, b| log(x, Some(b))); + } + fn test_log10(x: f64) { + crate::test::test_math_1(x, "log10", log10); + } + fn test_log2(x: f64) { + crate::test::test_math_1(x, "log2", log2); + } + fn test_log1p(x: f64) { + crate::test::test_math_1(x, "log1p", log1p); + } + fn test_sqrt(x: f64) { + crate::test::test_math_1(x, "sqrt", sqrt); + } + fn test_cbrt(x: f64) { + crate::test::test_math_1(x, "cbrt", cbrt); + } + fn test_pow(x: f64, y: f64) { + crate::test::test_math_2(x, y, "pow", pow); + } + + proptest::proptest! { + #[test] + fn proptest_exp(x: f64) { + test_exp(x); + } + + #[test] + fn proptest_exp2(x: f64) { + test_exp2(x); + } + + #[test] + fn proptest_sqrt(x: f64) { + test_sqrt(x); + } + + #[test] + fn proptest_cbrt(x: f64) { + test_cbrt(x); + } + + #[test] + fn proptest_expm1(x: f64) { + test_expm1(x); + } + + #[test] + fn proptest_log_without_base(x: f64) { + test_log_without_base(x); + } + + #[test] + fn proptest_log(x: f64, base: f64) { + test_log(x, base); + } + + #[test] + fn proptest_log10(x: f64) { + test_log10(x); + } + + #[test] + fn proptest_log2(x: f64) { + test_log2(x); + } + + #[test] + fn proptest_log1p(x: f64) { + test_log1p(x); + } + + #[test] + fn proptest_pow(x: f64, y: f64) { + test_pow(x, y); + } + } + + #[test] + fn edgetest_exp() { + for &x in &crate::test::EDGE_VALUES { + test_exp(x); + } + } + + #[test] + fn edgetest_exp2() { + for &x in &crate::test::EDGE_VALUES { + test_exp2(x); + } + } + + #[test] + fn edgetest_expm1() { + for &x in &crate::test::EDGE_VALUES { + test_expm1(x); + } + } + + #[test] + fn edgetest_sqrt() { + for &x in &crate::test::EDGE_VALUES { + test_sqrt(x); + } + } + + #[test] + fn edgetest_cbrt() { + for &x in &crate::test::EDGE_VALUES { + test_cbrt(x); + } + } + + #[test] + fn edgetest_log_without_base() { + for &x in &crate::test::EDGE_VALUES { + test_log_without_base(x); + } + } + + #[test] + fn edgetest_log10() { + for &x in &crate::test::EDGE_VALUES { + test_log10(x); + } + } + + #[test] + fn edgetest_log2() { + for &x in &crate::test::EDGE_VALUES { + test_log2(x); + } + } + + #[test] + fn edgetest_log1p() { + for &x in &crate::test::EDGE_VALUES { + test_log1p(x); + } + } + + #[test] + fn edgetest_pow() { + for &x in &crate::test::EDGE_VALUES { + for &y in &crate::test::EDGE_VALUES { + test_pow(x, y); + } + } + } +} diff --git a/src/gamma.rs b/src/math/gamma.rs similarity index 91% rename from src/gamma.rs rename to src/math/gamma.rs index 534f7ff..028853f 100644 --- a/src/gamma.rs +++ b/src/math/gamma.rs @@ -1,11 +1,28 @@ -use crate::Error; +//! Special functions: gamma, lgamma, erf, erfc. + +use crate::{Error, mul_add}; use std::f64::consts::PI; +/// Error function. +#[inline] +pub fn erf(x: f64) -> crate::Result { + super::math_1a(x, crate::m::erf) +} + +/// Complementary error function. +#[inline] +pub fn erfc(x: f64) -> crate::Result { + super::math_1a(x, crate::m::erfc) +} + const LOG_PI: f64 = 1.144729885849400174143427351353058711647; const LANCZOS_N: usize = 13; +#[allow(clippy::excessive_precision)] const LANCZOS_G: f64 = 6.024680040776729583740234375; +#[allow(clippy::excessive_precision)] const LANCZOS_G_MINUS_HALF: f64 = 5.524680040776729583740234375; +#[allow(clippy::excessive_precision)] const LANCZOS_NUM_COEFFS: [f64; LANCZOS_N] = [ 23531376880.410759688572007674451636754734846804940, 42919803642.649098768957899047001988850926355848959, @@ -37,14 +54,6 @@ const LANCZOS_DEN_COEFFS: [f64; LANCZOS_N] = [ 1.0, ]; -fn mul_add(a: f64, b: f64, c: f64) -> f64 { - if cfg!(feature = "mul_add") { - a.mul_add(b, c) - } else { - a * b + c - } -} - fn lanczos_sum(x: f64) -> f64 { let mut num = 0.0; let mut den = 0.0; @@ -206,9 +215,9 @@ pub fn gamma(x: f64) -> crate::Result { r }; if r.is_infinite() { - return Err((f64::INFINITY, Error::ERANGE).1); + Err((f64::INFINITY, Error::ERANGE).1) } else { - return Ok(r); + Ok(r) } } @@ -259,10 +268,12 @@ pub fn lgamma(x: f64) -> crate::Result { Ok(r) } -super::pyo3_proptest!(gamma(Result<_>), test_gamma, proptest_gamma, fulltest_gamma); -super::pyo3_proptest!( +crate::pyo3_proptest!(gamma(Result<_>), test_gamma, proptest_gamma, fulltest_gamma); +crate::pyo3_proptest!( lgamma(Result<_>), test_lgamma, proptest_lgamma, fulltest_lgamma ); +crate::pyo3_proptest!(erf(Result<_>), test_erf, proptest_erf, fulltest_erf); +crate::pyo3_proptest!(erfc(Result<_>), test_erfc, proptest_erfc, fulltest_erfc); diff --git a/src/math/integer.rs b/src/math/integer.rs new file mode 100644 index 0000000..ea6148c --- /dev/null +++ b/src/math/integer.rs @@ -0,0 +1,770 @@ +//! math.integer +//! +//! Integer-related mathematical functions. +//! This module requires either `num-bigint` or `malachite-bigint` feature. + +#[cfg(feature = "malachite-bigint")] +use malachite_bigint::{BigInt, BigUint}; +#[cfg(feature = "num-bigint")] +use num_bigint::{BigInt, BigUint}; + +use num_integer::Integer; +use num_traits::{One, Signed, ToPrimitive, Zero}; + +/// Return the greatest common divisor of a and b. +/// +/// Uses the optimized GCD implementation from num-integer, +#[inline] +pub fn gcd(a: &BigInt, b: &BigInt) -> BigInt { + a.gcd(b) +} + +/// Return the least common multiple of a and b. +#[inline] +pub fn lcm(a: &BigInt, b: &BigInt) -> BigInt { + if a.is_zero() || b.is_zero() { + return BigInt::zero(); + } + let g = gcd(a, b); + let f = a / &g; + (&f * b).abs() +} + +/// Approximate square roots for 16-bit integers. +/// For any n in range 2**14 <= n < 2**16, the value +/// a = APPROXIMATE_ISQRT_TAB[(n >> 8) - 64] +/// is an approximate square root of n, satisfying (a - 1)**2 < n < (a + 1)**2. +/// +/// The table was computed in Python using: +/// [min(round(sqrt(256*n + 128)), 255) for n in range(64, 256)] +const APPROXIMATE_ISQRT_TAB: [u8; 192] = [ + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 144, 145, + 146, 147, 148, 149, 150, 151, 151, 152, 153, 154, 155, 156, 156, 157, 158, 159, 160, 160, 161, + 162, 163, 164, 164, 165, 166, 167, 167, 168, 169, 170, 170, 171, 172, 173, 173, 174, 175, 176, + 176, 177, 178, 179, 179, 180, 181, 181, 182, 183, 183, 184, 185, 186, 186, 187, 188, 188, 189, + 190, 190, 191, 192, 192, 193, 194, 194, 195, 196, 196, 197, 198, 198, 199, 200, 200, 201, 201, + 202, 203, 203, 204, 205, 205, 206, 206, 207, 208, 208, 209, 210, 210, 211, 211, 212, 213, 213, + 214, 214, 215, 216, 216, 217, 217, 218, 219, 219, 220, 220, 221, 221, 222, 223, 223, 224, 224, + 225, 225, 226, 227, 227, 228, 228, 229, 229, 230, 230, 231, 232, 232, 233, 233, 234, 234, 235, + 235, 236, 237, 237, 238, 238, 239, 239, 240, 240, 241, 241, 242, 242, 243, 243, 244, 244, 245, + 246, 246, 247, 247, 248, 248, 249, 249, 250, 250, 251, 251, 252, 252, 253, 253, 254, 254, 255, + 255, 255, +]; + +/// Approximate square root of a large 64-bit integer. +/// +/// Given `n` satisfying `2**62 <= n < 2**64`, return `a` +/// satisfying `(a - 1)**2 < n < (a + 1)**2`. +#[inline] +fn approximate_isqrt(n: u64) -> u32 { + let u = APPROXIMATE_ISQRT_TAB[((n >> 56) - 64) as usize] as u32; + let u = (u << 7) + (n >> 41) as u32 / u; + (u << 15) + ((n >> 17) / u as u64) as u32 +} + +/// Return the integer part of the square root of the input. +/// +/// This is an adaptive-precision pure-integer version of Newton's iteration. +pub fn isqrt(n: &BigUint) -> BigUint { + if n.is_zero() { + return BigUint::zero(); + } + + // c = (n.bit_length() - 1) // 2 + let c = (n.bits() - 1) / 2; + + // Fast path: if c <= 31 then n < 2**64 and we can compute directly + if c <= 31 { + let shift = 31 - c as u32; + let m = n.to_u64().unwrap(); + let mut u = approximate_isqrt(m << (2 * shift)) >> shift; + if (u as u64) * (u as u64) > m { + u -= 1; + } + return BigUint::from(u); + } + + // Slow path: n >= 2**64 + // We perform the first five iterations in u64 arithmetic, + // then switch to using BigUint. + + // From n >= 2**64 it follows that c.bit_length() >= 6 + let mut c_bit_length = 6u64; + while (c >> c_bit_length) > 0 { + c_bit_length += 1; + } + + // Initialize d and a + let d = c >> (c_bit_length - 5); + let m = (n >> (2 * c - 62)).to_u64().unwrap(); + let u = approximate_isqrt(m) >> (31 - d as u32); + let mut a = BigUint::from(u); + + let mut prev_d = d; + for s in (0..=(c_bit_length - 6)).rev() { + let e = prev_d; + let d = c >> s; + + // q = (n >> 2*c - e - d + 1) // a + let shift = 2 * c - d - e + 1; + let q = (n >> shift) / &a; + + // a = (a << d - 1 - e) + q + a = (a << (d - 1 - e) as usize) + q; + + prev_d = d; + } + + // The correct result is either a or a - 1 + if &a * &a > *n { + a -= 1u32; + } + + a +} + +// FACTORIAL + +/// Lookup table for small factorial values +const SMALL_FACTORIALS: [u64; 21] = [ + 1, + 1, + 2, + 6, + 24, + 120, + 720, + 5040, + 40320, + 362880, + 3628800, + 39916800, + 479001600, + 6227020800, + 87178291200, + 1307674368000, + 20922789888000, + 355687428096000, + 6402373705728000, + 121645100408832000, + 2432902008176640000, +]; + +/// Count the number of set bits in n +#[inline] +fn count_set_bits(n: u64) -> u64 { + n.count_ones() as u64 +} + +/// Compute product(range(start, stop, 2)) using divide and conquer. +/// Assumes start and stop are odd and stop > start. +fn factorial_partial_product(start: u64, stop: u64, max_bits: u32) -> BigUint { + let num_operands = (stop - start) / 2; + + // If the result fits in a u64, multiply directly + if num_operands <= 64 && num_operands * (max_bits as u64) <= 64 { + let mut total = start; + let mut j = start + 2; + while j < stop { + total *= j; + j += 2; + } + return BigUint::from(total); + } + + // Find midpoint of range(start, stop), rounded up to next odd number + let midpoint = (start + num_operands) | 1; + let left = factorial_partial_product( + start, + midpoint, + (64 - midpoint.leading_zeros()).saturating_sub(1), + ); + let right = factorial_partial_product(midpoint, stop, max_bits); + left * right +} + +/// Compute the odd part of factorial(n). +fn factorial_odd_part(n: u64) -> BigUint { + let mut inner = BigUint::one(); + let mut outer = BigUint::one(); + + let mut upper = 3u64; + let n_bit_length = 64 - n.leading_zeros(); + + for i in (0..=(n_bit_length.saturating_sub(2))).rev() { + let v = n >> i; + if v <= 2 { + continue; + } + let lower = upper; + // (v + 1) | 1 = least odd integer strictly larger than n / 2**i + upper = (v + 1) | 1; + let partial = factorial_partial_product( + lower, + upper, + (64 - (upper - 2).leading_zeros()).saturating_sub(1), + ); + inner *= partial; + outer *= &inner; + } + + outer +} + +/// Return n factorial (n!). +/// +/// Uses the divide-and-conquer algorithm. +/// Based on: http://www.luschny.de/math/factorial/binarysplitfact.html +pub fn factorial(n: u64) -> BigUint { + // Use lookup table for small values + if n < SMALL_FACTORIALS.len() as u64 { + return BigUint::from(SMALL_FACTORIALS[n as usize]); + } + + // Express as odd_part * 2**two_valuation + let odd_part = factorial_odd_part(n); + let two_valuation = n - count_set_bits(n); + odd_part << two_valuation as usize +} + +// COMB / PERM + +/// Least significant 64 bits of the odd part of factorial(n), for n in range(128). +const REDUCED_FACTORIAL_ODD_PART: [u64; 128] = [ + 0x0000000000000001, + 0x0000000000000001, + 0x0000000000000001, + 0x0000000000000003, + 0x0000000000000003, + 0x000000000000000f, + 0x000000000000002d, + 0x000000000000013b, + 0x000000000000013b, + 0x0000000000000b13, + 0x000000000000375f, + 0x0000000000026115, + 0x000000000007233f, + 0x00000000005cca33, + 0x0000000002898765, + 0x00000000260eeeeb, + 0x00000000260eeeeb, + 0x0000000286fddd9b, + 0x00000016beecca73, + 0x000001b02b930689, + 0x00000870d9df20ad, + 0x0000b141df4dae31, + 0x00079dd498567c1b, + 0x00af2e19afc5266d, + 0x020d8a4d0f4f7347, + 0x335281867ec241ef, + 0x9b3093d46fdd5923, + 0x5e1f9767cc5866b1, + 0x92dd23d6966aced7, + 0xa30d0f4f0a196e5b, + 0x8dc3e5a1977d7755, + 0x2ab8ce915831734b, + 0x2ab8ce915831734b, + 0x81d2a0bc5e5fdcab, + 0x9efcac82445da75b, + 0xbc8b95cf58cde171, + 0xa0e8444a1f3cecf9, + 0x4191deb683ce3ffd, + 0xddd3878bc84ebfc7, + 0xcb39a64b83ff3751, + 0xf8203f7993fc1495, + 0xbd2a2a78b35f4bdd, + 0x84757be6b6d13921, + 0x3fbbcfc0b524988b, + 0xbd11ed47c8928df9, + 0x3c26b59e41c2f4c5, + 0x677a5137e883fdb3, + 0xff74e943b03b93dd, + 0xfe5ebbcb10b2bb97, + 0xb021f1de3235e7e7, + 0x33509eb2e743a58f, + 0x390f9da41279fb7d, + 0xe5cb0154f031c559, + 0x93074695ba4ddb6d, + 0x81c471caa636247f, + 0xe1347289b5a1d749, + 0x286f21c3f76ce2ff, + 0x00be84a2173e8ac7, + 0x1595065ca215b88b, + 0xf95877595b018809, + 0x9c2efe3c5516f887, + 0x373294604679382b, + 0xaf1ff7a888adcd35, + 0x18ddf279a2c5800b, + 0x18ddf279a2c5800b, + 0x505a90e2542582cb, + 0x5bacad2cd8d5dc2b, + 0xfe3152bcbff89f41, + 0xe1467e88bf829351, + 0xb8001adb9e31b4d5, + 0x2803ac06a0cbb91f, + 0x1904b5d698805799, + 0xe12a648b5c831461, + 0x3516abbd6160cfa9, + 0xac46d25f12fe036d, + 0x78bfa1da906b00ef, + 0xf6390338b7f111bd, + 0x0f25f80f538255d9, + 0x4ec8ca55b8db140f, + 0x4ff670740b9b30a1, + 0x8fd032443a07f325, + 0x80dfe7965c83eeb5, + 0xa3dc1714d1213afd, + 0x205b7bbfcdc62007, + 0xa78126bbe140a093, + 0x9de1dc61ca7550cf, + 0x84f0046d01b492c5, + 0x2d91810b945de0f3, + 0xf5408b7f6008aa71, + 0x43707f4863034149, + 0xdac65fb9679279d5, + 0xc48406e7d1114eb7, + 0xa7dc9ed3c88e1271, + 0xfb25b2efdb9cb30d, + 0x1bebda0951c4df63, + 0x5c85e975580ee5bd, + 0x1591bc60082cb137, + 0x2c38606318ef25d7, + 0x76ca72f7c5c63e27, + 0xf04a75d17baa0915, + 0x77458175139ae30d, + 0x0e6c1330bc1b9421, + 0xdf87d2b5797e8293, + 0xefa5c703e1e68925, + 0x2b6b1b3278b4f6e1, + 0xceee27b382394249, + 0xd74e3829f5dab91d, + 0xfdb17989c26b5f1f, + 0xc1b7d18781530845, + 0x7b4436b2105a8561, + 0x7ba7c0418372a7d7, + 0x9dbc5c67feb6c639, + 0x502686d7f6ff6b8f, + 0x6101855406be7a1f, + 0x9956afb5806930e7, + 0xe1f0ee88af40f7c5, + 0x984b057bda5c1151, + 0x9a49819acc13ea05, + 0x8ef0dead0896ef27, + 0x71f7826efe292b21, + 0xad80a480e46986ef, + 0x01cdc0ebf5e0c6f7, + 0x6e06f839968f68db, + 0xdd5943ab56e76139, + 0xcdcf31bf8604c5e7, + 0x7e2b4a847054a1cb, + 0x0ca75697a4d3d0f5, + 0x4703f53ac514a98b, +]; + +/// Inverses of reduced_factorial_odd_part values modulo 2**64. +const INVERTED_FACTORIAL_ODD_PART: [u64; 128] = [ + 0x0000000000000001, + 0x0000000000000001, + 0x0000000000000001, + 0xaaaaaaaaaaaaaaab, + 0xaaaaaaaaaaaaaaab, + 0xeeeeeeeeeeeeeeef, + 0x4fa4fa4fa4fa4fa5, + 0x2ff2ff2ff2ff2ff3, + 0x2ff2ff2ff2ff2ff3, + 0x938cc70553e3771b, + 0xb71c27cddd93e49f, + 0xb38e3229fcdee63d, + 0xe684bb63544a4cbf, + 0xc2f684917ca340fb, + 0xf747c9cba417526d, + 0xbb26eb51d7bd49c3, + 0xbb26eb51d7bd49c3, + 0xb0a7efb985294093, + 0xbe4b8c69f259eabb, + 0x6854d17ed6dc4fb9, + 0xe1aa904c915f4325, + 0x3b8206df131cead1, + 0x79c6009fea76fe13, + 0xd8c5d381633cd365, + 0x4841f12b21144677, + 0x4a91ff68200b0d0f, + 0x8f9513a58c4f9e8b, + 0x2b3e690621a42251, + 0x4f520f00e03c04e7, + 0x2edf84ee600211d3, + 0xadcaa2764aaacffd, + 0x161f4f9033f4fe63, + 0x161f4f9033f4fe63, + 0xbada2932ea4d3e03, + 0xcec189f3efaa30d3, + 0xf7475bb68330bf91, + 0x37eb7bf7d5b01549, + 0x46b35660a4e91555, + 0xa567c12d81f151f7, + 0x4c724007bb2071b1, + 0x0f4a0cce58a016bd, + 0xfa21068e66106475, + 0x244ab72b5a318ae1, + 0x366ce67e080d0f23, + 0xd666fdae5dd2a449, + 0xd740ddd0acc06a0d, + 0xb050bbbb28e6f97b, + 0x70b003fe890a5c75, + 0xd03aabff83037427, + 0x13ec4ca72c783bd7, + 0x90282c06afdbd96f, + 0x4414ddb9db4a95d5, + 0xa2c68735ae6832e9, + 0xbf72d71455676665, + 0xa8469fab6b759b7f, + 0xc1e55b56e606caf9, + 0x40455630fc4a1cff, + 0x0120a7b0046d16f7, + 0xa7c3553b08faef23, + 0x9f0bfd1b08d48639, + 0xa433ffce9a304d37, + 0xa22ad1d53915c683, + 0xcb6cbc723ba5dd1d, + 0x547fb1b8ab9d0ba3, + 0x547fb1b8ab9d0ba3, + 0x8f15a826498852e3, + 0x32e1a03f38880283, + 0x3de4cce63283f0c1, + 0x5dfe6667e4da95b1, + 0xfda6eeeef479e47d, + 0xf14de991cc7882df, + 0xe68db79247630ca9, + 0xa7d6db8207ee8fa1, + 0x255e1f0fcf034499, + 0xc9a8990e43dd7e65, + 0x3279b6f289702e0f, + 0xe7b5905d9b71b195, + 0x03025ba41ff0da69, + 0xb7df3d6d3be55aef, + 0xf89b212ebff2b361, + 0xfe856d095996f0ad, + 0xd6e533e9fdf20f9d, + 0xf8c0e84a63da3255, + 0xa677876cd91b4db7, + 0x07ed4f97780d7d9b, + 0x90a8705f258db62f, + 0xa41bbb2be31b1c0d, + 0x6ec28690b038383b, + 0xdb860c3bb2edd691, + 0x0838286838a980f9, + 0x558417a74b36f77d, + 0x71779afc3646ef07, + 0x743cda377ccb6e91, + 0x7fdf9f3fe89153c5, + 0xdc97d25df49b9a4b, + 0x76321a778eb37d95, + 0x7cbb5e27da3bd487, + 0x9cff4ade1a009de7, + 0x70eb166d05c15197, + 0xdcf0460b71d5fe3d, + 0x5ac1ee5260b6a3c5, + 0xc922dedfdd78efe1, + 0xe5d381dc3b8eeb9b, + 0xd57e5347bafc6aad, + 0x86939040983acd21, + 0x395b9d69740a4ff9, + 0x1467299c8e43d135, + 0x5fe440fcad975cdf, + 0xcaa9a39794a6ca8d, + 0xf61dbd640868dea1, + 0xac09d98d74843be7, + 0x2b103b9e1a6b4809, + 0x2ab92d16960f536f, + 0x6653323d5e3681df, + 0xefd48c1c0624e2d7, + 0xa496fefe04816f0d, + 0x1754a7b07bbdd7b1, + 0x23353c829a3852cd, + 0xbf831261abd59097, + 0x57a8e656df0618e1, + 0x16e9206c3100680f, + 0xadad4c6ee921dac7, + 0x635f2b3860265353, + 0xdd6d0059f44b3d09, + 0xac4dd6b894447dd7, + 0x42ea183eeaa87be3, + 0x15612d1550ee5b5d, + 0x226fa19d656cb623, +]; + +/// Exponent of the largest power of 2 dividing factorial(n), for n in range(128). +const FACTORIAL_TRAILING_ZEROS: [u8; 128] = [ + 0, 0, 1, 1, 3, 3, 4, 4, 7, 7, 8, 8, 10, 10, 11, 11, // 0-15 + 15, 15, 16, 16, 18, 18, 19, 19, 22, 22, 23, 23, 25, 25, 26, 26, // 16-31 + 31, 31, 32, 32, 34, 34, 35, 35, 38, 38, 39, 39, 41, 41, 42, 42, // 32-47 + 46, 46, 47, 47, 49, 49, 50, 50, 53, 53, 54, 54, 56, 56, 57, 57, // 48-63 + 63, 63, 64, 64, 66, 66, 67, 67, 70, 70, 71, 71, 73, 73, 74, 74, // 64-79 + 78, 78, 79, 79, 81, 81, 82, 82, 85, 85, 86, 86, 88, 88, 89, 89, // 80-95 + 94, 94, 95, 95, 97, 97, 98, 98, 101, 101, 102, 102, 104, 104, 105, 105, // 96-111 + 109, 109, 110, 110, 112, 112, 113, 113, 116, 116, 117, 117, 119, 119, 120, 120, // 112-127 +]; + +/// Maximal n so that 2*k-1 <= n <= 127 and C(n, k) fits into a u64. +const FAST_COMB_LIMITS1: [u8; 35] = [ + 0, 0, 127, 127, 127, 127, 127, 127, // 0-7 + 127, 127, 127, 127, 127, 127, 127, 127, // 8-15 + 116, 105, 97, 91, 86, 82, 78, 76, // 16-23 + 74, 72, 71, 70, 69, 68, 68, 67, // 24-31 + 67, 67, 67, // 32-34 +]; + +/// Maximal n so that 2*k-1 <= n <= 127 and C(n, k)*k fits into a u64. +const FAST_COMB_LIMITS2: [u64; 14] = [ + 0, + u64::MAX, + 4294967296, + 3329022, + 102570, + 13467, + 3612, + 1449, // 0-7 + 746, + 453, + 308, + 227, + 178, + 147, // 8-13 +]; + +/// Maximal n so that k <= n and P(n, k) fits into a u64. +const FAST_PERM_LIMITS: [u64; 21] = [ + 0, + u64::MAX, + 4294967296, + 2642246, + 65537, + 7133, + 1627, + 568, // 0-7 + 259, + 142, + 88, + 61, + 45, + 36, + 30, + 26, // 8-15 + 24, + 22, + 21, + 20, + 20, // 16-20 +]; + +/// Calculate C(n, k) or P(n, k) for n in the 63-bit range. +fn perm_comb_small(n: u64, k: u64, is_comb: bool) -> BigUint { + if k == 0 { + return BigUint::one(); + } + + if is_comb { + // Fast path 1: use lookup tables for small n + if (k as usize) < FAST_COMB_LIMITS1.len() && n <= FAST_COMB_LIMITS1[k as usize] as u64 { + let comb_odd_part = REDUCED_FACTORIAL_ODD_PART[n as usize] + .wrapping_mul(INVERTED_FACTORIAL_ODD_PART[k as usize]) + .wrapping_mul(INVERTED_FACTORIAL_ODD_PART[(n - k) as usize]); + let shift = FACTORIAL_TRAILING_ZEROS[n as usize] as i32 + - FACTORIAL_TRAILING_ZEROS[k as usize] as i32 + - FACTORIAL_TRAILING_ZEROS[(n - k) as usize] as i32; + return BigUint::from(comb_odd_part << shift); + } + + // Fast path 2: sequential multiplication for medium values + if (k as usize) < FAST_COMB_LIMITS2.len() && n <= FAST_COMB_LIMITS2[k as usize] { + let mut result = n; + let mut n = n; + let mut i = 1u64; + while i < k { + n -= 1; + result *= n; + i += 1; + result /= i; + } + return BigUint::from(result); + } + } else { + // Permutation fast paths + if (k as usize) < FAST_PERM_LIMITS.len() && n <= FAST_PERM_LIMITS[k as usize] { + if n <= 127 { + let perm_odd_part = REDUCED_FACTORIAL_ODD_PART[n as usize] + .wrapping_mul(INVERTED_FACTORIAL_ODD_PART[(n - k) as usize]); + let shift = FACTORIAL_TRAILING_ZEROS[n as usize] as i32 + - FACTORIAL_TRAILING_ZEROS[(n - k) as usize] as i32; + return BigUint::from(perm_odd_part << shift); + } + + let mut result = n; + let mut n = n; + let mut i = 1u64; + while i < k { + n -= 1; + result *= n; + i += 1; + } + return BigUint::from(result); + } + } + + // For larger n use recursive formulas: + // P(n, k) = P(n, j) * P(n-j, k-j) + // C(n, k) = C(n, j) * C(n-j, k-j) / C(k, j) + let j = k / 2; + let a = perm_comb_small(n, j, is_comb); + let b = perm_comb_small(n - j, k - j, is_comb); + let mut result = a * b; + if is_comb { + let c = perm_comb_small(k, j, true); + result /= c; + } + result +} + +/// Calculate P(n, k) or C(n, k) using recursive formulas for big n. +/// Reserved for future BigUint n support. +#[allow(dead_code)] +fn perm_comb(n: &BigUint, k: u64, is_comb: bool) -> BigUint { + if k == 0 { + return BigUint::one(); + } + if k == 1 { + return n.clone(); + } + + // P(n, k) = P(n, j) * P(n-j, k-j) + // C(n, k) = C(n, j) * C(n-j, k-j) / C(k, j) + let j = k / 2; + let a = perm_comb(n, j, is_comb); + let n_minus_j = n - BigUint::from(j); + let b = perm_comb(&n_minus_j, k - j, is_comb); + let mut result = a * b; + if is_comb { + let c = perm_comb_small(k, j, true); + result /= c; + } + result +} + +/// Return the number of ways to choose k items from n items (n choose k). +/// +/// Evaluates to n! / (k! * (n - k)!) when k <= n and evaluates +/// to zero when k > n. +pub fn comb(n: u64, k: u64) -> BigUint { + if k > n { + return BigUint::zero(); + } + + // Use smaller k for efficiency + let k = k.min(n - k); + + if k <= 1 { + if k == 0 { + return BigUint::one(); + } + return BigUint::from(n); + } + + perm_comb_small(n, k, true) +} + +/// Return the number of ways to arrange k items from n items. +/// +/// Evaluates to n! / (n - k)! when k <= n and evaluates +/// to zero when k > n. +/// +/// If k is not specified (None), then k defaults to n +/// and the function returns n!. +pub fn perm(n: u64, k: Option) -> BigUint { + let k = k.unwrap_or(n); + if k > n { + return BigUint::zero(); + } + + if k == 0 { + return BigUint::one(); + } + if k == 1 { + return BigUint::from(n); + } + + perm_comb_small(n, k, false) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_gcd() { + assert_eq!(gcd(&12.into(), &8.into()), 4.into()); + assert_eq!(gcd(&8.into(), &12.into()), 4.into()); + assert_eq!(gcd(&0.into(), &5.into()), 5.into()); + assert_eq!(gcd(&5.into(), &0.into()), 5.into()); + assert_eq!(gcd(&0.into(), &0.into()), 0.into()); + assert_eq!(gcd(&(-12).into(), &8.into()), 4.into()); + assert_eq!(gcd(&12.into(), &(-8).into()), 4.into()); + assert_eq!(gcd(&(-12).into(), &(-8).into()), 4.into()); + assert_eq!(gcd(&17.into(), &13.into()), 1.into()); + } + + #[test] + fn test_lcm() { + assert_eq!(lcm(&12.into(), &8.into()), 24.into()); + assert_eq!(lcm(&8.into(), &12.into()), 24.into()); + assert_eq!(lcm(&0.into(), &5.into()), 0.into()); + assert_eq!(lcm(&5.into(), &0.into()), 0.into()); + assert_eq!(lcm(&0.into(), &0.into()), 0.into()); + assert_eq!(lcm(&(-12).into(), &8.into()), 24.into()); + assert_eq!(lcm(&12.into(), &(-8).into()), 24.into()); + assert_eq!(lcm(&17.into(), &13.into()), 221.into()); + } + + #[test] + fn test_isqrt() { + assert_eq!(isqrt(&BigUint::from(0u32)), BigUint::from(0u32)); + assert_eq!(isqrt(&BigUint::from(1u32)), BigUint::from(1u32)); + assert_eq!(isqrt(&BigUint::from(4u32)), BigUint::from(2u32)); + assert_eq!(isqrt(&BigUint::from(9u32)), BigUint::from(3u32)); + assert_eq!(isqrt(&BigUint::from(10u32)), BigUint::from(3u32)); + assert_eq!(isqrt(&BigUint::from(15u32)), BigUint::from(3u32)); + assert_eq!(isqrt(&BigUint::from(16u32)), BigUint::from(4u32)); + assert_eq!(isqrt(&BigUint::from(100u32)), BigUint::from(10u32)); + assert_eq!(isqrt(&BigUint::from(1000000u32)), BigUint::from(1000u32)); + // Test large number + let large = BigUint::from(10u64).pow(40); + assert_eq!(isqrt(&large), BigUint::from(10u64).pow(20)); + } + + #[test] + fn test_factorial() { + assert_eq!(factorial(0), BigUint::from(1u32)); + assert_eq!(factorial(1), BigUint::from(1u32)); + assert_eq!(factorial(5), BigUint::from(120u32)); + assert_eq!(factorial(10), BigUint::from(3628800u32)); + assert_eq!(factorial(20), BigUint::from(2432902008176640000u64)); + } + + #[test] + fn test_comb() { + assert_eq!(comb(5, 0), BigUint::from(1u32)); + assert_eq!(comb(5, 5), BigUint::from(1u32)); + assert_eq!(comb(5, 2), BigUint::from(10u32)); + assert_eq!(comb(10, 3), BigUint::from(120u32)); + assert_eq!(comb(3, 5), BigUint::from(0u32)); // k > n + assert_eq!(comb(100, 30), comb(100, 70)); // symmetry: C(n, k) == C(n, n-k) + } + + #[test] + fn test_perm() { + assert_eq!(perm(5, Some(0)), BigUint::from(1u32)); + assert_eq!(perm(5, Some(5)), BigUint::from(120u32)); + assert_eq!(perm(5, Some(2)), BigUint::from(20u32)); + assert_eq!(perm(5, None), BigUint::from(120u32)); // 5! + assert_eq!(perm(3, Some(5)), BigUint::from(0u32)); // k > n + } +} diff --git a/src/math/misc.rs b/src/math/misc.rs new file mode 100644 index 0000000..9809058 --- /dev/null +++ b/src/math/misc.rs @@ -0,0 +1,450 @@ +//! Floating-point manipulation and validation functions. + +use crate::{Error, Result, m}; + +super::libm_simple!(@1 ceil, floor, trunc); +super::libm_simple!(@2 nextafter); + +/// Return the absolute value of x. +#[inline] +pub fn fabs(x: f64) -> Result { + super::math_1(x, crate::m::fabs, false) +} + +/// Return a float with the magnitude of x but the sign of y. +#[inline] +pub fn copysign(x: f64, y: f64) -> crate::Result { + super::math_2(x, y, crate::m::copysign) +} + +// Validation functions + +/// Return True if x is neither an infinity nor a NaN, False otherwise. +#[inline] +pub fn isfinite(x: f64) -> bool { + x.is_finite() +} + +/// Return True if x is a positive or negative infinity, False otherwise. +#[inline] +pub fn isinf(x: f64) -> bool { + x.is_infinite() +} + +/// Return True if x is a NaN, False otherwise. +#[inline] +pub fn isnan(x: f64) -> bool { + x.is_nan() +} + +/// Return True if a and b are close to each other. +/// +/// Whether or not two values are considered close is determined according to +/// given absolute and relative tolerances: +/// `abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)` +/// +/// Returns Err(EDOM) if rel_tol or abs_tol is negative. +#[inline] +pub fn isclose(a: f64, b: f64, rel_tol: f64, abs_tol: f64) -> Result { + // Tolerances must be non-negative + if rel_tol < 0.0 || abs_tol < 0.0 { + return Err(Error::EDOM); + } + if a == b { + return Ok(true); + } + if a.is_nan() || b.is_nan() { + return Ok(false); + } + if a.is_infinite() || b.is_infinite() { + return Ok(false); + } + let diff = (a - b).abs(); + Ok(diff <= abs_tol.max(rel_tol * a.abs().max(b.abs()))) +} + +/// Return the mantissa and exponent of x as (m, e). +/// +/// m is a float and e is an integer such that x == m * 2**e exactly. +#[inline] +pub fn frexp(x: f64) -> (f64, i32) { + // Handle special cases directly, to sidestep platform differences + if x.is_nan() || x.is_infinite() || x == 0.0 { + return (x, 0); + } + let mut exp: i32 = 0; + let mantissa = m::frexp(x, &mut exp); + (mantissa, exp) +} + +/// Return x * (2**i). +/// +/// Returns ERANGE if the result overflows. +#[inline] +pub fn ldexp(x: f64, i: i32) -> Result { + // NaNs, zeros and infinities are returned unchanged + if x == 0.0 || !x.is_finite() { + return Ok(x); + } + let r = m::ldexp(x, i); + if r.is_infinite() { + return Err(Error::ERANGE); + } + Ok(r) +} + +/// Return the fractional and integer parts of x. +/// +/// Returns (fractional_part, integer_part). +#[inline] +pub fn modf(x: f64) -> (f64, f64) { + // Some platforms don't do the right thing for NaNs and infinities, + // so we take care of special cases directly. + if x.is_infinite() { + return (m::copysign(0.0, x), x); + } + if x.is_nan() { + return (x, x); + } + let mut int_part: f64 = 0.0; + let frac_part = m::modf(x, &mut int_part); + (frac_part, int_part) +} + +/// Return the remainder of x / y. +/// +/// Returns EDOM if y is zero or x is infinite. +#[inline] +pub fn fmod(x: f64, y: f64) -> Result { + // fmod(x, +/-Inf) returns x for finite x. + if y.is_infinite() && x.is_finite() { + return Ok(x); + } + let r = m::fmod(x, y); + if r.is_nan() && !x.is_nan() && !y.is_nan() { + return Err(Error::EDOM); + } + Ok(r) +} + +/// Return the IEEE 754-style remainder of x with respect to y. +#[inline] +pub fn remainder(x: f64, y: f64) -> Result { + super::math_2(x, y, crate::m::remainder) +} + +/// Return the value of the least significant bit of x. +#[inline] +pub fn ulp(x: f64) -> f64 { + if x.is_nan() { + return x; + } + let x = x.abs(); + if x.is_infinite() { + return x; + } + let x2 = super::nextafter(x, f64::INFINITY); + if x2.is_infinite() { + // Special case: x is the largest positive representable float + let x2 = super::nextafter(x, f64::NEG_INFINITY); + return x - x2; + } + x2 - x +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_ldexp(x: f64, i: i32) { + use pyo3::prelude::*; + + let rs_result = ldexp(x, i); + + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + let py_func = math.getattr("ldexp").unwrap(); + let r = py_func.call1((x, i)); + + match (r, &rs_result) { + (Ok(py_val), Ok(rs_val)) => { + let py_f: f64 = py_val.extract().unwrap(); + assert_eq!( + py_f.to_bits(), + rs_val.to_bits(), + "ldexp({x}, {i}): py={py_f} vs rs={rs_val}" + ); + } + (Err(_), Err(_)) => {} + (Ok(py_val), Err(e)) => { + let py_f: f64 = py_val.extract().unwrap(); + panic!("ldexp({x}, {i}): py={py_f} but rs returned error {e:?}"); + } + (Err(e), Ok(rs_val)) => { + panic!("ldexp({x}, {i}): py raised {e} but rs={rs_val}"); + } + } + }); + } + + fn test_frexp(x: f64) { + use pyo3::prelude::*; + + let rs_result = frexp(x); + + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + let py_func = math.getattr("frexp").unwrap(); + let r = py_func.call1((x,)); + + if let Ok(py_val) = r { + let (py_m, py_e): (f64, i32) = py_val.extract().unwrap(); + assert_eq!( + py_m.to_bits(), + rs_result.0.to_bits(), + "frexp({x}) mantissa: py={py_m} vs rs={}", + rs_result.0 + ); + assert_eq!( + py_e, rs_result.1, + "frexp({x}) exponent: py={py_e} vs rs={}", + rs_result.1 + ); + } + }); + } + + fn test_modf(x: f64) { + use pyo3::prelude::*; + + let rs_result = modf(x); + + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + let py_func = math.getattr("modf").unwrap(); + let r = py_func.call1((x,)); + + if let Ok(py_val) = r { + let (py_frac, py_int): (f64, f64) = py_val.extract().unwrap(); + assert_eq!( + py_frac.to_bits(), + rs_result.0.to_bits(), + "modf({x}) frac: py={py_frac} vs rs={}", + rs_result.0 + ); + assert_eq!( + py_int.to_bits(), + rs_result.1.to_bits(), + "modf({x}) int: py={py_int} vs rs={}", + rs_result.1 + ); + } + }); + } + + fn test_fmod(x: f64, y: f64) { + crate::test::test_math_2(x, y, "fmod", fmod); + } + fn test_remainder(x: f64, y: f64) { + crate::test::test_math_2(x, y, "remainder", remainder); + } + fn test_copysign(x: f64, y: f64) { + crate::test::test_math_2(x, y, "copysign", copysign); + } + fn test_ulp(x: f64) { + crate::test::test_math_1(x, "ulp", |x| Ok(ulp(x))); + } + + #[test] + fn edgetest_frexp() { + for &x in &crate::test::EDGE_VALUES { + test_frexp(x); + } + } + + #[test] + fn edgetest_ldexp() { + for &x in &crate::test::EDGE_VALUES { + for &i in &crate::test::EDGE_INTS { + test_ldexp(x, i); + } + } + } + + #[test] + fn edgetest_modf() { + for &x in &crate::test::EDGE_VALUES { + test_modf(x); + } + } + + #[test] + fn edgetest_fmod() { + for &x in &crate::test::EDGE_VALUES { + for &y in &crate::test::EDGE_VALUES { + test_fmod(x, y); + } + } + } + + #[test] + fn edgetest_remainder() { + for &x in &crate::test::EDGE_VALUES { + for &y in &crate::test::EDGE_VALUES { + test_remainder(x, y); + } + } + } + + #[test] + fn edgetest_copysign() { + for &x in &crate::test::EDGE_VALUES { + for &y in &crate::test::EDGE_VALUES { + test_copysign(x, y); + } + } + } + + #[test] + fn edgetest_ulp() { + for &x in &crate::test::EDGE_VALUES { + test_ulp(x); + } + } + + proptest::proptest! { + #[test] + fn proptest_frexp(x: f64) { + test_frexp(x); + } + + #[test] + fn proptest_ldexp(x: f64, i: i32) { + test_ldexp(x, i); + } + + #[test] + fn proptest_modf(x: f64) { + test_modf(x); + } + + #[test] + fn proptest_fmod(x: f64, y: f64) { + test_fmod(x, y); + } + + #[test] + fn proptest_remainder(x: f64, y: f64) { + test_remainder(x, y); + } + + #[test] + fn proptest_copysign(x: f64, y: f64) { + test_copysign(x, y); + } + + #[test] + fn proptest_ulp(x: f64) { + test_ulp(x); + } + } + + #[test] + fn test_validation_functions() { + // isfinite + assert!(isfinite(0.0)); + assert!(isfinite(1.0)); + assert!(isfinite(-1.0)); + assert!(!isfinite(f64::INFINITY)); + assert!(!isfinite(f64::NEG_INFINITY)); + assert!(!isfinite(f64::NAN)); + + // isinf + assert!(!isinf(0.0)); + assert!(!isinf(1.0)); + assert!(!isinf(f64::NAN)); + assert!(isinf(f64::INFINITY)); + assert!(isinf(f64::NEG_INFINITY)); + + // isnan + assert!(!isnan(0.0)); + assert!(!isnan(1.0)); + assert!(!isnan(f64::INFINITY)); + assert!(!isnan(f64::NEG_INFINITY)); + assert!(isnan(f64::NAN)); + } + + fn test_isclose_impl(a: f64, b: f64, rel_tol: f64, abs_tol: f64) { + use pyo3::prelude::*; + + let rs_result = isclose(a, b, rel_tol, abs_tol); + + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + let py_func = math.getattr("isclose").unwrap(); + let kwargs = pyo3::types::PyDict::new(py); + kwargs.set_item("rel_tol", rel_tol).unwrap(); + kwargs.set_item("abs_tol", abs_tol).unwrap(); + let py_result = py_func.call((a, b), Some(&kwargs)); + + match py_result { + Ok(result) => { + let py_bool: bool = result.extract().unwrap(); + let rs_bool = rs_result.unwrap(); + assert_eq!( + py_bool, rs_bool, + "a = {a}, b = {b}, rel_tol = {rel_tol}, abs_tol = {abs_tol}" + ); + } + Err(e) => { + if e.is_instance_of::(py) { + assert_eq!(rs_result.err(), Some(Error::EDOM)); + } else { + panic!("isclose({a}, {b}): py raised unexpected error {e}"); + } + } + } + }); + } + + #[test] + fn test_isclose() { + // Equal values + test_isclose_impl(1.0, 1.0, 1e-9, 0.0); + test_isclose_impl(0.0, 0.0, 1e-9, 0.0); + test_isclose_impl(-1.0, -1.0, 1e-9, 0.0); + + // Close values + test_isclose_impl(1.0, 1.0 + 1e-10, 1e-9, 0.0); + test_isclose_impl(1.0, 1.0 + 1e-8, 1e-9, 0.0); + + // Not close values + test_isclose_impl(1.0, 2.0, 1e-9, 0.0); + test_isclose_impl(1.0, 1.1, 1e-9, 0.0); + + // With abs_tol + test_isclose_impl(0.0, 1e-10, 1e-9, 1e-9); + test_isclose_impl(0.0, 1e-8, 1e-9, 1e-9); + + // Infinities + test_isclose_impl(f64::INFINITY, f64::INFINITY, 1e-9, 0.0); + test_isclose_impl(f64::NEG_INFINITY, f64::NEG_INFINITY, 1e-9, 0.0); + test_isclose_impl(f64::INFINITY, f64::NEG_INFINITY, 1e-9, 0.0); + test_isclose_impl(f64::INFINITY, 1.0, 1e-9, 0.0); + + // NaN + test_isclose_impl(f64::NAN, f64::NAN, 1e-9, 0.0); + test_isclose_impl(f64::NAN, 1.0, 1e-9, 0.0); + + // Zero comparison + test_isclose_impl(0.0, 1e-10, 1e-9, 0.0); + } + + proptest::proptest! { + #[test] + fn proptest_isclose(a: f64, b: f64) { + // Use default tolerances + test_isclose_impl(a, b, 1e-9, 0.0); + } + } +} diff --git a/src/math/trigonometric.rs b/src/math/trigonometric.rs new file mode 100644 index 0000000..285b767 --- /dev/null +++ b/src/math/trigonometric.rs @@ -0,0 +1,298 @@ +//! Trigonometric and hyperbolic functions. + +use crate::Result; + +// Trigonometric functions + +/// Return the sine of x (in radians). +#[inline] +pub fn sin(x: f64) -> Result { + super::math_1(x, crate::m::sin, false) +} + +/// Return the cosine of x (in radians). +#[inline] +pub fn cos(x: f64) -> Result { + super::math_1(x, crate::m::cos, false) +} + +/// Return the tangent of x (in radians). +#[inline] +pub fn tan(x: f64) -> Result { + super::math_1(x, crate::m::tan, false) +} + +/// Return the arc sine of x, in radians. +/// Result is in the range [-pi/2, pi/2]. +#[inline] +pub fn asin(x: f64) -> Result { + super::math_1(x, crate::m::asin, false) +} + +/// Return the arc cosine of x, in radians. +/// Result is in the range [0, pi]. +#[inline] +pub fn acos(x: f64) -> Result { + super::math_1(x, crate::m::acos, false) +} + +/// Return the arc tangent of x, in radians. +/// Result is in the range [-pi/2, pi/2]. +#[inline] +pub fn atan(x: f64) -> Result { + super::math_1(x, crate::m::atan, false) +} + +/// Return the arc tangent of y/x, in radians. +/// Result is in the range [-pi, pi]. +#[inline] +pub fn atan2(y: f64, x: f64) -> Result { + super::math_2(y, x, crate::m::atan2) +} + +// Hyperbolic functions + +/// Hyperbolic sine. +#[inline] +pub fn sinh(x: f64) -> Result { + super::math_1(x, crate::m::sinh, true) +} + +/// Hyperbolic cosine. +#[inline] +pub fn cosh(x: f64) -> Result { + super::math_1(x, crate::m::cosh, true) +} + +/// Hyperbolic tangent. +#[inline] +pub fn tanh(x: f64) -> Result { + super::math_1(x, crate::m::tanh, false) +} + +/// Inverse hyperbolic sine. +#[inline] +pub fn asinh(x: f64) -> Result { + super::math_1(x, crate::m::asinh, false) +} + +/// Inverse hyperbolic cosine. +#[inline] +pub fn acosh(x: f64) -> Result { + super::math_1(x, crate::m::acosh, false) +} + +/// Inverse hyperbolic tangent. +#[inline] +pub fn atanh(x: f64) -> Result { + super::math_1(x, crate::m::atanh, false) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_sin(x: f64) { + crate::test::test_math_1(x, "sin", sin); + } + fn test_cos(x: f64) { + crate::test::test_math_1(x, "cos", cos); + } + fn test_tan(x: f64) { + crate::test::test_math_1(x, "tan", tan); + } + fn test_asin(x: f64) { + crate::test::test_math_1(x, "asin", asin); + } + fn test_acos(x: f64) { + crate::test::test_math_1(x, "acos", acos); + } + fn test_atan(x: f64) { + crate::test::test_math_1(x, "atan", atan); + } + fn test_atan2(y: f64, x: f64) { + crate::test::test_math_2(y, x, "atan2", atan2); + } + fn test_sinh(x: f64) { + crate::test::test_math_1(x, "sinh", sinh); + } + fn test_cosh(x: f64) { + crate::test::test_math_1(x, "cosh", cosh); + } + fn test_tanh(x: f64) { + crate::test::test_math_1(x, "tanh", tanh); + } + fn test_asinh(x: f64) { + crate::test::test_math_1(x, "asinh", asinh); + } + fn test_acosh(x: f64) { + crate::test::test_math_1(x, "acosh", acosh); + } + fn test_atanh(x: f64) { + crate::test::test_math_1(x, "atanh", atanh); + } + + // Trigonometric edge tests + #[test] + fn edgetest_sin() { + for &x in &crate::test::EDGE_VALUES { + test_sin(x); + } + } + + #[test] + fn edgetest_cos() { + for &x in &crate::test::EDGE_VALUES { + test_cos(x); + } + } + + #[test] + fn edgetest_tan() { + for &x in &crate::test::EDGE_VALUES { + test_tan(x); + } + } + + #[test] + fn edgetest_asin() { + for &x in &crate::test::EDGE_VALUES { + test_asin(x); + } + } + + #[test] + fn edgetest_acos() { + for &x in &crate::test::EDGE_VALUES { + test_acos(x); + } + } + + #[test] + fn edgetest_atan() { + for &x in &crate::test::EDGE_VALUES { + test_atan(x); + } + } + + #[test] + fn edgetest_atan2() { + for &y in &crate::test::EDGE_VALUES { + for &x in &crate::test::EDGE_VALUES { + test_atan2(y, x); + } + } + } + + // Hyperbolic edge tests + #[test] + fn edgetest_tanh() { + for &x in &crate::test::EDGE_VALUES { + test_tanh(x); + } + } + + #[test] + fn edgetest_asinh() { + for &x in &crate::test::EDGE_VALUES { + test_asinh(x); + } + } + + #[test] + fn edgetest_sinh() { + for &x in &crate::test::EDGE_VALUES { + test_sinh(x); + } + } + + #[test] + fn edgetest_cosh() { + for &x in &crate::test::EDGE_VALUES { + test_cosh(x); + } + } + + #[test] + fn edgetest_acosh() { + for &x in &crate::test::EDGE_VALUES { + test_acosh(x); + } + } + + #[test] + fn edgetest_atanh() { + for &x in &crate::test::EDGE_VALUES { + test_atanh(x); + } + } + + proptest::proptest! { + #[test] + fn proptest_sin(x: f64) { + test_sin(x); + } + + #[test] + fn proptest_cos(x: f64) { + test_cos(x); + } + + #[test] + fn proptest_tan(x: f64) { + test_tan(x); + } + + #[test] + fn proptest_atan2(y: f64, x: f64) { + test_atan2(y, x); + } + + // Trigonometric proptests + #[test] + fn proptest_asin(x: f64) { + test_asin(x); + } + + #[test] + fn proptest_acos(x: f64) { + test_acos(x); + } + + #[test] + fn proptest_atan(x: f64) { + test_atan(x); + } + + // Hyperbolic proptests + #[test] + fn proptest_tanh(x: f64) { + test_tanh(x); + } + + #[test] + fn proptest_asinh(x: f64) { + test_asinh(x); + } + + #[test] + fn proptest_sinh(x: f64) { + test_sinh(x); + } + + #[test] + fn proptest_cosh(x: f64) { + test_cosh(x); + } + + #[test] + fn proptest_acosh(x: f64) { + test_acosh(x); + } + + #[test] + fn proptest_atanh(x: f64) { + test_atanh(x); + } + } +} diff --git a/src/test.rs b/src/test.rs index 7869f1e..2902845 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,17 +1,62 @@ use crate::Error; use pyo3::{Python, prelude::*}; -pub(crate) fn unwrap<'a, T: 'a>( - py: Python, - py_v: PyResult>, - v: Result, -) -> Option<(T, T)> -where - T: PartialEq + std::fmt::Debug + FromPyObject<'a>, -{ +/// Edge values for testing floating-point functions. +/// Includes: zeros, infinities, various NaNs, subnormals, and values at different scales. +pub(crate) const EDGE_VALUES: [f64; 30] = [ + // Zeros + 0.0, + -0.0, + // Infinities + f64::INFINITY, + f64::NEG_INFINITY, + // Standard NaNs + f64::NAN, + -f64::NAN, + // Additional NaN with different payload (quiet NaN with payload 1) + f64::from_bits(0x7FF8_0000_0000_0001_u64), + // Subnormal (denormalized) values + f64::MIN_POSITIVE * 0.5, // smallest subnormal + -f64::MIN_POSITIVE * 0.5, + // Boundary values + f64::MIN_POSITIVE, // smallest positive normal + f64::MAX, // largest finite + f64::MIN, // most negative finite (not smallest!) + // Near-infinity large values + f64::MAX * 0.5, + -f64::MAX * 0.5, + 1e308, + -1e308, + // Small scale + 1e-10, + -1e-10, + 1e-300, + // Normal scale + 1.0, + -1.0, + 0.5, + -0.5, + 2.0, + // Trigonometric special values (where sin/cos/tan have exact or near-zero results) + std::f64::consts::PI, // sin(PI) ≈ 0 + -std::f64::consts::PI, + std::f64::consts::FRAC_PI_2, // cos(PI/2) ≈ 0 + -std::f64::consts::FRAC_PI_2, + std::f64::consts::FRAC_PI_4, // tan(PI/4) = 1 + std::f64::consts::TAU, // sin(2*PI) ≈ 0, cos(2*PI) = 1 +]; + +/// Edge integer values for testing functions like ldexp +pub(crate) const EDGE_INTS: [i32; 9] = [0, 1, -1, 100, -100, 1024, -1024, i32::MAX, i32::MIN]; + +pub(crate) fn unwrap<'py>( + py: Python<'py>, + py_v: PyResult>, + v: Result, +) -> Option<(f64, f64)> { match py_v { Ok(py_v) => { - let py_v: T = py_v.extract().unwrap(); + let py_v: f64 = py_v.extract().ok().expect("failed to extract"); Some((py_v, v.unwrap())) } Err(e) => { @@ -26,3 +71,77 @@ where } } } + +/// Test a 1-argument function that returns Result +pub(crate) fn test_math_1(x: f64, func_name: &str, rs_func: impl Fn(f64) -> crate::Result) { + let rs_result = rs_func(x); + + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + let py_func = math.getattr(func_name).unwrap(); + let r = py_func.call1((x,)); + let Some((py_result, rs_result)) = unwrap(py, r, rs_result) else { + return; + }; + if py_result.is_nan() && rs_result.is_nan() { + return; + } + assert_eq!( + py_result.to_bits(), + rs_result.to_bits(), + "{func_name}({x}): py={py_result} vs rs={rs_result}" + ); + }); +} + +/// Test a 2-argument function that returns Result +pub(crate) fn test_math_2( + x: f64, + y: f64, + func_name: &str, + rs_func: impl Fn(f64, f64) -> crate::Result, +) { + let rs_result = rs_func(x, y); + + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + let py_func = math.getattr(func_name).unwrap(); + let r = py_func.call1((x, y)); + + match r { + Ok(py_val) => { + let py_f: f64 = py_val.extract().unwrap(); + let rs_val = rs_result.unwrap_or_else(|e| { + panic!("{func_name}({x}, {y}): py={py_f} but rs returned error {e:?}") + }); + if py_f.is_nan() && rs_val.is_nan() { + return; + } + assert_eq!( + py_f.to_bits(), + rs_val.to_bits(), + "{func_name}({x}, {y}): py={py_f} vs rs={rs_val}" + ); + } + Err(e) => { + // Check error type matches + let rs_err = rs_result.as_ref().err(); + if e.is_instance_of::(py) { + assert_eq!( + rs_err, + Some(&Error::EDOM), + "{func_name}({x}, {y}): py raised ValueError but rs={rs_err:?}" + ); + } else if e.is_instance_of::(py) { + assert_eq!( + rs_err, + Some(&Error::ERANGE), + "{func_name}({x}, {y}): py raised OverflowError but rs={rs_err:?}" + ); + } else { + panic!("{func_name}({x}, {y}): py raised unexpected error {e}"); + } + } + } + }); +}