Skip to content

Commit d6ec431

Browse files
benchmark added
1 parent 966bed3 commit d6ec431

File tree

2 files changed

+315
-0
lines changed

2 files changed

+315
-0
lines changed

benchmarks/Jumps/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
33
Catalyst = "479239e8-5488-4da2-87a7-35f2df7eef83"
44
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
55
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
6+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
67
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
78
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
89
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
910
JumpProblemLibrary = "faf0f6d7-8cee-47cb-b27c-1eb80cef534e"
1011
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1214
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
1315
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1416
PiecewiseDeterministicMarkovProcesses = "86206cdf-4603-54e0-bd58-22a2dcbf57aa"
@@ -20,6 +22,7 @@ SciMLBenchmarks = "31c91b34-3c75-11e9-0341-95557aab0344"
2022
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2123
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2224
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
25+
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
2326
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
2427

2528
[compat]
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
---
2+
title: Benchmarking Variable Rate Aggregator
3+
author: Siva Sathyaseelan D N, Chris Rackauckas, Samuel Isaacson
4+
weave_options:
5+
fig_ext : ".png"
6+
---
7+
8+
```julia
9+
using DiffEqBase, JumpProcesses, OrdinaryDiffEq, StochasticDiffEq
10+
using Random, LinearSolve, StableRNGs, BenchmarkTools, Plots, LinearAlgebra
11+
fmt = :png
12+
width_px, height_px = default(:size)
13+
rng = StableRNG(12345)
14+
```
15+
16+
# Introduction
17+
18+
This document benchmarks the performance of variable rate jumps in `JumpProcesses.jl` and visualizes example solution trajectories for the test cases from `variable_rate_test.jl`. The benchmark compares `VRDirectCB` and `VRFRMODE` aggregators, while the visualization shows state variables vs. time to verify simulation behavior.
19+
20+
**Note**: If you encounter a precompilation error due to method overwriting in `JumpProcesses.jl`, add `__precompile__(false)` to `/home/siva/Desktop/julia/JumpProcesses.jl/src/JumpProcesses.jl` and clear the compilation cache (`rm -rf ~/.julia/compiled`).
21+
22+
The test cases are:
23+
1. **Scalar ODE with Variable Rate Jumps**: Solved with `Tsit5` and `Rosenbrock23` (with/without autodiff).
24+
2. **Scalar SDE with Variable Rate Jumps**: Solved with `SRIW1`.
25+
3. **SDE with Parameter-Switching Jump**: Solved with `SRA1`.
26+
4. **ODE with Constant Rate Jump**: Solved with `Tsit5`.
27+
5. **ODE with Variable Rate Jumps (Alternative Rate)**: Solved with `Tsit5`.
28+
6. **SDE with Variable Rate Jumps (Alternative Rate)**: Solved with `SRIW1`.
29+
7. **Matrix ODE with Variable Rate Jump**: Solved with `Tsit5`.
30+
8. **Complex ODE with Variable Rate Jump**: Solved with `Tsit5`.
31+
32+
For visualization, we solve one trajectory per test case with 2 jumps (2x2 matrix for Test 7). For benchmarking, we vary jumps from 1 to 20 (2x2 to 10x10 for Test 7), running 100 trajectories per configuration.
33+
34+
# Benchmark and Visualization Setup
35+
36+
We define factories for each test case to create problems with a variable number of jumps (or matrix size for Test 7).
37+
38+
```julia
39+
algorithms = Tuple{Any, Any, String, String}[
40+
(VRDirectCB(), Tsit5(), "VRDirectCB", "Test 1 Tsit5 (VRDirectCB)"),
41+
(VRFRMODE(), Tsit5(), "VRFRMODE", "Test 1 Tsit5 (VRFRMODE)"),
42+
(VRDirectCB(), Rosenbrock23(autodiff=false), "VRDirectCB", "Test 1 Rosenbrock23 (no autodiff, VRDirectCB)"),
43+
(VRFRMODE(), Rosenbrock23(autodiff=false), "VRFRMODE", "Test 1 Rosenbrock23 (no autodiff, VRFRMODE)"),
44+
(VRDirectCB(), Rosenbrock23(), "VRDirectCB", "Test 1 Rosenbrock23 (autodiff, VRDirectCB)"),
45+
(VRFRMODE(), Rosenbrock23(), "VRFRMODE", "Test 1 Rosenbrock23 (autodiff, VRFRMODE)"),
46+
(VRDirectCB(), SRIW1(), "VRDirectCB", "Test 2 SRIW1 (VRDirectCB)"),
47+
(VRFRMODE(), SRIW1(), "VRFRMODE", "Test 2 SRIW1 (VRFRMODE)"),
48+
(VRDirectCB(), SRA1(), "VRDirectCB", "Test 3 SRA1 (VRDirectCB)"),
49+
(VRFRMODE(), SRA1(), "VRFRMODE", "Test 3 SRA1 (VRFRMODE)"),
50+
(VRDirectCB(), Tsit5(), "VRDirectCB", "Test 4 Tsit5 (VRDirectCB, ConstantRateJump)"),
51+
(VRFRMODE(), Tsit5(), "VRFRMODE", "Test 4 Tsit5 (VRFRMODE, ConstantRateJump)"),
52+
(VRDirectCB(), Tsit5(), "VRDirectCB", "Test 5 Tsit5 (VRDirectCB)"),
53+
(VRFRMODE(), Tsit5(), "VRFRMODE", "Test 5 Tsit5 (VRFRMODE)"),
54+
(VRDirectCB(), SRIW1(), "VRDirectCB", "Test 6 SRIW1 (VRDirectCB)"),
55+
(VRFRMODE(), SRIW1(), "VRFRMODE", "Test 6 SRIW1 (VRFRMODE)"),
56+
(VRDirectCB(), Tsit5(), "VRDirectCB", "Test 7 Tsit5 (VRDirectCB)"),
57+
(VRFRMODE(), Tsit5(), "VRFRMODE", "Test 7 Tsit5 (VRFRMODE)"),
58+
(VRDirectCB(), Tsit5(), "VRDirectCB", "Test 8 Tsit5 (VRDirectCB)"),
59+
(VRFRMODE(), Tsit5(), "VRFRMODE", "Test 8 Tsit5 (VRFRMODE)"),
60+
]
61+
62+
function create_test1_problem(num_jumps, vr_aggregator, solver)
63+
f = (du, u, p, t) -> (du[1] = u[1])
64+
prob = ODEProblem(f, [0.2], (0.0, 10.0))
65+
jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
66+
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
67+
ensemble_prob = EnsembleProblem(prob)
68+
return ensemble_prob, jump_prob
69+
end
70+
71+
function create_test2_problem(num_jumps, vr_aggregator, solver)
72+
f = (du, u, p, t) -> (du[1] = u[1])
73+
g = (du, u, p, t) -> (du[1] = u[1])
74+
prob = SDEProblem(f, g, [0.2], (0.0, 10.0))
75+
jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
76+
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
77+
ensemble_prob = EnsembleProblem(prob)
78+
return ensemble_prob, jump_prob
79+
end
80+
81+
function create_test3_problem(num_jumps, vr_aggregator, solver)
82+
ff = (du, u, p, t) -> (du .= p == 0 ? 1.01u : 2.01u)
83+
gg = (du, u, p, t) -> begin
84+
du[1, 1] = 0.3u[1]; du[1, 2] = 0.6u[1]
85+
du[2, 1] = 1.2u[1]; du[2, 2] = 0.2u[2]
86+
end
87+
prob = SDEProblem(ff, gg, ones(2), (0.0, 1.0), 0, noise_rate_prototype=zeros(2, 2))
88+
jumps = [VariableRateJump((u, p, t) -> u[1] * 1.0, (integrator) -> (integrator.p = 1)) for _ in 1:num_jumps]
89+
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
90+
ensemble_prob = EnsembleProblem(prob)
91+
return ensemble_prob, jump_prob
92+
end
93+
94+
function create_test4_problem(num_jumps, vr_aggregator, solver)
95+
f2 = (du, u, p, t) -> (du[1] = u[1])
96+
prob = ODEProblem(f2, [0.2], (0.0, 10.0))
97+
jumps = [ConstantRateJump((u, p, t) -> 2, (integrator) -> (integrator.u[1] = integrator.u[1] / 2)) for _ in 1:num_jumps]
98+
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
99+
ensemble_prob = EnsembleProblem(prob)
100+
return ensemble_prob, jump_prob
101+
end
102+
103+
function create_test5_problem(num_jumps, vr_aggregator, solver)
104+
f2 = (du, u, p, t) -> (du[1] = u[1])
105+
prob = ODEProblem(f2, [0.2], (0.0, 10.0))
106+
jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
107+
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
108+
ensemble_prob = EnsembleProblem(prob)
109+
return ensemble_prob, jump_prob
110+
end
111+
112+
function create_test6_problem(num_jumps, vr_aggregator, solver)
113+
f2 = (du, u, p, t) -> (du[1] = u[1])
114+
g2 = (du, u, p, t) -> (du[1] = u[1])
115+
prob = SDEProblem(f2, g2, [0.2], (0.0, 10.0))
116+
jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
117+
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
118+
ensemble_prob = EnsembleProblem(prob)
119+
return ensemble_prob, jump_prob
120+
end
121+
122+
function create_test7_problem(num_jumps, vr_aggregator, solver, matrix_size=2)
123+
f3 = (du, u, p, t) -> (du .= u)
124+
u0 = ones(matrix_size, matrix_size)
125+
prob = ODEProblem(f3, u0, (0.0, 1.0))
126+
rate3 = (u, p, t) -> sum(u[1, :])
127+
affect3! = (integrator) -> (integrator.u .= range(0.25, 1.0, length=matrix_size^2))
128+
jumps = [VariableRateJump(rate3, affect3!) for _ in 1:num_jumps]
129+
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
130+
ensemble_prob = EnsembleProblem(prob)
131+
return ensemble_prob, jump_prob
132+
end
133+
134+
function create_test8_problem(num_jumps, vr_aggregator, solver)
135+
f4 = (dx, x, p, t) -> (dx[1] = x[1])
136+
rate4 = (x, p, t) -> t
137+
affect4! = (integrator) -> (integrator.u[1] = integrator.u[1] * 0.5)
138+
prob = ODEProblem(f4, [1.0 + 0.0im], (0.0, 6.0))
139+
jumps = [VariableRateJump(rate4, affect4!) for _ in 1:num_jumps]
140+
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
141+
ensemble_prob = EnsembleProblem(prob)
142+
return ensemble_prob, jump_prob
143+
end
144+
```
145+
146+
# Solution Visualization
147+
148+
We solve one trajectory for each test case with 2 jumps (2x2 matrix for Test 7) using `VRDirectCB` and plot the state variables vs. time.
149+
150+
```julia
151+
let figs = []
152+
for test_num in 1:8
153+
# Select a representative solver for each test
154+
algo, stepper = if test_num == 1
155+
VRDirectCB(), Tsit5()
156+
elseif test_num == 2 || test_num == 6
157+
VRDirectCB(), SRIW1()
158+
elseif test_num == 3
159+
VRDirectCB(), SRA1()
160+
elseif test_num in [4, 5, 7, 8]
161+
VRDirectCB(), Tsit5()
162+
end
163+
label = "Test $test_num"
164+
165+
# Create problem with 2 jumps (or 2x2 matrix)
166+
ensemble_prob, jump_prob = if test_num == 1
167+
create_test1_problem(2, algo, stepper)
168+
elseif test_num == 2
169+
create_test2_problem(2, algo, stepper)
170+
elseif test_num == 3
171+
create_test3_problem(2, algo, stepper)
172+
elseif test_num == 4
173+
create_test4_problem(2, algo, stepper)
174+
elseif test_num == 5
175+
create_test5_problem(2, algo, stepper)
176+
elseif test_num == 6
177+
create_test6_problem(2, algo, stepper)
178+
elseif test_num == 7
179+
create_test7_problem(2, algo, stepper, 2)
180+
elseif test_num == 8
181+
create_test8_problem(2, algo, stepper)
182+
end
183+
184+
# Solve one trajectory
185+
solver_kwargs = test_num == 3 ? (dt=1.0,) : ()
186+
try
187+
sol = solve(jump_prob, stepper; saveat=0.01, solver_kwargs...)
188+
189+
# Plot solution
190+
fig = plot(title="Test $test_num: Solution Trajectory", xlabel="Time", ylabel="State")
191+
if test_num == 7
192+
# For matrix ODE, plot sum of elements
193+
plot!(sol.t, [sum(sol.u[i]) for i in 1:length(sol.u)], label="Sum of Matrix Elements")
194+
elseif test_num == 8
195+
# For complex ODE, plot real part
196+
plot!(sol.t, real.(sol[1,:]), label="Real Part")
197+
elseif test_num == 3
198+
# For 2D SDE, plot both components
199+
plot!(sol.t, sol[1,:], label="u[1]")
200+
plot!(sol.t, sol[2,:], label="u[2]")
201+
else
202+
# For scalar problems, plot state
203+
plot!(sol.t, sol[1,:], label="u[1]")
204+
end
205+
push!(figs, fig)
206+
catch e
207+
@warn "Failed to solve Test $test_num: $(sprint(showerror, e))"
208+
end
209+
end
210+
plot(figs..., layout=(4, 2), format=fmt, size=(width_px, 4*height_px/2))
211+
end
212+
```
213+
214+
# Benchmark Execution
215+
216+
We benchmark each test case for 1 to 20 jumps (2x2 to 10x10 for Test 7), running 100 trajectories. Errors are logged to diagnose failures.
217+
218+
```julia
219+
num_jumps_range = append!([1], 5:5:20)
220+
matrix_sizes = [2, 4, 6, 8, 10]
221+
bs = Vector{Vector{BenchmarkTools.Trial}}()
222+
errors = Dict{String, Vector{String}}()
223+
224+
for (algo, stepper, agg_name, label) in algorithms
225+
@info label
226+
push!(bs, Vector{BenchmarkTools.Trial}())
227+
errors[label] = String[]
228+
_bs = bs[end]
229+
test_num = parse(Int, match(r"Test (\d+)", label).captures[1])
230+
is_matrix_test = test_num == 7
231+
range_var = is_matrix_test ? matrix_sizes : num_jumps_range
232+
for (i, var) in enumerate(range_var)
233+
if test_num == 1
234+
ensemble_prob, jump_prob = create_test1_problem(is_matrix_test ? 2 : var, algo, stepper)
235+
elseif test_num == 2
236+
ensemble_prob, jump_prob = create_test2_problem(is_matrix_test ? 2 : var, algo, stepper)
237+
elseif test_num == 3
238+
ensemble_prob, jump_prob = create_test3_problem(is_matrix_test ? 2 : var, algo, stepper)
239+
elseif test_num == 4
240+
ensemble_prob, jump_prob = create_test4_problem(is_matrix_test ? 2 : var, algo, stepper)
241+
elseif test_num == 5
242+
ensemble_prob, jump_prob = create_test5_problem(is_matrix_test ? 2 : var, algo, stepper)
243+
elseif test_num == 6
244+
ensemble_prob, jump_prob = create_test6_problem(is_matrix_test ? 2 : var, algo, stepper)
245+
elseif test_num == 7
246+
ensemble_prob, jump_prob = create_test7_problem(2, algo, stepper, var)
247+
elseif test_num == 8
248+
ensemble_prob, jump_prob = create_test8_problem(is_matrix_test ? 2 : var, algo, stepper)
249+
end
250+
solver_kwargs = test_num == 3 ? (dt=1.0,) : ""
251+
trial = try
252+
@benchmark solve($ensemble_prob, $stepper, EnsembleSerial(), trajectories=100, jump_prob=$jump_prob; $solver_kwargs...) samples=50 evals=1 seconds=10
253+
catch e
254+
push!(errors[label], "Error at $(is_matrix_test ? "Matrix Size" : "Num Jumps") = $var: $(sprint(showerror, e))")
255+
BenchmarkTools.Trial(BenchmarkTools.Parameters(samples=50, evals=1, seconds=10))
256+
end
257+
push!(_bs, trial)
258+
if (var == 1 || var % (is_matrix_test ? 2 : 5) == 0)
259+
median_time = length(trial) > 0 ? "$(BenchmarkTools.prettytime(median(trial.times)))" : "nan"
260+
println("algo=$label, $(is_matrix_test ? "Matrix Size" : "Num Jumps") = $var, length = $(length(trial.times)), median time = $median_time")
261+
end
262+
end
263+
end
264+
265+
# Log errors
266+
for (label, err_list) in errors
267+
if !isempty(err_list)
268+
@warn "Errors for $label:"
269+
for err in err_list
270+
println(err)
271+
end
272+
end
273+
end
274+
```
275+
276+
# Benchmark Results
277+
278+
We plot the median execution times for each test case, comparing `VRDirectCB` and `VRFRMODE`.
279+
280+
```julia
281+
let figs = []
282+
for test_num in 1:8
283+
test_algorithms = filter(a -> parse(Int, match(r"Test (\d+)", a[4]).captures[1]) == test_num, algorithms)
284+
is_matrix_test = test_num == 7
285+
range_var = is_matrix_test ? matrix_sizes : num_jumps_range
286+
fig = plot(
287+
yscale=:log10,
288+
xlabel=is_matrix_test ? "Matrix Size" : "Number of Jumps",
289+
ylabel="Time (ns)",
290+
legend_position=:outertopright,
291+
title="Test $test_num: Simulations, 50 samples"
292+
)
293+
for (i, (algo, stepper, agg_name, label)) in enumerate(test_algorithms)
294+
algo_idx = findfirst(a -> a[4] == label, algorithms)
295+
_bs, _vars = [], []
296+
for (j, b) in enumerate(bs[algo_idx])
297+
if length(b) == 50
298+
push!(_bs, median(b.times))
299+
push!(_vars, range_var[j])
300+
end
301+
end
302+
if !isempty(_bs)
303+
plot!(_vars, _bs, label=label)
304+
else
305+
@warn "No valid data for $label in Test $test_num"
306+
end
307+
end
308+
push!(figs, fig)
309+
end
310+
plot(figs..., layout=(4, 2), format=fmt, size=(width_px, 4*height_px/2))
311+
end
312+
```

0 commit comments

Comments
 (0)