Skip to content

Commit 6f84bde

Browse files
committed
fix conversion of ReactionSystems to NonlinearSystems
1 parent 3816968 commit 6f84bde

File tree

3 files changed

+57
-49
lines changed

3 files changed

+57
-49
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ jacobian_sparsity(sys::NonlinearSystem) =
7575

7676
"""
7777
```julia
78-
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,tspan,
78+
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,
7979
parammap=DiffEqBase.NullParameters();
8080
jac = false, sparse=false,
8181
checkbounds = false,
@@ -104,7 +104,7 @@ end
104104

105105
"""
106106
```julia
107-
function DiffEqBase.NonlinearProblemExpr{iip}(sys::NonlinearSystem,u0map,tspan,
107+
function DiffEqBase.NonlinearProblemExpr{iip}(sys::NonlinearSystem,u0map,
108108
parammap=DiffEqBase.NullParameters();
109109
jac = false, sparse=false,
110110
checkbounds = false,
@@ -118,7 +118,7 @@ numerical enhancements.
118118
"""
119119
struct NonlinearProblemExpr{iip} end
120120

121-
function NonlinearProblemExpr{iip}(sys::NonlinearSystem,u0map,tspan,
121+
function NonlinearProblemExpr{iip}(sys::NonlinearSystem,u0map,
122122
parammap=DiffEqBase.NullParameters();
123123
jac = false, sparse=false,
124124
checkbounds = false,

src/systems/reaction/reactionsystem.jl

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -188,25 +188,34 @@ function oderatelaw(rx; combinatoric_ratelaw=true)
188188
rl
189189
end
190190

191-
function assemble_drift(rs; combinatoric_ratelaws=true)
192-
D = Differential(rs.iv)
193-
eqs = [D(x) ~ 0 for x in rs.states]
191+
function assemble_oderhs(rs; combinatoric_ratelaws=true)
194192
species_to_idx = Dict((x => i for (i,x) in enumerate(rs.states)))
195-
193+
rhsvec = [Num(0) for i in eachindex(rs.states)]
196194
for rx in rs.eqs
197195
rl = oderatelaw(rx; combinatoric_ratelaw=combinatoric_ratelaws)
198196
for (spec,stoich) in rx.netstoich
199197
i = species_to_idx[spec]
200-
if _iszero(eqs[i].rhs)
201-
signedrl = (stoich > zero(stoich)) ? rl : -rl
202-
rhs = isone(abs(stoich)) ? signedrl : stoich * rl
198+
if _iszero(rhsvec[i])
199+
signedrl = (stoich > zero(stoich)) ? rl : -rl
200+
rhsvec[i] = isone(abs(stoich)) ? signedrl : stoich * rl
203201
else
204-
Δspec = isone(abs(stoich)) ? rl : abs(stoich) * rl
205-
rhs = (stoich > zero(stoich)) ? (eqs[i].rhs + Δspec) : (eqs[i].rhs - Δspec)
202+
Δspec = isone(abs(stoich)) ? rl : abs(stoich) * rl
203+
rhsvec[i] = (stoich > zero(stoich)) ? (rhsvec[i] + Δspec) : (rhsvec[i] - Δspec)
206204
end
207-
eqs[i] = Equation(eqs[i].lhs, rhs)
208205
end
209206
end
207+
208+
rhsvec
209+
end
210+
211+
function assemble_drift(rs; combinatoric_ratelaws=true, as_odes=true)
212+
rhsvec = assemble_oderhs(rs; combinatoric_ratelaws=combinatoric_ratelaws)
213+
if as_odes
214+
D = Differential(rs.iv)
215+
eqs = [Equation(D(x),rhs) for (x,rhs) in zip(rs.states,rhsvec)]
216+
else
217+
eqs = [Equation(Num(0),rhs) for rhs in rhsvec]
218+
end
210219
eqs
211220
end
212221

@@ -383,6 +392,24 @@ function Base.convert(::Type{<:ODESystem}, rs::ReactionSystem; combinatoric_rate
383392
systems=convert.(ODESystem,rs.systems))
384393
end
385394

395+
"""
396+
```julia
397+
Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem)
398+
```
399+
400+
Convert a [`ReactionSystem`](@ref) to an [`NonlinearSystem`](@ref).
401+
402+
Notes:
403+
- `combinatoric_ratelaws=true` uses factorial scaling factors in calculating the rate
404+
law, i.e. for `2S -> 0` at rate `k` the ratelaw would be `k*S^2/2!`. If
405+
`combinatoric_ratelaws=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
406+
ignored.
407+
"""
408+
function Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem; combinatoric_ratelaws=true)
409+
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws, as_odes=false)
410+
NonlinearSystem(eqs,rs.states,rs.ps,name=rs.name,systems=convert.(NonlinearSystem,rs.systems))
411+
end
412+
386413
"""
387414
```julia
388415
Base.convert(::Type{<:SDESystem},rs::ReactionSystem)
@@ -450,55 +477,29 @@ function Base.convert(::Type{<:JumpSystem},rs::ReactionSystem; combinatoric_rate
450477
end
451478

452479

453-
"""
454-
```julia
455-
Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem)
456-
```
457-
458-
Convert a [`ReactionSystem`](@ref) to an [`NonlinearSystem`](@ref).
459-
460-
Notes:
461-
- `combinatoric_ratelaws=true` uses factorial scaling factors in calculating the rate
462-
law, i.e. for `2S -> 0` at rate `k` the ratelaw would be `k*S^2/2!`. If
463-
`combinatoric_ratelaws=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
464-
ignored.
465-
"""
466-
function Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem; combinatoric_ratelaws=true)
467-
states_swaps = value.(rs.states)
468-
eqs = map(eq -> 0 ~ make_sub!(eq,states_swaps),getproperty.(assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws),:rhs))
469-
NonlinearSystem(eqs,rs.states,rs.ps,name=rs.name,
470-
systems=convert.(NonlinearSystem,rs.systems))
471-
end
472-
473-
# Used for Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem) only, should likely be removed.
474-
function make_sub!(eq,states_swaps)
475-
for (i,arg) in enumerate(eq.args)
476-
if any(isequal.(states_swaps,arg))
477-
eq.args[i] = var2op(arg.op)
478-
else
479-
make_sub!(arg,states_swaps)
480-
end
481-
end
482-
return eq
483-
end
484-
485480
### Converts a reaction system to ODE or SDE problems ###
486481

487482

488483
# ODEProblem from AbstractReactionNetwork
489-
function DiffEqBase.ODEProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, tspan, p=DiffEqBase.NullParameters(), args...; kwargs...)
484+
function DiffEqBase.ODEProblem(rs::ReactionSystem, u0, tspan, p=DiffEqBase.NullParameters(), args...; kwargs...)
490485
return ODEProblem(convert(ODESystem,rs),u0,tspan,p, args...; kwargs...)
491486
end
492487

488+
# NonlinearProblem from AbstractReactionNetwork
489+
function DiffEqBase.NonlinearProblem(rs::ReactionSystem, u0, p=DiffEqBase.NullParameters(), args...; kwargs...)
490+
return NonlinearProblem(convert(NonlinearSystem,rs), u0, p, args...; kwargs...)
491+
end
492+
493+
493494
# SDEProblem from AbstractReactionNetwork
494-
function DiffEqBase.SDEProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, tspan, p=DiffEqBase.NullParameters(), args...; noise_scaling=nothing, kwargs...)
495+
function DiffEqBase.SDEProblem(rs::ReactionSystem, u0, tspan, p=DiffEqBase.NullParameters(), args...; noise_scaling=nothing, kwargs...)
495496
sde_sys = convert(SDESystem,rs,noise_scaling=noise_scaling)
496497
p_matrix = zeros(length(rs.states), length(rs.eqs))
497498
return SDEProblem(sde_sys,u0,tspan,p,args...; noise_rate_prototype=p_matrix,kwargs...)
498499
end
499500

500501
# DiscreteProblem from AbstractReactionNetwork
501-
function DiffEqBase.DiscreteProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, tspan::Tuple, p=DiffEqBase.NullParameters(), args...; kwargs...)
502+
function DiffEqBase.DiscreteProblem(rs::ReactionSystem, u0, tspan::Tuple, p=DiffEqBase.NullParameters(), args...; kwargs...)
502503
return DiscreteProblem(convert(JumpSystem,rs), u0,tspan,p, args...; kwargs...)
503504
end
504505

@@ -508,7 +509,7 @@ function DiffEqJump.JumpProblem(rs::ReactionSystem, prob, aggregator, args...; k
508509
end
509510

510511
# SteadyStateProblem from AbstractReactionNetwork
511-
function DiffEqBase.SteadyStateProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, p=DiffEqBase.NullParameters(), args...; kwargs...)
512+
function DiffEqBase.SteadyStateProblem(rs::ReactionSystem, u0, p=DiffEqBase.NullParameters(), args...; kwargs...)
512513
return SteadyStateProblem(ODEFunction(convert(ODESystem,rs)),u0,p, args...; kwargs...)
513514
end
514515

test/reactionsystem.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ du2 = sf.f(u,p,t)
8181
G2 = sf.g(u,p,t)
8282
@test norm(G-G2) < 100*eps()
8383

84+
# test conversion to NonlinearSystem
85+
ns = convert(NonlinearSystem,rs)
86+
fnl = eval(generate_function(ns)[2])
87+
dunl = similar(du)
88+
fnl(dunl,u,p)
89+
@test norm(du-dunl) < 100*eps()
90+
8491
# tests the noise_scaling argument.
8592
p = rand(length(k)+1)
8693
u = rand(length(k))

0 commit comments

Comments
 (0)