Skip to content

Commit 6713a93

Browse files
added DNA Gene Model
1 parent 7d2b038 commit 6713a93

File tree

1 file changed

+76
-124
lines changed

1 file changed

+76
-124
lines changed

benchmarks/Jumps/VR_Aggregator_Benchmark.jmd

Lines changed: 76 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,10 @@ state variables vs. time to verify simulation behavior.
2222

2323
The test cases are:
2424
1. **Scalar ODE with Variable Rate Jumps**: Solved with `Tsit5` and `Rosenbrock23` (with/without autodiff).
25-
2. **Scalar SDE with Variable Rate Jumps**: Solved with `SRIW1`.
26-
3. **ODE with Constant Rate Jump**: Solved with `Tsit5`.
27-
4. **ODE with Variable Rate Jumps (Alternative Rate)**: Solved with `Tsit5`.
28-
5. **SDE with Variable Rate Jumps (Alternative Rate)**: Solved with `SRIW1`.
29-
6. **Complex ODE with Variable Rate Jump**: Solved with `Tsit5`.
25+
2. **Complex ODE with Variable Rate Jump**: Solved with `Tsit5`.
26+
3. **DNA Gene Model**: ODE with 10 variable rate jumps from the RSSA paper, solved with Tsit5.
3027

31-
For visualization, we solve one trajectory per test case with 2 jumps. For benchmarking,
32-
we vary jumps from 1 to 20, running 100 trajectories per configuration.
28+
For visualization, we solve one trajectory per test case with 2 jumps. For benchmarking, we vary jumps from 1 to 20.
3329

3430
# Benchmark and Visualization Setup
3531

@@ -46,135 +42,98 @@ algorithms = Tuple{Any, Any, String, String}[
4642
(VR_Direct(), Rosenbrock23(), "VR_Direct", "Test 1 Rosenbrock23 (autodiff, VR_Direct)"),
4743
(VR_DirectFW(), Rosenbrock23(), "VR_DirectFW", "Test 1 Rosenbrock23 (autodiff, VR_DirectFW)"),
4844
(VR_FRM(), Rosenbrock23(), "VR_FRM", "Test 1 Rosenbrock23 (autodiff, VR_FRM)"),
49-
(VR_Direct(), SRIW1(), "VR_Direct", "Test 2 SRIW1 (VR_Direct)"),
50-
(VR_DirectFW(), SRIW1(), "VR_DirectFW", "Test 2 SRIW1 (VR_DirectFW)"),
51-
(VR_FRM(), SRIW1(), "VR_FRM", "Test 2 SRIW1 (VR_FRM)"),
52-
(VR_Direct(), Tsit5(), "VR_Direct", "Test 3 Tsit5 (VR_Direct, ConstantRateJump)"),
53-
(VR_DirectFW(), Tsit5(), "VR_DirectFW", "Test 3 Tsit5 (VR_DirectFW, ConstantRateJump)"),
54-
(VR_FRM(), Tsit5(), "VR_FRM", "Test 3 Tsit5 (VR_FRM, ConstantRateJump)"),
55-
(VR_Direct(), Tsit5(), "VR_Direct", "Test 4 Tsit5 (VR_Direct)"),
56-
(VR_DirectFW(), Tsit5(), "VR_DirectFW", "Test 4 Tsit5 (VR_DirectFW)"),
57-
(VR_FRM(), Tsit5(), "VR_FRM", "Test 4 Tsit5 (VR_FRM)"),
58-
(VR_Direct(), SRIW1(), "VR_Direct", "Test 5 SRIW1 (VR_Direct)"),
59-
(VR_DirectFW(), SRIW1(), "VR_DirectFW", "Test 5 SRIW1 (VR_DirectFW)"),
60-
(VR_FRM(), SRIW1(), "VR_FRM", "Test 5 SRIW1 (VR_FRM)"),
61-
(VR_Direct(), Tsit5(), "VR_Direct", "Test 6 Tsit5 (VR_Direct)"),
62-
(VR_DirectFW(), Tsit5(), "VR_DirectFW", "Test 6 Tsit5 (VR_DirectFW)"),
63-
(VR_FRM(), Tsit5(), "VR_FRM", "Test 6 Tsit5 (VR_FRM)"),
45+
(VR_Direct(), Tsit5(), "VR_Direct", "Test 2 Tsit5 (VR_Direct)"),
46+
(VR_DirectFW(), Tsit5(), "VR_DirectFW", "Test 2 Tsit5 (VR_DirectFW)"),
47+
(VR_FRM(), Tsit5(), "VR_FRM", "Test 2 Tsit5 (VR_FRM)"),
48+
(VR_Direct(), Tsit5(), "VR_Direct", "Test 3 Tsit5 (VR_Direct, DNA Model)"),
49+
(VR_DirectFW(), Tsit5(), "VR_DirectFW", "Test 3 Tsit5 (VR_DirectFW, DNA Model)"),
50+
(VR_FRM(), Tsit5(), "VR_FRM", "Test 3 Tsit5 (VR_FRM, DNA Model)"),
6451
]
6552

