Skip to content

Commit e99cdb5

Browse files
committed
update GPU code too
1 parent 7b949c4 commit e99cdb5

File tree

4 files changed

+49
-11
lines changed

4 files changed

+49
-11
lines changed

ext/JumpProcessesKernelAbstractionsExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem,
2222

2323
jump_prob = ensembleprob.prob
2424

25-
# boilerplate from SimpleTauLeaping method
26-
@assert isempty(jump_prob.jump_callback.continuous_callbacks) # still needs to be a regular jump
27-
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
25+
# Validate that this is a PureLeaping JumpProblem
26+
validate_pure_leaping_inputs(jump_prob) ||
27+
error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only non-RegularJumps.")
2828
prob = jump_prob.prob
2929

3030
probs = [remake(jump_prob) for _ in 1:trajectories]

src/simple_regular_solve.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
struct SimpleTauLeaping <: DiffEqBase.DEAlgorithm end
22

3+
function validate_pure_leaping_inputs(jump_prob::JumpProblem)
4+
jump_prob.aggregator isa PureLeaping &&
5+
isempty(jump_prob.jump_callback.continuous_callbacks) &&
6+
isempty(jump_prob.jump_callback.discrete_callbacks) &&
7+
isempty(jump_prob.constant_jumps) &&
8+
isempty(jump_prob.variable_jumps) &&
9+
get_num_majumps(jump_prob.massaction_jump) == 0 &&
10+
jump_prob.regular_jump !== nothing
11+
end
12+
313
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
414
seed = nothing,
515
dt = error("dt is required for SimpleTauLeaping."))
616

7-
# boilerplate from SimpleTauLeaping method
8-
@assert isempty(jump_prob.jump_callback.continuous_callbacks) # still needs to be a regular jump
9-
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
17+
validate_pure_leaping_inputs(jump_prob) ||
18+
error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only non-RegularJumps.")
1019
prob = jump_prob.prob
1120
rng = DEFAULT_RNG
1221
(seed !== nothing) && seed!(rng, seed)
@@ -62,4 +71,3 @@ end
6271
function EnsembleGPUKernel()
6372
EnsembleGPUKernel(nothing, 0.0)
6473
end
65-

test/gpu/regular_jumps.jl

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ let
3131

3232
prob_disc = DiscreteProblem(u0, tspan, p)
3333
rj = RegularJump(regular_rate, regular_c, 3)
34-
jump_prob = JumpProblem(prob_disc, Direct(), rj)
34+
jump_prob = JumpProblem(prob_disc, PureLeaping(), rj)
3535

3636
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(),
3737
EnsembleGPUKernel(CUDABackend()); trajectories = Nsims, dt = 1.0)
@@ -72,7 +72,7 @@ let
7272
# Create JumpProblem
7373
prob_disc = DiscreteProblem(u0, tspan, p)
7474
rj = RegularJump(regular_rate, regular_c, 3)
75-
jump_prob = JumpProblem(prob_disc, Direct(), rj; rng = StableRNG(12345))
75+
jump_prob = JumpProblem(prob_disc, PureLeaping(), rj; rng = StableRNG(12345))
7676

7777
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(),
7878
EnsembleGPUKernel(CUDABackend()); trajectories = Nsims, dt = 1.0)
@@ -84,3 +84,33 @@ let
8484

8585
@test isapprox(mean_kernel, mean_serial, rtol = 0.05)
8686
end
87+
88+
# Test PureLeaping validation for GPU
89+
@testset "PureLeaping GPU Validation" begin
90+
β = 0.1 / 1000.0
91+
ν = 0.01
92+
p = (β, ν)
93+
94+
regular_rate = (out, u, p, t) -> begin
95+
out[1] = p[1] * u[1] * u[2] # β*S*I (infection)
96+
out[2] = p[2] * u[2] # ν*I (recovery)
97+
end
98+
99+
regular_c = (dc, u, p, t, counts, mark) -> begin
100+
dc .= 0.0
101+
dc[1] = -counts[1] # S: -infection
102+
dc[2] = counts[1] - counts[2] # I: +infection - recovery
103+
dc[3] = counts[2] # R: +recovery
104+
end
105+
106+
u0 = [999.0, 10.0, 0.0] # S, I, R
107+
tspan = (0.0, 50.0)
108+
prob_disc = DiscreteProblem(u0, tspan, p)
109+
rj = RegularJump(regular_rate, regular_c, 2)
110+
111+
# This should fail - Direct aggregator with SimpleTauLeaping
112+
jump_prob_direct = JumpProblem(prob_disc, Direct(), rj)
113+
@test_throws ErrorException solve(EnsembleProblem(jump_prob_direct),
114+
SimpleTauLeaping(), EnsembleGPUKernel(CUDABackend());
115+
trajectories = 10, dt = 1.0)
116+
end

test/regular_jumps.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ rj = RegularJump(regular_rate, regular_c, dc; constant_c = true)
2121
jumps = JumpSet(rj)
2222

2323
prob = DiscreteProblem([999.0, 1.0, 0.0], (0.0, 250.0))
24-
jump_prob = JumpProblem(prob, Direct(), rj; rng = rng)
24+
jump_prob = JumpProblem(prob, PureLeaping(), rj; rng)
2525
sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
2626

2727
const _dc = zeros(3, 2)
@@ -37,7 +37,7 @@ end
3737
rj = RegularJump(regular_rate, regular_c, 2)
3838
jumps = JumpSet(rj)
3939
prob = DiscreteProblem([999, 1, 0], (0.0, 250.0))
40-
jump_prob = JumpProblem(prob, Direct(), rj; rng = rng)
40+
jump_prob = JumpProblem(prob, PureLeaping(), rj; rng)
4141
sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
4242

4343
# Test PureLeaping aggregator functionality

0 commit comments

Comments
 (0)