@@ -29,11 +29,11 @@ The test cases are:
29297. **Matrix ODE with Variable Rate Jump**: Solved with `Tsit5`.
30308. **Complex ODE with Variable Rate Jump**: Solved with `Tsit5`.
3131
32- For visualization, we solve one trajectory per test case with 2 jumps (2x2 matrix for Test 7) . For benchmarking, we vary jumps from 1 to 20 (2x2 to 10x10 for Test 7) , running 100 trajectories per configuration.
32+ For visualization, we solve one trajectory per test case with 2 jumps. For benchmarking, we vary jumps from 1 to 20, running 100 trajectories per configuration.
3333
3434# Benchmark and Visualization Setup
3535
36- We define factories for each test case to create problems with a variable number of jumps (or matrix size for Test 7) .
36+ We define factories for each test case to create problems with a variable number of jumps.
3737
3838```julia
3939algorithms = Tuple{Any, Any, String, String}[
@@ -45,18 +45,14 @@ algorithms = Tuple{Any, Any, String, String}[
4545 (VR_FRM(), Rosenbrock23(), "VR_FRM", "Test 1 Rosenbrock23 (autodiff, VR_FRM)"),
4646 (VR_Direct(), SRIW1(), "VR_Direct", "Test 2 SRIW1 (VR_Direct)"),
4747 (VR_FRM(), SRIW1(), "VR_FRM", "Test 2 SRIW1 (VR_FRM)"),
48- (VR_Direct(), SRA1(), "VR_Direct", "Test 3 SRA1 (VR_Direct)"),
49- (VR_FRM(), SRA1(), "VR_FRM", "Test 3 SRA1 (VR_FRM)"),
50- (VR_Direct(), Tsit5(), "VR_Direct", "Test 4 Tsit5 (VR_Direct, ConstantRateJump)"),
51- (VR_FRM(), Tsit5(), "VR_FRM", "Test 4 Tsit5 (VR_FRM, ConstantRateJump)"),
52- (VR_Direct(), Tsit5(), "VR_Direct", "Test 5 Tsit5 (VR_Direct)"),
53- (VR_FRM(), Tsit5(), "VR_FRM", "Test 5 Tsit5 (VR_FRM)"),
54- (VR_Direct(), SRIW1(), "VR_Direct", "Test 6 SRIW1 (VR_Direct)"),
55- (VR_FRM(), SRIW1(), "VR_FRM", "Test 6 SRIW1 (VR_FRM)"),
56- (VR_Direct(), Tsit5(), "VR_Direct", "Test 7 Tsit5 (VR_Direct)"),
57- (VR_FRM(), Tsit5(), "VR_FRM", "Test 7 Tsit5 (VR_FRM)"),
58- (VR_Direct(), Tsit5(), "VR_Direct", "Test 8 Tsit5 (VR_Direct)"),
59- (VR_FRM(), Tsit5(), "VR_FRM", "Test 8 Tsit5 (VR_FRM)"),
48+ (VR_Direct(), Tsit5(), "VR_Direct", "Test 3 Tsit5 (VR_Direct, ConstantRateJump)"),
49+ (VR_FRM(), Tsit5(), "VR_FRM", "Test 3 Tsit5 (VR_FRM, ConstantRateJump)"),
50+ (VR_Direct(), Tsit5(), "VR_Direct", "Test 4 Tsit5 (VR_Direct)"),
51+ (VR_FRM(), Tsit5(), "VR_FRM", "Test 4 Tsit5 (VR_FRM)"),
52+ (VR_Direct(), SRIW1(), "VR_Direct", "Test 5 SRIW1 (VR_Direct)"),
53+ (VR_FRM(), SRIW1(), "VR_FRM", "Test 5 SRIW1 (VR_FRM)"),
54+ (VR_Direct(), Tsit5(), "VR_Direct", "Test 6 Tsit5 (VR_Direct)"),
55+ (VR_FRM(), Tsit5(), "VR_FRM", "Test 6 Tsit5 (VR_FRM)"),
6056]
6157
6258function create_test1_problem(num_jumps, vr_aggregator, solver)
@@ -79,19 +75,6 @@ function create_test2_problem(num_jumps, vr_aggregator, solver)
7975end
8076
8177function create_test3_problem(num_jumps, vr_aggregator, solver)
82- ff = (du, u, p, t) -> (du .= p == 0 ? 1.01u : 2.01u)
83- gg = (du, u, p, t) -> begin
84- du[1, 1] = 0.3u[1]; du[1, 2] = 0.6u[1]
85- du[2, 1] = 1.2u[1]; du[2, 2] = 0.2u[2]
86- end
87- prob = SDEProblem(ff, gg, ones(2), (0.0, 1.0), 0, noise_rate_prototype=zeros(2, 2))
88- jumps = [VariableRateJump((u, p, t) -> u[1] * 1.0, (integrator) -> (integrator.p = 1)) for _ in 1:num_jumps]
89- jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
90- ensemble_prob = EnsembleProblem(prob)
91- return ensemble_prob, jump_prob
92- end
93-
94- function create_test4_problem(num_jumps, vr_aggregator, solver)
9578 f2 = (du, u, p, t) -> (du[1] = u[1])
9679 prob = ODEProblem(f2, [0.2], (0.0, 10.0))
9780 jumps = [ConstantRateJump((u, p, t) -> 2, (integrator) -> (integrator.u[1] = integrator.u[1] / 2)) for _ in 1:num_jumps]
@@ -100,7 +83,7 @@ function create_test4_problem(num_jumps, vr_aggregator, solver)
10083 return ensemble_prob, jump_prob
10184end
10285
103- function create_test5_problem (num_jumps, vr_aggregator, solver)
86+ function create_test4_problem (num_jumps, vr_aggregator, solver)
10487 f2 = (du, u, p, t) -> (du[1] = u[1])
10588 prob = ODEProblem(f2, [0.2], (0.0, 10.0))
10689 jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
@@ -109,7 +92,7 @@ function create_test5_problem(num_jumps, vr_aggregator, solver)
10992 return ensemble_prob, jump_prob
11093end
11194
112- function create_test6_problem (num_jumps, vr_aggregator, solver)
95+ function create_test5_problem (num_jumps, vr_aggregator, solver)
11396 f2 = (du, u, p, t) -> (du[1] = u[1])
11497 g2 = (du, u, p, t) -> (du[1] = u[1])
11598 prob = SDEProblem(f2, g2, [0.2], (0.0, 10.0))
@@ -119,19 +102,7 @@ function create_test6_problem(num_jumps, vr_aggregator, solver)
119102 return ensemble_prob, jump_prob
120103end
121104
122- function create_test7_problem(num_jumps, vr_aggregator, solver, matrix_size=2)
123- f3 = (du, u, p, t) -> (du .= u)
124- u0 = ones(matrix_size, matrix_size)
125- prob = ODEProblem(f3, u0, (0.0, 1.0))
126- rate3 = (u, p, t) -> sum(u[1, :])
127- affect3! = (integrator) -> (integrator.u .= range(0.25, 1.0, length=matrix_size^2))
128- jumps = [VariableRateJump(rate3, affect3!) for _ in 1:num_jumps]
129- jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
130- ensemble_prob = EnsembleProblem(prob)
131- return ensemble_prob, jump_prob
132- end
133-
134- function create_test8_problem(num_jumps, vr_aggregator, solver)
105+ function create_test6_problem(num_jumps, vr_aggregator, solver)
135106 f4 = (dx, x, p, t) -> (dx[1] = x[1])
136107 rate4 = (x, p, t) -> t
137108 affect4! = (integrator) -> (integrator.u[1] = integrator.u[1] * 0.5)
@@ -145,19 +116,15 @@ end
145116
146117# Solution Visualization
147118
148- We solve one trajectory for each test case with 2 jumps (2x2 matrix for Test 7) using `VR_Direct` and plot the state variables vs. time.
119+ We solve one trajectory for each test case with 2 jumps using `VR_Direct` and plot the state variables vs. time.
149120
150121```julia
151122let figs = []
152- for test_num in 1:8
123+ for test_num in 1:6
153124 # Select a representative solver for each test
154- algo, stepper = if test_num == 1
155- VR_Direct(), Tsit5()
156- elseif test_num == 2 || test_num == 6
125+ algo, stepper = if test_num == 2 || test_num == 5
157126 VR_Direct(), SRIW1()
158- elseif test_num == 3
159- VR_Direct(), SRA1()
160- elseif test_num in [4, 5, 7, 8]
127+ else
161128 VR_Direct(), Tsit5()
162129 end
163130 label = "Test $test_num"
@@ -175,29 +142,16 @@ let figs = []
175142 create_test5_problem(2, algo, stepper)
176143 elseif test_num == 6
177144 create_test6_problem(2, algo, stepper)
178- elseif test_num == 7
179- create_test7_problem(2, algo, stepper, 2)
180- elseif test_num == 8
181- create_test8_problem(2, algo, stepper)
182145 end
183-
184- # Solve one trajectory
185- solver_kwargs = test_num == 3 ? (dt=1.0,) : ()
186- try
187- sol = solve(jump_prob, stepper; saveat=0.01, solver_kwargs...)
188146
147+ try
148+ sol = solve(jump_prob, stepper; saveat=0.01)
149+
189150 # Plot solution
190151 fig = plot(title="Test $test_num: Solution Trajectory", xlabel="Time", ylabel="State")
191- if test_num == 7
192- # For matrix ODE, plot sum of elements
193- plot!(sol.t, [sum(sol.u[i]) for i in 1:length(sol.u)], label="Sum of Matrix Elements")
194- elseif test_num == 8
152+ if test_num == 6
195153 # For complex ODE, plot real part
196154 plot!(sol.t, real.(sol[1,:]), label="Real Part")
197- elseif test_num == 3
198- # For 2D SDE, plot both components
199- plot!(sol.t, sol[1,:], label="u[1]")
200- plot!(sol.t, sol[2,:], label="u[2]")
201155 else
202156 # For scalar problems, plot state
203157 plot!(sol.t, sol[1,:], label="u[1]")
@@ -213,11 +167,10 @@ end
213167
214168# Benchmark Execution
215169
216- We benchmark each test case for 1 to 20 jumps (2x2 to 10x10 for Test 7) , running 100 trajectories. Errors are logged to diagnose failures.
170+ We benchmark each test case for 1 to 20 jumps, running 100 trajectories. Errors are logged to diagnose failures.
217171
218172```julia
219173num_jumps_range = append!([1], 5:5:20)
220- matrix_sizes = [2, 4, 6, 8, 10]
221174bs = Vector{Vector{BenchmarkTools.Trial}}()
222175errors = Dict{String, Vector{String}}()
223176
@@ -227,38 +180,36 @@ for (algo, stepper, agg_name, label) in algorithms
227180 errors[label] = String[]
228181 _bs = bs[end]
229182 test_num = parse(Int, match(r"Test (\d+)", label).captures[1])
230- is_matrix_test = test_num == 7
231- range_var = is_matrix_test ? matrix_sizes : num_jumps_range
183+ range_var = num_jumps_range
232184 for (i, var) in enumerate(range_var)
233185 if test_num == 1
234- ensemble_prob, jump_prob = create_test1_problem(is_matrix_test ? 2 : var, algo, stepper)
186+ ensemble_prob, jump_prob = create_test1_problem(var, algo, stepper)
235187 elseif test_num == 2
236- ensemble_prob, jump_prob = create_test2_problem(is_matrix_test ? 2 : var, algo, stepper)
188+ ensemble_prob, jump_prob = create_test2_problem(var, algo, stepper)
237189 elseif test_num == 3
238- ensemble_prob, jump_prob = create_test3_problem(is_matrix_test ? 2 : var, algo, stepper)
190+ ensemble_prob, jump_prob = create_test3_problem(var, algo, stepper)
239191 elseif test_num == 4
240- ensemble_prob, jump_prob = create_test4_problem(is_matrix_test ? 2 : var, algo, stepper)
192+ ensemble_prob, jump_prob = create_test4_problem(var, algo, stepper)
241193 elseif test_num == 5
242- ensemble_prob, jump_prob = create_test5_problem(is_matrix_test ? 2 : var, algo, stepper)
194+ ensemble_prob, jump_prob = create_test5_problem(var, algo, stepper)
243195 elseif test_num == 6
244- ensemble_prob, jump_prob = create_test6_problem(is_matrix_test ? 2 : var, algo, stepper)
245- elseif test_num == 7
246- ensemble_prob, jump_prob = create_test7_problem(2, algo, stepper, var)
247- elseif test_num == 8
248- ensemble_prob, jump_prob = create_test8_problem(is_matrix_test ? 2 : var, algo, stepper)
196+ ensemble_prob, jump_prob = create_test6_problem(var, algo, stepper)
249197 end
250- solver_kwargs = test_num == 3 ? (dt=1.0,) : ""
251198 trial = try
252- @benchmark solve($ensemble_prob, $stepper, EnsembleSerial(), trajectories=100, jump_prob=$jump_prob; $solver_kwargs...) samples=50 evals=1 seconds=10
199+ @benchmark(
200+ solve($jump_prob, $stepper),
201+ samples=50,
202+ evals=1,
203+ seconds=10
204+ )
253205 catch e
254- push!(errors[label], "Error at $(is_matrix_test ? "Matrix Size" : " Num Jumps") = $var: $(sprint(showerror, e))")
206+ push!(errors[label], "Error at Num Jumps = $var: $(sprint(showerror, e))")
255207 BenchmarkTools.Trial(BenchmarkTools.Parameters(samples=50, evals=1, seconds=10))
256208 end
257209 push!(_bs, trial)
258- if (var == 1 || var % (is_matrix_test ? 2 : 5) == 0)
259- median_time = length(trial) > 0 ? "$(BenchmarkTools.prettytime(median(trial.times)))" : "nan"
260- println("algo=$label, $(is_matrix_test ? "Matrix Size" : "Num Jumps") = $var, length = $(length(trial.times)), median time = $median_time")
261- end
210+
211+ median_time = length(trial) > 0 ? "$(BenchmarkTools.prettytime(median(trial.times)))" : "nan"
212+ println("algo=$label, Num Jumps = $var, length = $(length(trial.times)), median time = $median_time")
262213 end
263214end
264215
@@ -279,7 +230,7 @@ We plot the median execution times for each test case, comparing `VR_Direct` and
279230
280231```julia
281232let figs = []
282- for test_num in 1:8
233+ for test_num in 1:6
283234 test_algorithms = filter(a -> parse(Int, match(r"Test (\d+)", a[4]).captures[1]) == test_num, algorithms)
284235 is_matrix_test = test_num == 7
285236 range_var = is_matrix_test ? matrix_sizes : num_jumps_range
@@ -307,6 +258,6 @@ let figs = []
307258 end
308259 push!(figs, fig)
309260 end
310- plot(figs..., layout=(4, 2 ), format=fmt, size=(width_px, 4 *height_px/2))
261+ plot(figs..., layout=(6, 1 ), format=fmt, size=(width_px, 8 *height_px/2))
311262end
312263```
0 commit comments