6653
function create_test1_problem(num_jumps, vr_aggregator, solver)
6754
f = (du, u, p, t) -> (du[1] = u[1])
6855
prob = ODEProblem(f, [0.2], (0.0, 10.0))
6956
jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
7057
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
71-
ensemble_prob = EnsembleProblem(prob)
58+
ensemble_prob = EnsembleProblem(jump_prob)
7259
return ensemble_prob, jump_prob
7360
end
7461

7562
function create_test2_problem(num_jumps, vr_aggregator, solver)
76-
f = (du, u, p, t) -> (du[1] = u[1])
77-
g = (du, u, p, t) -> (du[1] = u[1])
78-
prob = SDEProblem(f, g, [0.2], (0.0, 10.0))
79-
jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
80-
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
81-
ensemble_prob = EnsembleProblem(prob)
82-
return ensemble_prob, jump_prob
83-
end
84-
85-
function create_test3_problem(num_jumps, vr_aggregator, solver)
86-
f2 = (du, u, p, t) -> (du[1] = u[1])
87-
prob = ODEProblem(f2, [0.2], (0.0, 10.0))
88-
jumps = [ConstantRateJump((u, p, t) -> 2, (integrator) -> (integrator.u[1] = integrator.u[1] / 2)) for _ in 1:num_jumps]
89-
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
90-
ensemble_prob = EnsembleProblem(prob)
91-
return ensemble_prob, jump_prob
92-
end
93-
94-
function create_test4_problem(num_jumps, vr_aggregator, solver)
95-
f2 = (du, u, p, t) -> (du[1] = u[1])
96-
prob = ODEProblem(f2, [0.2], (0.0, 10.0))
97-
jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
98-
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
99-
ensemble_prob = EnsembleProblem(prob)
100-
return ensemble_prob, jump_prob
101-
end
102-
103-
function create_test5_problem(num_jumps, vr_aggregator, solver)
104-
f2 = (du, u, p, t) -> (du[1] = u[1])
105-
g2 = (du, u, p, t) -> (du[1] = u[1])
106-
prob = SDEProblem(f2, g2, [0.2], (0.0, 10.0))
107-
jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
108-
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
109-
ensemble_prob = EnsembleProblem(prob)
110-
return ensemble_prob, jump_prob
111-
end
112-
113-
function create_test6_problem(num_jumps, vr_aggregator, solver)
11463
f4 = (dx, x, p, t) -> (dx[1] = x[1])
11564
rate4 = (x, p, t) -> t
11665
affect4! = (integrator) -> (integrator.u[1] = integrator.u[1] * 0.5)
11766
prob = ODEProblem(f4, [1.0 + 0.0im], (0.0, 6.0))
11867
jumps = [VariableRateJump(rate4, affect4!) for _ in 1:num_jumps]
11968
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
120-
ensemble_prob = EnsembleProblem(prob)
69+
ensemble_prob = EnsembleProblem(jump_prob)
12170
return ensemble_prob, jump_prob
12271
end
123-
```
124-
125-
# Solution Visualization
126-
127-
We solve one trajectory for each test case with 2 jumps using `VR_Direct` and plot the state variables vs. time.
12872

129-
```julia
130-
let figs = []
131-
for test_num in 1:6
132-
# Select a representative solver for each test
133-
algo, stepper = if test_num == 2 || test_num == 5
134-
VR_Direct(), SRIW1()
135-
else
136-
VR_Direct(), Tsit5()
137-
end
138-
label = "Test $test_num"
139-
140-
# Create problem with 2 jumps (or 2x2 matrix)
141-
ensemble_prob, jump_prob = if test_num == 1
142-
create_test1_problem(2, algo, stepper)
143-
elseif test_num == 2
144-
create_test2_problem(2, algo, stepper)
145-
elseif test_num == 3
146-
create_test3_problem(2, algo, stepper)
147-
elseif test_num == 4
148-
create_test4_problem(2, algo, stepper)
149-
elseif test_num == 5
150-
create_test5_problem(2, algo, stepper)
151-
elseif test_num == 6
152-
create_test6_problem(2, algo, stepper)
153-
end
154-
155-
try
156-
sol = solve(jump_prob, stepper; saveat=0.01)
157-
# Plot solution
158-
fig = plot(title="Test $test_num: Solution Trajectory", xlabel="Time", ylabel="State")
159-
if test_num == 6
160-
# For complex ODE, plot real part
161-
plot!(sol.t, real.(sol[1,:]), label="Real Part")
162-
else
163-
# For scalar problems, plot state
164-
plot!(sol.t, sol[1,:], label="u[1]")
165-
end
166-
push!(figs, fig)
167-
catch e
168-
@warn "Failed to solve Test $test_num: $(sprint(showerror, e))"
169-
end
73+
function create_test3_problem(num_jumps, vr_aggregator, solver)
74+
# Parameters from the RSSA paper
75+
r = [0.043, 0.0007, 0.0715, 0.0039, 0.0199, 0.4791, 0.00019, 0.8765, 0.083, 0.5]
76+
k = -log(2) / 30
77+
u0 = [10.0, 10.0, 30.0, 0.0, 0.0, 0.0] # [DNA, M, D, RNA, DNAD, DNA2D]
78+
tspan = (0.0, 120.0)
79+
80+
function f_dna(du, u, p, t)
81+
du .= 0.0
82+
nothing
17083
end
171-
plot(figs..., layout=(4, 2), format=fmt, size=(width_px, 4*height_px/2))
84+
85+
# Define 10 variable rate jumps (fixed set, num_jumps ignored for consistency)
86+
function rate1(u, p, t) r[1] * u[4] end
87+
function affect1!(integrator) integrator.u[2] += 1; nothing end
88+
jump1 = VariableRateJump(rate1, affect1!)
89+
90+
function rate2(u, p, t) r[2] * u[2] end
91+
function affect2!(integrator) integrator.u[2] -= 1; nothing end
92+
jump2 = VariableRateJump(rate2, affect2!)
93+
94+
function rate3(u, p, t) r[3] * u[5] end
95+
function affect3!(integrator) integrator.u[4] += 1; nothing end
96+
jump3 = VariableRateJump(rate3, affect3!)
97+
98+
function rate4(u, p, t) r[4] * u[4] end
99+
function affect4!(integrator) integrator.u[4] -= 1; nothing end
100+
jump4 = VariableRateJump(rate4, affect4!)
101+
102+
function rate5(u, p, t) r[5] * exp(k * t) * u[1] * u[3] end
103+
function affect5!(integrator) integrator.u[1] -= 1; integrator.u[3] -= 1; integrator.u[5] += 1; nothing end
104+
jump5 = VariableRateJump(rate5, affect5!)
105+
106+
function rate6(u, p, t) r[6] * u[5] end
107+
function affect6!(integrator) integrator.u[5] -= 1; integrator.u[1] += 1; integrator.u[3] += 1; nothing end
108+
jump6 = VariableRateJump(rate6, affect6!)
109+
110+
function rate7(u, p, t) r[7] * exp(k * t) * u[5] * u[3] end
111+
function affect7!(integrator) integrator.u[5] -= 1; integrator.u[3] -= 1; integrator.u[6] += 1; nothing end
112+
jump7 = VariableRateJump(rate7, affect7!)
113+
114+
function rate8(u, p, t) r[8] * u[6] end
115+
function affect8!(integrator) integrator.u[6] -= 1; integrator.u[1] += 1; integrator.u[3] += 1; nothing end
116+
jump8 = VariableRateJump(rate8, affect8!)
117+
118+
function rate9(u, p, t) r[9] * exp(k * t) * u[2] * (u[2] - 1) / 2 end
119+
function affect9!(integrator) integrator.u[2] -= 2; integrator.u[3] += 1; nothing end
120+
jump9 = VariableRateJump(rate9, affect9!)
121+
122+
function rate10(u, p, t) r[10] * u[3] end
123+
function affect10!(integrator) integrator.u[3] -= 1; integrator.u[2] += 2; nothing end
124+
jump10 = VariableRateJump(rate10, affect10!)
125+
126+
prob = ODEProblem(f_dna, u0, tspan)
127+
jumps = (jump1, jump2, jump3, jump4, jump5, jump6, jump7, jump8, jump9, jump10)
128+
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
129+
ensemble_prob = EnsembleProblem(jump_prob)
130+
return ensemble_prob, jump_prob
172131
end
173132
```
174133

175134
# Benchmark Execution
176135

177-
We benchmark each test case for 1 to 20 jumps, running 100 trajectories. Errors are logged to diagnose failures.
136+
We benchmark each test case for 1 to 20 jumps. Errors are logged to diagnose failures.
178137

179138
```julia
180139
num_jumps_range = append!([1], 5:5:20)
@@ -195,18 +154,12 @@ for (algo, stepper, agg_name, label) in algorithms
195154
ensemble_prob, jump_prob = create_test2_problem(var, algo, stepper)
196155
elseif test_num == 3
197156
ensemble_prob, jump_prob = create_test3_problem(var, algo, stepper)
198-
elseif test_num == 4
199-
ensemble_prob, jump_prob = create_test4_problem(var, algo, stepper)
200-
elseif test_num == 5
201-
ensemble_prob, jump_prob = create_test5_problem(var, algo, stepper)
202-
elseif test_num == 6
203-
ensemble_prob, jump_prob = create_test6_problem(var, algo, stepper)
204157
end
205158
trial = try
206159
@benchmark(
207-
solve($jump_prob, $stepper),
208-
samples=50,
209-
evals=1,
160+
solve($jump_prob, $stepper),
161+
samples=50,
162+
evals=1,
210163
seconds=10
211164
)
212165
catch e
@@ -237,13 +190,12 @@ We plot the median execution times for each test case, comparing `VR_Direct`,`VR
237190

238191
```julia
239192
let figs = []
240-
for test_num in 1:6
193+
for test_num in 1:3
241194
test_algorithms = filter(a -> parse(Int, match(r"Test (\d+)", a[4]).captures[1]) == test_num, algorithms)
242-
is_matrix_test = test_num == 7
243-
range_var = is_matrix_test ? matrix_sizes : num_jumps_range
195+
range_var = num_jumps_range
244196
fig = plot(
245197
yscale=:log10,
246-
xlabel=is_matrix_test ? "Matrix Size" : "Number of Jumps",
198+
xlabel="Number of Jumps",
247199
ylabel="Time (ns)",
248200
legend_position=:outertopright,
249201
title="Test $test_num: Simulations, 50 samples"
@@ -265,6 +217,6 @@ let figs = []
265217
end
266218
push!(figs, fig)
267219
end
268-
plot(figs..., layout=(6, 1), format=fmt, size=(width_px, 8*height_px/2))
220+
plot(figs..., layout=(3, 1), format=fmt, size=(width_px, 4*height_px))
269221
end
270222
```

0 commit comments

Comments
 (0)