Skip to content

Commit 3011af1

Browse files
fix up compile times
1 parent 9df0c5e commit 3011af1

File tree

1 file changed

+45
-25
lines changed

1 file changed

+45
-25
lines changed

benchmarks/IntervalNonlinearProblem/suite.jmd

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -171,49 +171,49 @@ test_functions = [
171171
(name = "Wilkinson-like polynomial",
172172
f = (u, p) -> (u - 1) * (u - 2) * (u - 3) * (u - 4) * (u - 5) - p,
173173
interval = (0.5, 5.5),
174-
p_range = (-0.1, 0.1)),
174+
p = 0.05),
175175

176176
# Function 2: Trigonometric with multiple roots
177177
(name = "sin(x) - 0.5x",
178178
f = (u, p) -> sin(u) - 0.5*u - p,
179179
interval = (-10.0, 10.0),
180-
p_range = (-0.5, 0.5)),
180+
p = 0.3),
181181

182182
# Function 3: Exponential function (sensitive near zero)
183183
(name = "exp(x) - 1 - x - x²/2",
184184
f = (u, p) -> exp(u) - 1 - u - u^2/2 - p,
185185
interval = (-2.0, 2.0),
186-
p_range = (-0.01, 0.01)),
186+
p = 0.005),
187187

188188
# Function 4: Rational function with pole
189189
(name = "1/(x-0.5) - 2",
190190
f = (u, p) -> 1/(u - 0.5) - 2 - p,
191191
interval = (0.6, 2.0),
192-
p_range = (-0.1, 0.1)),
192+
p = 0.05),
193193

194194
# Function 5: Logarithmic function
195195
(name = "log(x) - x + 2",
196196
f = (u, p) -> log(u) - u + 2 - p,
197197
interval = (0.1, 3.0),
198-
p_range = (-0.1, 0.1)),
198+
p = 0.05),
199199

200200
# Function 6: High oscillation function
201201
(name = "sin(20x) + 0.1x",
202202
f = (u, p) -> sin(20*u) + 0.1*u - p,
203203
interval = (-5.0, 5.0),
204-
p_range = (-0.2, 0.2)),
204+
p = 0.1),
205205

206206
# Function 7: Function with very flat region
207207
(name = "x³ - 2x² + x",
208208
f = (u, p) -> u^3 - 2*u^2 + u - p,
209209
interval = (-1.0, 2.0),
210-
p_range = (-0.05, 0.05)),
210+
p = 0.025),
211211

212212
# Function 8: Bessel-like function
213213
(name = "x·sin(1/x) - 0.1",
214214
f = (u, p) -> u * sin(1/u) - 0.1 - p,
215215
interval = (0.01, 1.0),
216-
p_range = (-0.02, 0.02)),
216+
p = 0.01),
217217
]
218218

