|
1 | 1 | --- |
2 | | -title: NonlinearSolve.jl suite of interval root-finding algorithms |
| 2 | +title: Interval root-finding test suite |
3 | 3 | author: Fabian Gittins |
4 | 4 | --- |
5 | 5 |
|
6 | 6 | In this benchmark, we will examine how the interval root-finding algorithms |
7 | | -provided in `NonlinearSolve.jl` fare against one another for a selection of |
8 | | -examples. |
| 7 | +provided in `NonlinearSolve.jl` and `SimpleNonlinearSolve.jl` fare against one another for a selection of |
| 8 | +challenging test functions from the literature. |
9 | 9 |
|
10 | 10 | ## `Roots.jl` baseline |
11 | 11 |
|
@@ -155,6 +155,352 @@ than others. This is entirely to be expected as some of the algorithms, like |
155 | 155 | `Bisection`, bracket the root and thus will reliably converge to high accuracy. |
156 | 156 | Others, like `Muller`, are not bracketing methods, but can be extremely fast. |
157 | 157 |
|
| 158 | +## Extended Test Suite with Challenging Functions |
| 159 | + |
| 160 | +Now we'll test the algorithms on a comprehensive suite of challenging test functions |
| 161 | +commonly used in the interval rootfinding literature. These functions exhibit various |
| 162 | +difficulties such as multiple roots, nearly flat regions, discontinuities, and |
| 163 | +extreme sensitivity. |
| 164 | + |
| 165 | +```julia |
| 166 | +using Statistics |
| 167 | + |
| 168 | +# Define challenging test functions |
| 169 | +test_functions = [ |
| 170 | + # Function 1: Polynomial with multiple roots |
| 171 | + (name = "Wilkinson-like polynomial", |
| 172 | + f = (u, p) -> (u - 1) * (u - 2) * (u - 3) * (u - 4) * (u - 5) - p, |
| 173 | + interval = (0.5, 5.5), |
| 174 | + p = 0.05), |
| 175 | + |
| 176 | + # Function 2: Trigonometric with multiple roots |
| 177 | + (name = "sin(x) - 0.5x", |
| 178 | + f = (u, p) -> sin(u) - 0.5*u - p, |
| 179 | + interval = (-10.0, 10.0), |
| 180 | + p = 0.3), |
| 181 | + |
| 182 | + # Function 3: Exponential function (sensitive near zero) |
| 183 | + (name = "exp(x) - 1 - x - x²/2", |
| 184 | + f = (u, p) -> exp(u) - 1 - u - u^2/2 - p, |
| 185 | + interval = (-2.0, 2.0), |
| 186 | + p = 0.005), |
| 187 | + |
| 188 | + # Function 4: Rational function with pole |
| 189 | + (name = "1/(x-0.5) - 2", |
| 190 | + f = (u, p) -> 1/(u - 0.5) - 2 - p, |
| 191 | + interval = (0.6, 2.0), |
| 192 | + p = 0.05), |
| 193 | + |
| 194 | + # Function 5: Logarithmic function |
| 195 | + (name = "log(x) - x + 2", |
| 196 | + f = (u, p) -> log(u) - u + 2 - p, |
| 197 | + interval = (0.1, 3.0), |
| 198 | + p = 0.05), |
| 199 | + |
| 200 | + # Function 6: High oscillation function |
| 201 | + (name = "sin(20x) + 0.1x", |
| 202 | + f = (u, p) -> sin(20*u) + 0.1*u - p, |
| 203 | + interval = (-5.0, 5.0), |
| 204 | + p = 0.1), |
| 205 | + |
| 206 | + # Function 7: Function with very flat region |
| 207 | + (name = "x³ - 2x² + x", |
| 208 | + f = (u, p) -> u^3 - 2*u^2 + u - p, |
| 209 | + interval = (-1.0, 2.0), |
| 210 | + p = 0.025), |
| 211 | + |
| 212 | + # Function 8: Bessel-like function |
| 213 | + (name = "x·sin(1/x) - 0.1", |
| 214 | + f = (u, p) -> u * sin(1/u) - 0.1 - p, |
| 215 | + interval = (0.01, 1.0), |
| 216 | + p = 0.01), |
| 217 | +] |
| 218 | + |
| 219 | +# Add SimpleNonlinearSolve algorithms |
| 220 | +using SimpleNonlinearSolve |
| 221 | + |
| 222 | +# Combined algorithm list from both packages |
| 223 | +all_algorithms = [ |
| 224 | + (name = "Alefeld (BNS)", alg = () -> Alefeld(), package = "BracketingNonlinearSolve"), |
| 225 | + (name = "Bisection (BNS)", alg = () -> NonlinearSolve.Bisection(), package = "BracketingNonlinearSolve"), |
| 226 | + (name = "Brent (BNS)", alg = () -> Brent(), package = "BracketingNonlinearSolve"), |
| 227 | + (name = "Falsi (BNS)", alg = () -> Falsi(), package = "BracketingNonlinearSolve"), |
| 228 | + (name = "ITP (BNS)", alg = () -> ITP(), package = "BracketingNonlinearSolve"), |
| 229 | + (name = "Ridder (BNS)", alg = () -> Ridder(), package = "BracketingNonlinearSolve"), |
| 230 | + (name = "Bisection (SNS)", alg = () -> SimpleNonlinearSolve.Bisection(), package = "SimpleNonlinearSolve"), |
| 231 | + (name = "Brent (SNS)", alg = () -> SimpleNonlinearSolve.Brent(), package = "SimpleNonlinearSolve"), |
| 232 | + (name = "Falsi (SNS)", alg = () -> SimpleNonlinearSolve.Falsi(), package = "SimpleNonlinearSolve"), |
| 233 | + (name = "Ridders (SNS)", alg = () -> SimpleNonlinearSolve.Ridders(), package = "SimpleNonlinearSolve") |
| 234 | +] |
| 235 | + |
| 236 | +# Benchmark function for testing all algorithms on a given function |
| 237 | +function benchmark_function(test_func, N_samples=10000) |
| 238 | + println("\\n=== Testing: $(test_func.name) ===") |
| 239 | + println("Interval: $(test_func.interval)") |
| 240 | + println("Parameter: $(test_func.p)") |
| 241 | + |
| 242 | + results = [] |
| 243 | + |
| 244 | + # Test Roots.jl baseline |
| 245 | + try |
| 246 | + # Cache the function for Roots.jl |
| 247 | + roots_func = u -> test_func.f(u, test_func.p) |
| 248 | + |
| 249 | + # Warmup run to exclude compilation time |
| 250 | + find_zero(roots_func, test_func.interval) |
| 251 | + |
| 252 | + # Actual timing |
| 253 | + time_roots = @elapsed begin |
| 254 | + for i in 1:N_samples |
| 255 | + root = find_zero(roots_func, test_func.interval) |
| 256 | + end |
| 257 | + end |
| 258 | + |
| 259 | + # Calculate error using one solve |
| 260 | + final_root = find_zero(roots_func, test_func.interval) |
| 261 | + error_roots = abs(test_func.f(final_root, test_func.p)) |
| 262 | + |
| 263 | + println("Roots.jl: $(round(time_roots*1000, digits=2)) ms, Error: $(round(error_roots, sigdigits=3))") |
| 264 | + push!(results, (name="Roots.jl", time=time_roots, error=error_roots, success=true)) |
| 265 | + catch e |
| 266 | + println("Roots.jl: FAILED - $e") |
| 267 | + push!(results, (name="Roots.jl", time=Inf, error=Inf, success=false)) |
| 268 | + end |
| 269 | + |
| 270 | + # Test all algorithms |
| 271 | + for alg_info in all_algorithms |
| 272 | + try |
| 273 | + # Warmup run to exclude compilation time |
| 274 | + prob_warmup = IntervalNonlinearProblem{false}( |
| 275 | + IntervalNonlinearFunction{false}(test_func.f), |
| 276 | + test_func.interval, test_func.p) |
| 277 | + solve(prob_warmup, alg_info.alg()) |
| 278 | + |
| 279 | + # Actual timing |
| 280 | + time_taken = @elapsed begin |
| 281 | + for i in 1:N_samples |
| 282 | + prob = IntervalNonlinearProblem{false}( |
| 283 | + IntervalNonlinearFunction{false}(test_func.f), |
| 284 | + test_func.interval, test_func.p) |
| 285 | + sol = solve(prob, alg_info.alg()) |
| 286 | + end |
| 287 | + end |
| 288 | + |
| 289 | + # Calculate error using one solve |
| 290 | + prob_final = IntervalNonlinearProblem{false}( |
| 291 | + IntervalNonlinearFunction{false}(test_func.f), |
| 292 | + test_func.interval, test_func.p) |
| 293 | + sol_final = solve(prob_final, alg_info.alg()) |
| 294 | + error_val = abs(test_func.f(sol_final.u, test_func.p)) |
| 295 | + |
| 296 | + println("$(alg_info.name): $(round(time_taken*1000, digits=2)) ms, Error: $(round(error_val, sigdigits=3))") |
| 297 | + push!(results, (name=alg_info.name, time=time_taken, error=error_val, success=true)) |
| 298 | + catch e |
| 299 | + println("$(alg_info.name): FAILED - $e") |
| 300 | + push!(results, (name=alg_info.name, time=Inf, error=Inf, success=false)) |
| 301 | + end |
| 302 | + end |
| 303 | + |
| 304 | + return results |
| 305 | +end |
| 306 | + |
| 307 | +# Run benchmarks on all test functions |
| 308 | +all_results = [] |
| 309 | +for test_func in test_functions |
| 310 | + results = benchmark_function(test_func, 10000) # Increased N since we're using fixed parameters |
| 311 | + push!(all_results, (func_name=test_func.name, results=results)) |
| 312 | +end |
| 313 | +``` |
| 314 | + |
| 315 | +## Performance Summary |
| 316 | + |
| 317 | +Let's create a summary table of the results: |
| 318 | + |
| 319 | +```julia |
| 320 | +using Printf |
| 321 | + |
| 322 | +function print_summary_table(all_results) |
| 323 | + println("\\n" * "="^80) |
| 324 | + println("COMPREHENSIVE BENCHMARK SUMMARY") |
| 325 | + println("="^80) |
| 326 | + |
| 327 | + # Get all algorithm names |
| 328 | + alg_names = unique([r.name for func_results in all_results for r in func_results.results]) |
| 329 | + |
| 330 | + # Print header |
| 331 | + @printf "%-25s" "Function" |
| 332 | + for alg in alg_names |
| 333 | + @printf "%-15s" alg[1:min(14, length(alg))] |
| 334 | + end |
| 335 | + println() |
| 336 | + println("-"^(25 + 15*length(alg_names))) |
| 337 | + |
| 338 | + # Print results for each function |
| 339 | + for func_result in all_results |
| 340 | + @printf "%-25s" func_result.func_name[1:min(24, length(func_result.func_name))] |
| 341 | + |
| 342 | + for alg in alg_names |
| 343 | + # Find result for this algorithm |
| 344 | + alg_result = findfirst(r -> r.name == alg, func_result.results) |
| 345 | + if alg_result !== nothing |
| 346 | + result = func_result.results[alg_result] |
| 347 | + if result.success && result.time < 1.0 # Reasonable time limit |
| 348 | + @printf "%-15s" "$(round(result.time*1000, digits=1))ms" |
| 349 | + else |
| 350 | + @printf "%-15s" "FAIL" |
| 351 | + end |
| 352 | + else |
| 353 | + @printf "%-15s" "N/A" |
| 354 | + end |
| 355 | + end |
| 356 | + println() |
| 357 | + end |
| 358 | + |
| 359 | + println("\\n" * "="^80) |
| 360 | + println("Notes:") |
| 361 | + println("- Times shown in milliseconds for 10000 function evaluations") |
| 362 | + println("- BNS = BracketingNonlinearSolve.jl, SNS = SimpleNonlinearSolve.jl") |
| 363 | + println("- FAIL indicates algorithm failed or took excessive time") |
| 364 | + println("- Compilation time excluded via warmup runs") |
| 365 | + println("="^80) |
| 366 | +end |
| 367 | + |
| 368 | +print_summary_table(all_results) |
| 369 | +``` |
| 370 | + |
| 371 | +## Accuracy Analysis |
| 372 | + |
| 373 | +Now let's examine the accuracy of each method: |
| 374 | + |
| 375 | +```julia |
| 376 | +function print_accuracy_table(all_results) |
| 377 | + println("\\n" * "="^80) |
| 378 | + println("ACCURACY ANALYSIS (Absolute Error)") |
| 379 | + println("="^80) |
| 380 | + |
| 381 | + alg_names = unique([r.name for func_results in all_results for r in func_results.results]) |
| 382 | + |
| 383 | + # Print header |
| 384 | + @printf "%-25s" "Function" |
| 385 | + for alg in alg_names |
| 386 | + @printf "%-15s" alg[1:min(14, length(alg))] |
| 387 | + end |
| 388 | + println() |
| 389 | + println("-"^(25 + 15*length(alg_names))) |
| 390 | + |
| 391 | + # Print results for each function |
| 392 | + for func_result in all_results |
| 393 | + @printf "%-25s" func_result.func_name[1:min(24, length(func_result.func_name))] |
| 394 | + |
| 395 | + for alg in alg_names |
| 396 | + alg_result = findfirst(r -> r.name == alg, func_result.results) |
| 397 | + if alg_result !== nothing |
| 398 | + result = func_result.results[alg_result] |
| 399 | + if result.success && result.error < 1e10 |
| 400 | + @printf "%-15s" "$(round(result.error, sigdigits=2))" |
| 401 | + else |
| 402 | + @printf "%-15s" "FAIL" |
| 403 | + end |
| 404 | + else |
| 405 | + @printf "%-15s" "N/A" |
| 406 | + end |
| 407 | + end |
| 408 | + println() |
| 409 | + end |
| 410 | + |
| 411 | + println("="^80) |
| 412 | +end |
| 413 | + |
| 414 | +print_accuracy_table(all_results) |
| 415 | +``` |
| 416 | + |
| 417 | +## Algorithm Rankings |
| 418 | + |
| 419 | +Finally, let's rank the algorithms by overall performance: |
| 420 | + |
| 421 | +```julia |
| 422 | +function rank_algorithms(all_results) |
| 423 | + println("\\n" * "="^60) |
| 424 | + println("ALGORITHM RANKINGS") |
| 425 | + println("="^60) |
| 426 | + |
| 427 | + # Calculate scores for each algorithm |
| 428 | + alg_scores = Dict() |
| 429 | + |
| 430 | + for func_result in all_results |
| 431 | + for result in func_result.results |
| 432 | + if !haskey(alg_scores, result.name) |
| 433 | + alg_scores[result.name] = Dict(:time_score => 0.0, :accuracy_score => 0.0, :success_count => 0) |
| 434 | + end |
| 435 | + |
| 436 | + if result.success |
| 437 | + alg_scores[result.name][:success_count] += 1 |
| 438 | + # Lower time is better (inverse score) |
| 439 | + alg_scores[result.name][:time_score] += result.time < 1.0 ? 1.0 / result.time : 0.0 |
| 440 | + # Lower error is better (inverse score) |
| 441 | + alg_scores[result.name][:accuracy_score] += result.error < 1e10 ? 1.0 / (result.error + 1e-15) : 0.0 |
| 442 | + end |
| 443 | + end |
| 444 | + end |
| 445 | + |
| 446 | + # Normalize and combine scores |
| 447 | + total_functions = length(all_results) |
| 448 | + algorithm_rankings = [] |
| 449 | + |
| 450 | + for (alg, scores) in alg_scores |
| 451 | + success_rate = scores[:success_count] / total_functions |
| 452 | + avg_speed_score = scores[:time_score] / total_functions |
| 453 | + avg_accuracy_score = scores[:accuracy_score] / total_functions |
| 454 | + |
| 455 | + # Combined score (weighted: 40% success rate, 30% speed, 30% accuracy) |
| 456 | + combined_score = 0.4 * success_rate + 0.3 * (avg_speed_score / 1000) + 0.3 * (avg_accuracy_score / 1e12) |
| 457 | + |
| 458 | + push!(algorithm_rankings, ( |
| 459 | + name = alg, |
| 460 | + success_rate = success_rate, |
| 461 | + speed_score = avg_speed_score, |
| 462 | + accuracy_score = avg_accuracy_score, |
| 463 | + combined_score = combined_score |
| 464 | + )) |
| 465 | + end |
| 466 | + |
| 467 | + # Sort by combined score |
| 468 | + sort!(algorithm_rankings, by = x -> x.combined_score, rev = true) |
| 469 | + |
| 470 | + println("Rank | Algorithm | Success Rate | Combined Score") |
| 471 | + println("-"^60) |
| 472 | + for (i, alg) in enumerate(algorithm_rankings) |
| 473 | + @printf "%-4d | %-18s | %-11.1f%% | %-12.3f\\n" i alg.name[1:min(18, length(alg.name))] (alg.success_rate*100) alg.combined_score |
| 474 | + end |
| 475 | + |
| 476 | + println("="^60) |
| 477 | + println("Note: Combined score weights success rate (40%), speed (30%), and accuracy (30%)") |
| 478 | +end |
| 479 | + |
| 480 | +rank_algorithms(all_results) |
| 481 | +``` |
| 482 | + |
| 483 | +## Conclusion |
| 484 | + |
| 485 | +This extended benchmark suite demonstrates the performance and accuracy characteristics of interval rootfinding algorithms across a diverse set of challenging test functions. The test functions include: |
| 486 | + |
| 487 | +1. **Polynomial functions** with multiple roots |
| 488 | +2. **Trigonometric functions** with oscillatory behavior |
| 489 | +3. **Exponential functions** with high sensitivity |
| 490 | +4. **Rational functions** with singularities |
| 491 | +5. **Logarithmic functions** with domain restrictions |
| 492 | +6. **Highly oscillatory functions** testing robustness |
| 493 | +7. **Functions with flat regions** challenging convergence |
| 494 | +8. **Bessel-like functions** with complex behavior |
| 495 | + |
| 496 | +The benchmark compares algorithms from both `BracketingNonlinearSolve.jl` and `SimpleNonlinearSolve.jl`, providing insights into: |
| 497 | +- **Robustness**: Which algorithms handle challenging functions |
| 498 | +- **Speed**: Computational efficiency across different problem types |
| 499 | +- **Accuracy**: Precision of the found roots |
| 500 | +- **Reliability**: Success rates across diverse test cases |
| 501 | + |
| 502 | +This comprehensive evaluation helps users choose the most appropriate interval rootfinding algorithm for their specific applications. |
| 503 | + |
158 | 504 | ```julia, echo = false |
159 | 505 | using SciMLBenchmarks |
160 | 506 | SciMLBenchmarks.bench_footer(WEAVE_ARGS[:folder], WEAVE_ARGS[:file]) |
|
0 commit comments