@@ -22,14 +22,10 @@ state variables vs. time to verify simulation behavior.
2222
2323The test cases are:
24241. **Scalar ODE with Variable Rate Jumps**: Solved with `Tsit5` and `Rosenbrock23` (with/without autodiff).
25- 2. **Scalar SDE with Variable Rate Jumps**: Solved with `SRIW1`.
26- 3. **ODE with Constant Rate Jump**: Solved with `Tsit5`.
27- 4. **ODE with Variable Rate Jumps (Alternative Rate)**: Solved with `Tsit5`.
28- 5. **SDE with Variable Rate Jumps (Alternative Rate)**: Solved with `SRIW1`.
29- 6. **Complex ODE with Variable Rate Jump**: Solved with `Tsit5`.
25+ 2. **Complex ODE with Variable Rate Jump**: Solved with `Tsit5`.
26+ 3. **DNA Gene Model**: ODE with 10 variable rate jumps from the RSSA paper, solved with Tsit5.
3027
31- For visualization, we solve one trajectory per test case with 2 jumps. For benchmarking,
32- we vary jumps from 1 to 20, running 100 trajectories per configuration.
28+ For visualization, we solve one trajectory per test case with 2 jumps. For benchmarking, we vary jumps from 1 to 20.
3329
3430# Benchmark and Visualization Setup
3531
@@ -46,135 +42,98 @@ algorithms = Tuple{Any, Any, String, String}[
4642 (VR_Direct(), Rosenbrock23(), "VR_Direct", "Test 1 Rosenbrock23 (autodiff, VR_Direct)"),
4743 (VR_DirectFW(), Rosenbrock23(), "VR_DirectFW", "Test 1 Rosenbrock23 (autodiff, VR_DirectFW)"),
4844 (VR_FRM(), Rosenbrock23(), "VR_FRM", "Test 1 Rosenbrock23 (autodiff, VR_FRM)"),
49- (VR_Direct(), SRIW1(), "VR_Direct", "Test 2 SRIW1 (VR_Direct)"),
50- (VR_DirectFW(), SRIW1(), "VR_DirectFW", "Test 2 SRIW1 (VR_DirectFW)"),
51- (VR_FRM(), SRIW1(), "VR_FRM", "Test 2 SRIW1 (VR_FRM)"),
52- (VR_Direct(), Tsit5(), "VR_Direct", "Test 3 Tsit5 (VR_Direct, ConstantRateJump)"),
53- (VR_DirectFW(), Tsit5(), "VR_DirectFW", "Test 3 Tsit5 (VR_DirectFW, ConstantRateJump)"),
54- (VR_FRM(), Tsit5(), "VR_FRM", "Test 3 Tsit5 (VR_FRM, ConstantRateJump)"),
55- (VR_Direct(), Tsit5(), "VR_Direct", "Test 4 Tsit5 (VR_Direct)"),
56- (VR_DirectFW(), Tsit5(), "VR_DirectFW", "Test 4 Tsit5 (VR_DirectFW)"),
57- (VR_FRM(), Tsit5(), "VR_FRM", "Test 4 Tsit5 (VR_FRM)"),
58- (VR_Direct(), SRIW1(), "VR_Direct", "Test 5 SRIW1 (VR_Direct)"),
59- (VR_DirectFW(), SRIW1(), "VR_DirectFW", "Test 5 SRIW1 (VR_DirectFW)"),
60- (VR_FRM(), SRIW1(), "VR_FRM", "Test 5 SRIW1 (VR_FRM)"),
61- (VR_Direct(), Tsit5(), "VR_Direct", "Test 6 Tsit5 (VR_Direct)"),
62- (VR_DirectFW(), Tsit5(), "VR_DirectFW", "Test 6 Tsit5 (VR_DirectFW)"),
63- (VR_FRM(), Tsit5(), "VR_FRM", "Test 6 Tsit5 (VR_FRM)"),
45+ (VR_Direct(), Tsit5(), "VR_Direct", "Test 2 Tsit5 (VR_Direct)"),
46+ (VR_DirectFW(), Tsit5(), "VR_DirectFW", "Test 2 Tsit5 (VR_DirectFW)"),
47+ (VR_FRM(), Tsit5(), "VR_FRM", "Test 2 Tsit5 (VR_FRM)"),
48+ (VR_Direct(), Tsit5(), "VR_Direct", "Test 3 Tsit5 (VR_Direct, DNA Model)"),
49+ (VR_DirectFW(), Tsit5(), "VR_DirectFW", "Test 3 Tsit5 (VR_DirectFW, DNA Model)"),
50+ (VR_FRM(), Tsit5(), "VR_FRM", "Test 3 Tsit5 (VR_FRM, DNA Model)"),
6451]
6552
6653function create_test1_problem(num_jumps, vr_aggregator, solver)
6754 f = (du, u, p, t) -> (du[1] = u[1])
6855 prob = ODEProblem(f, [0.2], (0.0, 10.0))
6956 jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
7057 jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
71- ensemble_prob = EnsembleProblem(prob )
58+ ensemble_prob = EnsembleProblem(jump_prob )
7259 return ensemble_prob, jump_prob
7360end
7461
7562function create_test2_problem(num_jumps, vr_aggregator, solver)
76- f = (du, u, p, t) -> (du[1] = u[1])
77- g = (du, u, p, t) -> (du[1] = u[1])
78- prob = SDEProblem(f, g, [0.2], (0.0, 10.0))
79- jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
80- jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
81- ensemble_prob = EnsembleProblem(prob)
82- return ensemble_prob, jump_prob
83- end
84-
85- function create_test3_problem(num_jumps, vr_aggregator, solver)
86- f2 = (du, u, p, t) -> (du[1] = u[1])
87- prob = ODEProblem(f2, [0.2], (0.0, 10.0))
88- jumps = [ConstantRateJump((u, p, t) -> 2, (integrator) -> (integrator.u[1] = integrator.u[1] / 2)) for _ in 1:num_jumps]
89- jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
90- ensemble_prob = EnsembleProblem(prob)
91- return ensemble_prob, jump_prob
92- end
93-
94- function create_test4_problem(num_jumps, vr_aggregator, solver)
95- f2 = (du, u, p, t) -> (du[1] = u[1])
96- prob = ODEProblem(f2, [0.2], (0.0, 10.0))
97- jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
98- jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
99- ensemble_prob = EnsembleProblem(prob)
100- return ensemble_prob, jump_prob
101- end
102-
103- function create_test5_problem(num_jumps, vr_aggregator, solver)
104- f2 = (du, u, p, t) -> (du[1] = u[1])
105- g2 = (du, u, p, t) -> (du[1] = u[1])
106- prob = SDEProblem(f2, g2, [0.2], (0.0, 10.0))
107- jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
108- jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
109- ensemble_prob = EnsembleProblem(prob)
110- return ensemble_prob, jump_prob
111- end
112-
113- function create_test6_problem(num_jumps, vr_aggregator, solver)
11463 f4 = (dx, x, p, t) -> (dx[1] = x[1])
11564 rate4 = (x, p, t) -> t
11665 affect4! = (integrator) -> (integrator.u[1] = integrator.u[1] * 0.5)
11766 prob = ODEProblem(f4, [1.0 + 0.0im], (0.0, 6.0))
11867 jumps = [VariableRateJump(rate4, affect4!) for _ in 1:num_jumps]
11968 jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
120- ensemble_prob = EnsembleProblem(prob )
69+ ensemble_prob = EnsembleProblem(jump_prob )
12170 return ensemble_prob, jump_prob
12271end
123- ```
124-
125- # Solution Visualization
126-
127- We solve one trajectory for each test case with 2 jumps using `VR_Direct` and plot the state variables vs. time.
12872
129- ```julia
130- let figs = []
131- for test_num in 1:6
132- # Select a representative solver for each test
133- algo, stepper = if test_num == 2 || test_num == 5
134- VR_Direct(), SRIW1()
135- else
136- VR_Direct(), Tsit5()
137- end
138- label = "Test $test_num"
139-
140- # Create problem with 2 jumps (or 2x2 matrix)
141- ensemble_prob, jump_prob = if test_num == 1
142- create_test1_problem(2, algo, stepper)
143- elseif test_num == 2
144- create_test2_problem(2, algo, stepper)
145- elseif test_num == 3
146- create_test3_problem(2, algo, stepper)
147- elseif test_num == 4
148- create_test4_problem(2, algo, stepper)
149- elseif test_num == 5
150- create_test5_problem(2, algo, stepper)
151- elseif test_num == 6
152- create_test6_problem(2, algo, stepper)
153- end
154-
155- try
156- sol = solve(jump_prob, stepper; saveat=0.01)
157- # Plot solution
158- fig = plot(title="Test $test_num: Solution Trajectory", xlabel="Time", ylabel="State")
159- if test_num == 6
160- # For complex ODE, plot real part
161- plot!(sol.t, real.(sol[1,:]), label="Real Part")
162- else
163- # For scalar problems, plot state
164- plot!(sol.t, sol[1,:], label="u[1]")
165- end
166- push!(figs, fig)
167- catch e
168- @warn "Failed to solve Test $test_num: $(sprint(showerror, e))"
169- end
73+ function create_test3_problem(num_jumps, vr_aggregator, solver)
74+ # Parameters from the RSSA paper
75+ r = [0.043, 0.0007, 0.0715, 0.0039, 0.0199, 0.4791, 0.00019, 0.8765, 0.083, 0.5]
76+ k = -log(2) / 30
77+ u0 = [10.0, 10.0, 30.0, 0.0, 0.0, 0.0] # [DNA, M, D, RNA, DNAD, DNA2D]
78+ tspan = (0.0, 120.0)
79+
80+ function f_dna(du, u, p, t)
81+ du .= 0.0
82+ nothing
17083 end
171- plot(figs..., layout=(4, 2), format=fmt, size=(width_px, 4*height_px/2))
84+
85+ # Define 10 variable rate jumps (fixed set, num_jumps ignored for consistency)
86+ function rate1(u, p, t) r[1] * u[4] end
87+ function affect1!(integrator) integrator.u[2] += 1; nothing end
88+ jump1 = VariableRateJump(rate1, affect1!)
89+
90+ function rate2(u, p, t) r[2] * u[2] end
91+ function affect2!(integrator) integrator.u[2] -= 1; nothing end
92+ jump2 = VariableRateJump(rate2, affect2!)
93+
94+ function rate3(u, p, t) r[3] * u[5] end
95+ function affect3!(integrator) integrator.u[4] += 1; nothing end
96+ jump3 = VariableRateJump(rate3, affect3!)
97+
98+ function rate4(u, p, t) r[4] * u[4] end
99+ function affect4!(integrator) integrator.u[4] -= 1; nothing end
100+ jump4 = VariableRateJump(rate4, affect4!)
101+
102+ function rate5(u, p, t) r[5] * exp(k * t) * u[1] * u[3] end
103+ function affect5!(integrator) integrator.u[1] -= 1; integrator.u[3] -= 1; integrator.u[5] += 1; nothing end
104+ jump5 = VariableRateJump(rate5, affect5!)
105+
106+ function rate6(u, p, t) r[6] * u[5] end
107+ function affect6!(integrator) integrator.u[5] -= 1; integrator.u[1] += 1; integrator.u[3] += 1; nothing end
108+ jump6 = VariableRateJump(rate6, affect6!)
109+
110+ function rate7(u, p, t) r[7] * exp(k * t) * u[5] * u[3] end
111+ function affect7!(integrator) integrator.u[5] -= 1; integrator.u[3] -= 1; integrator.u[6] += 1; nothing end
112+ jump7 = VariableRateJump(rate7, affect7!)
113+
114+ function rate8(u, p, t) r[8] * u[6] end
115+ function affect8!(integrator) integrator.u[6] -= 1; integrator.u[1] += 1; integrator.u[3] += 1; nothing end
116+ jump8 = VariableRateJump(rate8, affect8!)
117+
118+ function rate9(u, p, t) r[9] * exp(k * t) * u[2] * (u[2] - 1) / 2 end
119+ function affect9!(integrator) integrator.u[2] -= 2; integrator.u[3] += 1; nothing end
120+ jump9 = VariableRateJump(rate9, affect9!)
121+
122+ function rate10(u, p, t) r[10] * u[3] end
123+ function affect10!(integrator) integrator.u[3] -= 1; integrator.u[2] += 2; nothing end
124+ jump10 = VariableRateJump(rate10, affect10!)
125+
126+ prob = ODEProblem(f_dna, u0, tspan)
127+ jumps = (jump1, jump2, jump3, jump4, jump5, jump6, jump7, jump8, jump9, jump10)
128+ jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
129+ ensemble_prob = EnsembleProblem(jump_prob)
130+ return ensemble_prob, jump_prob
172131end
173132```
174133
175134# Benchmark Execution
176135
177- We benchmark each test case for 1 to 20 jumps, running 100 trajectories . Errors are logged to diagnose failures.
136+ We benchmark each test case for 1 to 20 jumps. Errors are logged to diagnose failures.
178137
179138```julia
180139num_jumps_range = append!([1], 5:5:20)
@@ -195,18 +154,12 @@ for (algo, stepper, agg_name, label) in algorithms
195154 ensemble_prob, jump_prob = create_test2_problem(var, algo, stepper)
196155 elseif test_num == 3
197156 ensemble_prob, jump_prob = create_test3_problem(var, algo, stepper)
198- elseif test_num == 4
199- ensemble_prob, jump_prob = create_test4_problem(var, algo, stepper)
200- elseif test_num == 5
201- ensemble_prob, jump_prob = create_test5_problem(var, algo, stepper)
202- elseif test_num == 6
203- ensemble_prob, jump_prob = create_test6_problem(var, algo, stepper)
204157 end
205158 trial = try
206159 @benchmark(
207- solve($jump_prob, $stepper),
208- samples=50,
209- evals=1,
160+ solve($jump_prob, $stepper),
161+ samples=50,
162+ evals=1,
210163 seconds=10
211164 )
212165 catch e
@@ -237,13 +190,12 @@ We plot the median execution times for each test case, comparing `VR_Direct`,`VR
237190
238191```julia
239192let figs = []
240- for test_num in 1:6
193+ for test_num in 1:3
241194 test_algorithms = filter(a -> parse(Int, match(r"Test (\d+)", a[4]).captures[1]) == test_num, algorithms)
242- is_matrix_test = test_num == 7
243- range_var = is_matrix_test ? matrix_sizes : num_jumps_range
195+ range_var = num_jumps_range
244196 fig = plot(
245197 yscale=:log10,
246- xlabel=is_matrix_test ? "Matrix Size" : "Number of Jumps",
198+ xlabel="Number of Jumps",
247199 ylabel="Time (ns)",
248200 legend_position=:outertopright,
249201 title="Test $test_num: Simulations, 50 samples"
@@ -265,6 +217,6 @@ let figs = []
265217 end
266218 push!(figs, fig)
267219 end
268- plot(figs..., layout=(6 , 1), format=fmt, size=(width_px, 8 *height_px/2 ))
220+ plot(figs..., layout=(3 , 1), format=fmt, size=(width_px, 4 *height_px))
269221end
270222```
0 commit comments