219219
# Add SimpleNonlinearSolve algorithms
@@ -235,25 +235,32 @@ all_algorithms = [
235235

236236
# Benchmark function for testing all algorithms on a given function
237237
function benchmark_function(test_func, N_samples=10000)
238-
Random.seed!(42)
239-
ps = test_func.p_range[1] .+ (test_func.p_range[2] - test_func.p_range[1]) .* rand(N_samples)
240-
241238
println("\\n=== Testing: $(test_func.name) ===")
242239
println("Interval: $(test_func.interval)")
243-
println("Parameter range: $(test_func.p_range)")
240+
println("Parameter: $(test_func.p)")
244241

245242
results = []
246243

247244
# Test Roots.jl baseline
248245
try
249-
out_roots = zeros(N_samples)
246+
# Cache the function for Roots.jl
247+
roots_func = u -> test_func.f(u, test_func.p)
248+
249+
# Warmup run to exclude compilation time
250+
find_zero(roots_func, test_func.interval)
251+
252+
# Actual timing
250253
time_roots = @elapsed begin
251254
for i in 1:N_samples
252-
out_roots[i] = find_zero(u -> test_func.f(u, ps[i]), test_func.interval)
255+
root = find_zero(roots_func, test_func.interval)
253256
end
254257
end
255-
error_roots = mean(abs.(test_func.f.(out_roots, ps)))
256-
println("Roots.jl: $(round(time_roots*1000, digits=2)) ms, MAE: $(round(error_roots, sigdigits=3))")
258+
259+
# Calculate error using one solve
260+
final_root = find_zero(roots_func, test_func.interval)
261+
error_roots = abs(test_func.f(final_root, test_func.p))
262+
263+
println("Roots.jl: $(round(time_roots*1000, digits=2)) ms, Error: $(round(error_roots, sigdigits=3))")
257264
push!(results, (name="Roots.jl", time=time_roots, error=error_roots, success=true))
258265
catch e
259266
println("Roots.jl: FAILED - $e")
@@ -263,18 +270,30 @@ function benchmark_function(test_func, N_samples=10000)
263270
# Test all algorithms
264271
for alg_info in all_algorithms
265272
try
266-
out = zeros(N_samples)
273+
# Warmup run to exclude compilation time
274+
prob_warmup = IntervalNonlinearProblem{false}(
275+
IntervalNonlinearFunction{false}(test_func.f),
276+
test_func.interval, test_func.p)
277+
solve(prob_warmup, alg_info.alg())
278+
279+
# Actual timing
267280
time_taken = @elapsed begin
268281
for i in 1:N_samples
269282
prob = IntervalNonlinearProblem{false}(
270-
IntervalNonlinearFunction{false}((u, p) -> test_func.f(u, p)),
271-
test_func.interval, ps[i])
283+
IntervalNonlinearFunction{false}(test_func.f),
284+
test_func.interval, test_func.p)
272285
sol = solve(prob, alg_info.alg())
273-
out[i] = sol.u
274286
end
275287
end
276-
error_val = mean(abs.(test_func.f.(out, ps)))
277-
println("$(alg_info.name): $(round(time_taken*1000, digits=2)) ms, MAE: $(round(error_val, sigdigits=3))")
288+
289+
# Calculate error using one solve
290+
prob_final = IntervalNonlinearProblem{false}(
291+
IntervalNonlinearFunction{false}(test_func.f),
292+
test_func.interval, test_func.p)
293+
sol_final = solve(prob_final, alg_info.alg())
294+
error_val = abs(test_func.f(sol_final.u, test_func.p))
295+
296+
println("$(alg_info.name): $(round(time_taken*1000, digits=2)) ms, Error: $(round(error_val, sigdigits=3))")
278297
push!(results, (name=alg_info.name, time=time_taken, error=error_val, success=true))
279298
catch e
280299
println("$(alg_info.name): FAILED - $e")
@@ -288,7 +307,7 @@ end
288307
# Run benchmarks on all test functions
289308
all_results = []
290309
for test_func in test_functions
291-
results = benchmark_function(test_func, 5000) # Use smaller N for comprehensive testing
310+
results = benchmark_function(test_func, 10000) # Increased N since we're using fixed parameters
292311
push!(all_results, (func_name=test_func.name, results=results))
293312
end
294313
```
@@ -339,9 +358,10 @@ function print_summary_table(all_results)
339358

340359
println("\\n" * "="^80)
341360
println("Notes:")
342-
println("- Times shown in milliseconds for 5000 function evaluations")
361+
println("- Times shown in milliseconds for 10000 function evaluations")
343362
println("- BNS = BracketingNonlinearSolve.jl, SNS = SimpleNonlinearSolve.jl")
344363
println("- FAIL indicates algorithm failed or took excessive time")
364+
println("- Compilation time excluded via warmup runs")
345365
println("="^80)
346366
end
347367

@@ -355,7 +375,7 @@ Now let's examine the accuracy of each method:
355375
```julia
356376
function print_accuracy_table(all_results)
357377
println("\\n" * "="^80)
358-
println("ACCURACY ANALYSIS (Mean Absolute Error)")
378+
println("ACCURACY ANALYSIS (Absolute Error)")
359379
println("="^80)
360380

361381
alg_names = unique([r.name for func_results in all_results for r in func_results.results])

0 commit comments

Comments
 (0)