-
Notifications
You must be signed in to change notification settings - Fork 241
More efficient bisection for 1D Newton root finder #1012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 3 commits
6bd1d9c
eef5fe0
af08a6c
16b1033
c2d2b76
21dc82b
21d9374
2f60168
a803be2
7acf90c
0708745
19bbf76
752a1c5
777d64a
8c46b6d
f28048d
a14567f
0a64a5f
57c5512
3957093
6d27f3a
66a61a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -213,6 +213,245 @@ inline std::pair<T, T> bisect(F f, T min, T max, Tol tol) noexcept(policies::is_ | |
} | ||
|
||
|
||
////// Motivation for the Bisection namespace. ////// | ||
// | ||
// What's the best way to bisect between a lower bound (lb) and an upper | ||
// bound (ub) during root finding? Let's consider options... | ||
// | ||
// Arithmetic bisection: | ||
// - The natural choice, but it doesn't always work well. For example, if | ||
// lb = 1.0 and ub = std::numeric_limits<float>::max(), many bisections | ||
// may be needed to converge if the root is near 1. | ||
// | ||
// Geometric bisection: | ||
// - This approach performs much better for the example above, but it | ||
// too has issues. For example, if lb = 0.0, it gets stuck at 0.0. | ||
// It also fails if lb and ub have different signs. | ||
// | ||
// In addition to the limitations outlined above, neither of these approaches | ||
// works if ub is infinity. We want a more robust way to handle bisection | ||
// for general root finding problems. That's what this namespace is for. | ||
// | ||
namespace detail { | ||
namespace Bisection { | ||
|
||
////// The Midpoint754 class ////// | ||
// | ||
// On a conceptual level, this class is designed to solve the following root | ||
// finding problem. | ||
// - A function f(x) has a single root x_solution somewhere in the interval | ||
// [-infinity, +infinity]. For all values below x_solution f(x) is -1. | ||
// For all values above x_solution f(x) is +1. The best way to root find | ||
// this problem is to bisect in bit space. | ||
// | ||
// Efficient bit space bisection is possible because of the IEEE 754 standard. | ||
// According to the standard, the bits in floating point numbers are partitioned | ||
// into three parts: sign, exponent, and mantissa. As long as the sign of the | ||
// of the number stays the same, increasing numbers in bit space have increasing | ||
// floating point values starting at zero, and ending at infinity! The table | ||
// below shows select numbers for float (single precision). | ||
// | ||
// 0 | 0 00000000 00000000000000000000000 | positive zero | ||
// 1.4013e-45 | 0 00000000 00000000000000000000001 | std::numeric_limits<float>::denorm_min() | ||
// 1.17549e-38 | 0 00000001 00000000000000000000000 | std::numeric_limits<float>::min() | ||
// 1.19209e-07 | 0 01101000 00000000000000000000000 | std::numeric_limits<float>::epsilon() | ||
// 1 | 0 01111111 00000000000000000000000 | positive one | ||
// 3.40282e+38 | 0 11111110 11111111111111111111111 | std::numeric_limits<float>::max() | ||
// inf | 0 11111111 00000000000000000000000 | std::numeric_limits<float>::infinity() | ||
// nan | 0 11111111 10000000000000000000000 | std::numeric_limits<float>::quiet_NaN() | ||
// | ||
// Negative values are similar, but the sign bit is set to 1. My keeping track of the possible | ||
// sign flip, it can bisect numbers with different signs. | ||
// | ||
template <typename T, typename U> | ||
class Midpoint754 { | ||
private: | ||
// Does the bisection in bit space for IEEE 754 floating point numbers. | ||
// Infinities are allowed. It's assumed that neither x nor X is NaN. | ||
static_assert(std::numeric_limits<T>::is_iec559, "Type must be IEEE 754 floating point."); | ||
static_assert(std::is_unsigned<U>::value, "U must be an unsigned integer type."); | ||
static_assert(sizeof(T) == sizeof(U), "Type and uint size must be the same."); | ||
|
||
// Convert float to bits | ||
static U float_to_uint(T x) { | ||
U bits; | ||
std::memcpy(&bits, &x, sizeof(U)); | ||
return bits; | ||
} | ||
|
||
// Convert bits to float | ||
static T uint_to_float(U bits) { | ||
T x; | ||
std::memcpy(&x, &bits, sizeof(T)); | ||
return x; | ||
} | ||
|
||
public: | ||
static T solve(T x, T X) { | ||
using std::fabs; | ||
|
||
// Sort so that X has the larger magnitude | ||
if (fabs(X) < fabs(x)) { | ||
std::swap(x, X); | ||
} | ||
|
||
const T x_mag = std::fabs(x); | ||
const T X_mag = std::fabs(X); | ||
const T sign_x = sign(x); | ||
const T sign_X = sign(X); | ||
|
||
// Convert the magnitudes to bits | ||
U bits_mag_x = float_to_uint(x_mag); | ||
U bits_mag_X = float_to_uint(X_mag); | ||
|
||
// Calculate the average magnitude in bits | ||
U bits_mag = (sign_x == sign_X) ? (bits_mag_X + bits_mag_x) : (bits_mag_X - bits_mag_x); | ||
bits_mag = bits_mag >> 1; // Divide by 2 | ||
|
||
// Reconstruct upl_mean from average magnitude and sign of X | ||
return uint_to_float(bits_mag) * sign_X; | ||
} | ||
}; // class Midpoint754 | ||
|
||
|
||
template <typename T> | ||
class MidpointNon754 { | ||
private: | ||
static_assert(!std::is_same<T, float>::value, "Need to use Midpoint754 solver when T is float"); | ||
static_assert(!std::is_same<T, double>::value, "Need to use Midpoint754 solver when T is double"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should there be an assertion for long double as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added an explanation as to why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A guess there are a couple cases:
Here is where I have defined macros before for determining the size of long double. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
These systems will use the specialization for
I tried to make this work. Starting with
I got this to work for both |
||
|
||
public: | ||
static T solve(T x, T X) { | ||
const T sx = sign(x); | ||
const T sX = sign(X); | ||
|
||
// Sign flip return zero | ||
if (sx * sX == -1) { return T(0.0); } | ||
|
||
// At least one is positive | ||
if (0 < sx + sX) { return do_solve(x, X); } | ||
|
||
// At least one is negative | ||
return -do_solve(-x, -X); | ||
} | ||
|
||
private: | ||
struct EqZero { | ||
EqZero(T x) { BOOST_MATH_ASSERT(x == 0 && "x must be zero."); } | ||
}; | ||
|
||
struct EqInf { | ||
EqInf(T x) { BOOST_MATH_ASSERT(x == static_cast<T>(std::numeric_limits<double>::infinity()) && "x must be infinity."); } | ||
mborland marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}; | ||
|
||
class PosFinite { | ||
public: | ||
PosFinite(T x) : x_(x) { | ||
BOOST_MATH_ASSERT(0 < x && "x must be positive."); | ||
BOOST_MATH_ASSERT(x < std::numeric_limits<float>::infinity() && "x must be less than infinity."); | ||
ryanelandt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
T value() const { return x_; } | ||
|
||
private: | ||
T x_; | ||
}; | ||
|
||
// Two unknowns | ||
static T do_solve(T x, T X) { | ||
if (X < x) { | ||
return do_solve(X, x); | ||
} | ||
|
||
if (x == 0) { | ||
return do_solve(EqZero(x), X); | ||
} else if (x == static_cast<T>(std::numeric_limits<double>::infinity())) { | ||
return static_cast<T>(std::numeric_limits<double>::infinity()); | ||
} else { | ||
return do_solve(PosFinite(x), X); | ||
} | ||
} | ||
|
||
// One unknowns | ||
static T do_solve(EqZero x, T X) { | ||
if (X == 0) { | ||
return T(0.0); | ||
} else if (X == static_cast<T>(std::numeric_limits<double>::infinity())) { | ||
return T(1.0); | ||
} else { | ||
return do_solve(x, PosFinite(X)); | ||
} | ||
} | ||
static T do_solve(PosFinite x, T X) { | ||
if (X == static_cast<T>(std::numeric_limits<double>::infinity())) { | ||
return do_solve(x, EqInf(X)); | ||
} else { | ||
return do_solve(x, PosFinite(X)); | ||
} | ||
} | ||
|
||
// Zero unknowns | ||
template <typename U = T> | ||
static typename std::enable_if<std::numeric_limits<U>::is_specialized, T>::type | ||
mborland marked this conversation as resolved.
Show resolved
Hide resolved
|
||
do_solve(PosFinite x, EqInf X) { | ||
return do_solve(x, PosFinite((std::numeric_limits<U>::max)())); | ||
} | ||
template <typename U = T> | ||
static typename std::enable_if<!std::numeric_limits<U>::is_specialized, T>::type | ||
mborland marked this conversation as resolved.
Show resolved
Hide resolved
|
||
do_solve(PosFinite x, EqInf X) { | ||
BOOST_MATH_ASSERT(false && "infinite bounds support requires specialization."); | ||
ryanelandt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return static_cast<T>(std::numeric_limits<T>::signaling_NaN()); | ||
} | ||
|
||
template <typename U = T> | ||
static typename std::enable_if<std::numeric_limits<U>::is_specialized, U>::type | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You may be able to get away without checking There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want to be able to bisect with denormals because the solution to the root finding problem could be a denormal number. Reading the depreciation document, it seems like the depreciation of I think we can actually delete There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think so. Since you are already check for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
do_solve(EqZero x, PosFinite X) { | ||
const auto get_smallest_value = []() { | ||
const U denorm_min = std::numeric_limits<U>::denorm_min(); | ||
if (denorm_min != 0) { return denorm_min; } | ||
|
||
const U min = (std::numeric_limits<U>::min)(); | ||
if (min != 0) { return min; } | ||
|
||
BOOST_MATH_ASSERT(false && "denorm_min and min are both zero."); | ||
ryanelandt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return static_cast<T>(std::numeric_limits<T>::signaling_NaN()); | ||
}; | ||
|
||
return do_solve(PosFinite(get_smallest_value()), X); | ||
} | ||
template <typename U = T> | ||
static typename std::enable_if<!std::numeric_limits<U>::is_specialized, U>::type | ||
do_solve(EqZero x, PosFinite X) { return X.value() / U(2); } | ||
|
||
static T do_solve(PosFinite x, PosFinite X) { | ||
BOOST_MATH_ASSERT(x.value() <= X.value() && "x must be less than or equal to X."); | ||
|
||
const T xv = x.value(); | ||
const T Xv = X.value(); | ||
|
||
// Take arithmetic mean if they are close enough | ||
if (Xv < xv * 8) { return (Xv - xv) / 2 + xv; } // NOTE: avoids overflow | ||
|
||
// Take geometric mean if they are far apart | ||
using std::sqrt; | ||
return sqrt(xv) * sqrt(Xv); // NOTE: avoids overflow | ||
} | ||
}; // class MidpointNon754 | ||
|
||
template <typename T> | ||
static T calc_midpoint(T x, T X) { | ||
return MidpointNon754<T>::solve(x, X); | ||
} | ||
static float calc_midpoint(float x, float X) { | ||
return Midpoint754<float, std::uint32_t>::solve(x, X); | ||
} | ||
static double calc_midpoint(double x, double X) { | ||
return Midpoint754<double, std::uint64_t>::solve(x, X); | ||
} | ||
|
||
} // namespace Bisection | ||
} // namespace detail | ||
|
||
template <class F, class T> | ||
T newton_raphson_iterate(F f, T guess, T min, T max, int digits, std::uintmax_t& max_iter) noexcept(policies::is_noexcept_error_policy<policies::policy<> >::value&& BOOST_MATH_IS_FLOAT(T) && noexcept(std::declval<F>()(std::declval<T>()))) | ||
{ | ||
|
@@ -256,8 +495,13 @@ T newton_raphson_iterate(F f, T guess, T min, T max, int digits, std::uintmax_t& | |
last_f0 = f0; | ||
delta2 = delta1; | ||
delta1 = delta; | ||
if (count == 0) { | ||
return policies::raise_evaluation_error(function, "Ran out of iterations in boost::math::tools::newton_raphson_iterate, guess: %1%", guess, boost::math::policies::policy<>()); | ||
} else { | ||
--count; | ||
} | ||
detail::unpack_tuple(f(result), f0, f1); | ||
--count; | ||
|
||
if (0 == f0) | ||
break; | ||
if (f1 == 0) | ||
|
@@ -275,7 +519,8 @@ T newton_raphson_iterate(F f, T guess, T min, T max, int digits, std::uintmax_t& | |
if (fabs(delta * 2) > fabs(delta2)) | ||
{ | ||
// Last two steps haven't converged. | ||
delta = (delta > 0) ? (result - min) / 2 : (result - max) / 2; | ||
const T x_other = (delta > 0) ? min : max; | ||
delta = result - detail::Bisection::calc_midpoint(result, x_other); | ||
// reset delta1/2 so we don't take this branch next time round: | ||
delta1 = 3 * delta; | ||
delta2 = 3 * delta; | ||
|
@@ -302,7 +547,7 @@ T newton_raphson_iterate(F f, T guess, T min, T max, int digits, std::uintmax_t& | |
max = guess; | ||
max_range_f = f0; | ||
} | ||
else | ||
else if (delta < 0) // Cannot have "else" here, as delta being zero is not indicative of failure | ||
{ | ||
min = guess; | ||
min_range_f = f0; | ||
|
@@ -314,7 +559,7 @@ T newton_raphson_iterate(F f, T guess, T min, T max, int digits, std::uintmax_t& | |
{ | ||
return policies::raise_evaluation_error(function, "There appears to be no root to be found in boost::math::tools::newton_raphson_iterate, perhaps we have a local minima near current best guess of %1%", guess, boost::math::policies::policy<>()); | ||
} | ||
}while(count && (fabs(result * factor) < fabs(delta))); | ||
} while (fabs(result * factor) < fabs(delta) || result == 0); | ||
|
||
max_iter -= count; | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.