@@ -171,49 +171,49 @@ test_functions = [
171171 (name = "Wilkinson-like polynomial",
172172 f = (u, p) -> (u - 1) * (u - 2) * (u - 3) * (u - 4) * (u - 5) - p,
173173 interval = (0.5, 5.5),
174- p_range = (-0.1, 0.1) ),
174+ p = 0.05 ),
175175
176176 # Function 2: Trigonometric with multiple roots
177177 (name = "sin(x) - 0.5x",
178178 f = (u, p) -> sin(u) - 0.5*u - p,
179179 interval = (-10.0, 10.0),
180- p_range = (-0.5, 0.5) ),
180+ p = 0.3 ),
181181
182182 # Function 3: Exponential function (sensitive near zero)
183183 (name = "exp(x) - 1 - x - x²/2",
184184 f = (u, p) -> exp(u) - 1 - u - u^2/2 - p,
185185 interval = (-2.0, 2.0),
186- p_range = (-0.01, 0.01) ),
186+ p = 0.005 ),
187187
188188 # Function 4: Rational function with pole
189189 (name = "1/(x-0.5) - 2",
190190 f = (u, p) -> 1/(u - 0.5) - 2 - p,
191191 interval = (0.6, 2.0),
192- p_range = (-0.1, 0.1) ),
192+ p = 0.05 ),
193193
194194 # Function 5: Logarithmic function
195195 (name = "log(x) - x + 2",
196196 f = (u, p) -> log(u) - u + 2 - p,
197197 interval = (0.1, 3.0),
198- p_range = (-0.1, 0.1) ),
198+ p = 0.05 ),
199199
200200 # Function 6: High oscillation function
201201 (name = "sin(20x) + 0.1x",
202202 f = (u, p) -> sin(20*u) + 0.1*u - p,
203203 interval = (-5.0, 5.0),
204- p_range = (-0.2, 0.2) ),
204+ p = 0.1 ),
205205
206206 # Function 7: Function with very flat region
207207 (name = "x³ - 2x² + x",
208208 f = (u, p) -> u^3 - 2*u^2 + u - p,
209209 interval = (-1.0, 2.0),
210- p_range = (-0.05, 0.05) ),
210+ p = 0.025 ),
211211
212212 # Function 8: Bessel-like function
213213 (name = "x·sin(1/x) - 0.1",
214214 f = (u, p) -> u * sin(1/u) - 0.1 - p,
215215 interval = (0.01, 1.0),
216- p_range = (-0.02, 0.02) ),
216+ p = 0.01 ),
217217]
218218
219219# Add SimpleNonlinearSolve algorithms
@@ -235,25 +235,32 @@ all_algorithms = [
235235
236236# Benchmark function for testing all algorithms on a given function
237237function benchmark_function(test_func, N_samples=10000)
238- Random.seed!(42)
239- ps = test_func.p_range[1] .+ (test_func.p_range[2] - test_func.p_range[1]) .* rand(N_samples)
240-
241238 println("\\n=== Testing: $(test_func.name) ===")
242239 println("Interval: $(test_func.interval)")
243- println("Parameter range : $(test_func.p_range )")
240+ println("Parameter: $(test_func.p )")
244241
245242 results = []
246243
247244 # Test Roots.jl baseline
248245 try
249- out_roots = zeros(N_samples)
246+ # Cache the function for Roots.jl
247+ roots_func = u -> test_func.f(u, test_func.p)
248+
249+ # Warmup run to exclude compilation time
250+ find_zero(roots_func, test_func.interval)
251+
252+ # Actual timing
250253 time_roots = @elapsed begin
251254 for i in 1:N_samples
252- out_roots[i] = find_zero(u -> test_func.f(u, ps[i]) , test_func.interval)
255+ root = find_zero(roots_func , test_func.interval)
253256 end
254257 end
255- error_roots = mean(abs.(test_func.f.(out_roots, ps)))
256- println("Roots.jl: $(round(time_roots*1000, digits=2)) ms, MAE: $(round(error_roots, sigdigits=3))")
258+
259+ # Calculate error using one solve
260+ final_root = find_zero(roots_func, test_func.interval)
261+ error_roots = abs(test_func.f(final_root, test_func.p))
262+
263+ println("Roots.jl: $(round(time_roots*1000, digits=2)) ms, Error: $(round(error_roots, sigdigits=3))")
257264 push!(results, (name="Roots.jl", time=time_roots, error=error_roots, success=true))
258265 catch e
259266 println("Roots.jl: FAILED - $e")
@@ -263,18 +270,30 @@ function benchmark_function(test_func, N_samples=10000)
263270 # Test all algorithms
264271 for alg_info in all_algorithms
265272 try
266- out = zeros(N_samples)
273+ # Warmup run to exclude compilation time
274+ prob_warmup = IntervalNonlinearProblem{false}(
275+ IntervalNonlinearFunction{false}(test_func.f),
276+ test_func.interval, test_func.p)
277+ solve(prob_warmup, alg_info.alg())
278+
279+ # Actual timing
267280 time_taken = @elapsed begin
268281 for i in 1:N_samples
269282 prob = IntervalNonlinearProblem{false}(
270- IntervalNonlinearFunction{false}((u, p) -> test_func.f(u, p) ),
271- test_func.interval, ps[i] )
283+ IntervalNonlinearFunction{false}(test_func.f),
284+ test_func.interval, test_func.p )
272285 sol = solve(prob, alg_info.alg())
273- out[i] = sol.u
274286 end
275287 end
276- error_val = mean(abs.(test_func.f.(out, ps)))
277- println("$(alg_info.name): $(round(time_taken*1000, digits=2)) ms, MAE: $(round(error_val, sigdigits=3))")
288+
289+ # Calculate error using one solve
290+ prob_final = IntervalNonlinearProblem{false}(
291+ IntervalNonlinearFunction{false}(test_func.f),
292+ test_func.interval, test_func.p)
293+ sol_final = solve(prob_final, alg_info.alg())
294+ error_val = abs(test_func.f(sol_final.u, test_func.p))
295+
296+ println("$(alg_info.name): $(round(time_taken*1000, digits=2)) ms, Error: $(round(error_val, sigdigits=3))")
278297 push!(results, (name=alg_info.name, time=time_taken, error=error_val, success=true))
279298 catch e
280299 println("$(alg_info.name): FAILED - $e")
288307# Run benchmarks on all test functions
289308all_results = []
290309for test_func in test_functions
291- results = benchmark_function(test_func, 5000 ) # Use smaller N for comprehensive testing
310+ results = benchmark_function(test_func, 10000 ) # Increased N since we're using fixed parameters
292311 push!(all_results, (func_name=test_func.name, results=results))
293312end
294313```
@@ -339,9 +358,10 @@ function print_summary_table(all_results)
339358
340359 println("\\n" * "="^80)
341360 println("Notes:")
342- println("- Times shown in milliseconds for 5000 function evaluations")
361+ println("- Times shown in milliseconds for 10000 function evaluations")
343362 println("- BNS = BracketingNonlinearSolve.jl, SNS = SimpleNonlinearSolve.jl")
344363 println("- FAIL indicates algorithm failed or took excessive time")
364+ println("- Compilation time excluded via warmup runs")
345365 println("="^80)
346366end
347367
@@ -355,7 +375,7 @@ Now let's examine the accuracy of each method:
355375```julia
356376function print_accuracy_table(all_results)
357377 println("\\n" * "="^80)
358- println("ACCURACY ANALYSIS (Mean Absolute Error)")
378+ println("ACCURACY ANALYSIS (Absolute Error)")
359379 println("="^80)
360380
361381 alg_names = unique([r.name for func_results in all_results for r in func_results.results])
0 commit comments