Skip to content

Commit 0551172

Browse files
Merge pull request #710 from isaacsas/fix-reactionsys-to-nonlinsys
Fix conversion of reactionsys to nonlinsys
2 parents e799bae + 862278e commit 0551172

File tree

3 files changed

+58
-59
lines changed

3 files changed

+58
-59
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ jacobian_sparsity(sys::NonlinearSystem) =
7575

7676
"""
7777
```julia
78-
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,tspan,
78+
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,
7979
parammap=DiffEqBase.NullParameters();
8080
jac = false, sparse=false,
8181
checkbounds = false,
@@ -104,7 +104,7 @@ end
104104

105105
"""
106106
```julia
107-
function DiffEqBase.NonlinearProblemExpr{iip}(sys::NonlinearSystem,u0map,tspan,
107+
function DiffEqBase.NonlinearProblemExpr{iip}(sys::NonlinearSystem,u0map,
108108
parammap=DiffEqBase.NullParameters();
109109
jac = false, sparse=false,
110110
checkbounds = false,
@@ -118,7 +118,7 @@ numerical enhancements.
118118
"""
119119
struct NonlinearProblemExpr{iip} end
120120

121-
function NonlinearProblemExpr{iip}(sys::NonlinearSystem,u0map,tspan,
121+
function NonlinearProblemExpr{iip}(sys::NonlinearSystem,u0map,
122122
parammap=DiffEqBase.NullParameters();
123123
jac = false, sparse=false,
124124
checkbounds = false,

src/systems/reaction/reactionsystem.jl

Lines changed: 48 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -188,25 +188,35 @@ function oderatelaw(rx; combinatoric_ratelaw=true)
188188
rl
189189
end
190190

191-
function assemble_drift(rs; combinatoric_ratelaws=true)
192-
D = Differential(rs.iv)
193-
eqs = [D(x) ~ 0 for x in rs.states]
191+
function assemble_oderhs(rs; combinatoric_ratelaws=true)
194192
species_to_idx = Dict((x => i for (i,x) in enumerate(rs.states)))
193+
rhsvec = Any[0 for i in eachindex(rs.states)]
195194

196195
for rx in rs.eqs
197196
rl = oderatelaw(rx; combinatoric_ratelaw=combinatoric_ratelaws)
198197
for (spec,stoich) in rx.netstoich
199198
i = species_to_idx[spec]
200-
if _iszero(eqs[i].rhs)
201-
signedrl = (stoich > zero(stoich)) ? rl : -rl
202-
rhs = isone(abs(stoich)) ? signedrl : stoich * rl
199+
if _iszero(rhsvec[i])
200+
signedrl = (stoich > zero(stoich)) ? rl : -rl
201+
rhsvec[i] = isone(abs(stoich)) ? signedrl : stoich * rl
203202
else
204-
Δspec = isone(abs(stoich)) ? rl : abs(stoich) * rl
205-
rhs = (stoich > zero(stoich)) ? (eqs[i].rhs + Δspec) : (eqs[i].rhs - Δspec)
203+
Δspec = isone(abs(stoich)) ? rl : abs(stoich) * rl
204+
rhsvec[i] = (stoich > zero(stoich)) ? (rhsvec[i] + Δspec) : (rhsvec[i] - Δspec)
206205
end
207-
eqs[i] = Equation(eqs[i].lhs, rhs)
208206
end
209207
end
208+
209+
rhsvec
210+
end
211+
212+
function assemble_drift(rs; combinatoric_ratelaws=true, as_odes=true)
213+
rhsvec = assemble_oderhs(rs; combinatoric_ratelaws=combinatoric_ratelaws)
214+
if as_odes
215+
D = Differential(rs.iv)
216+
eqs = [Equation(D(x),rhs) for (x,rhs) in zip(rs.states,rhsvec)]
217+
else
218+
eqs = [Equation(0,rhs) for rhs in rhsvec]
219+
end
210220
eqs
211221
end
212222

@@ -226,15 +236,6 @@ function assemble_diffusion(rs, noise_scaling; combinatoric_ratelaws=true)
226236
eqs
227237
end
228238

229-
function var2op(var)
230-
Sym{symtype(var)}(nameof(var.op))
231-
end
232-
function var2op(var::Sym)
233-
var
234-
end
235-
236-
# Calculate the Jump rate law (like ODE, but uses X instead of X(t).
237-
# The former generates a "MethodError: objects of type Int64 are not callable" when trying to solve the problem.
238239
"""
239240
jumpratelaw(rx; rxvars=get_variables(rx.rate), combinatoric_ratelaw=true)
240241
@@ -263,7 +264,6 @@ Notes:
263264
function jumpratelaw(rx; rxvars=get_variables(rx.rate), combinatoric_ratelaw=true)
264265
@unpack rate, substrates, substoich, only_use_rate = rx
265266
rl = rate
266-
#rl = substitute(rl, Dict(rxvars .=> var2op.(rxvars)))
267267
if !only_use_rate
268268
coef = one(eltype(substoich))
269269
for (i,stoich) in enumerate(substoich)
@@ -300,7 +300,7 @@ explicitly on the independent variable (usually time).
300300
- Optional: `stateset`, set of states which if the rxvars are within mean rx is non-mass action.
301301
"""
302302
function ismassaction(rx, rs; rxvars = get_variables(rx.rate),
303-
haveivdep,
303+
haveivdep = any(var -> isequal(rs.iv,var), rxvars),
304304
stateset = Set(states(rs)))
305305
# if no dependencies must be zero order
306306
(length(rxvars)==0) && return true
@@ -383,6 +383,24 @@ function Base.convert(::Type{<:ODESystem}, rs::ReactionSystem; combinatoric_rate
383383
systems=convert.(ODESystem,rs.systems))
384384
end
385385

386+
"""
387+
```julia
388+
Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem)
389+
```
390+
391+
Convert a [`ReactionSystem`](@ref) to an [`NonlinearSystem`](@ref).
392+
393+
Notes:
394+
- `combinatoric_ratelaws=true` uses factorial scaling factors in calculating the rate
395+
law, i.e. for `2S -> 0` at rate `k` the ratelaw would be `k*S^2/2!`. If
396+
`combinatoric_ratelaws=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
397+
ignored.
398+
"""
399+
function Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem; combinatoric_ratelaws=true)
400+
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws, as_odes=false)
401+
NonlinearSystem(eqs,rs.states,rs.ps,name=rs.name,systems=convert.(NonlinearSystem,rs.systems))
402+
end
403+
386404
"""
387405
```julia
388406
Base.convert(::Type{<:SDESystem},rs::ReactionSystem)
@@ -450,55 +468,29 @@ function Base.convert(::Type{<:JumpSystem},rs::ReactionSystem; combinatoric_rate
450468
end
451469

452470

453-
"""
454-
```julia
455-
Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem)
456-
```
457-
458-
Convert a [`ReactionSystem`](@ref) to an [`NonlinearSystem`](@ref).
459-
460-
Notes:
461-
- `combinatoric_ratelaws=true` uses factorial scaling factors in calculating the rate
462-
law, i.e. for `2S -> 0` at rate `k` the ratelaw would be `k*S^2/2!`. If
463-
`combinatoric_ratelaws=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
464-
ignored.
465-
"""
466-
function Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem; combinatoric_ratelaws=true)
467-
states_swaps = value.(rs.states)
468-
eqs = map(eq -> 0 ~ make_sub!(eq,states_swaps),getproperty.(assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws),:rhs))
469-
NonlinearSystem(eqs,rs.states,rs.ps,name=rs.name,
470-
systems=convert.(NonlinearSystem,rs.systems))
471-
end
472-
473-
# Used for Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem) only, should likely be removed.
474-
function make_sub!(eq,states_swaps)
475-
for (i,arg) in enumerate(eq.args)
476-
if any(isequal.(states_swaps,arg))
477-
eq.args[i] = var2op(arg.op)
478-
else
479-
make_sub!(arg,states_swaps)
480-
end
481-
end
482-
return eq
483-
end
484-
485471
### Converts a reaction system to ODE or SDE problems ###
486472

487473

488474
# ODEProblem from AbstractReactionNetwork
489-
function DiffEqBase.ODEProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, tspan, p=DiffEqBase.NullParameters(), args...; kwargs...)
475+
function DiffEqBase.ODEProblem(rs::ReactionSystem, u0, tspan, p=DiffEqBase.NullParameters(), args...; kwargs...)
490476
return ODEProblem(convert(ODESystem,rs),u0,tspan,p, args...; kwargs...)
491477
end
492478

479+
# NonlinearProblem from AbstractReactionNetwork
480+
function DiffEqBase.NonlinearProblem(rs::ReactionSystem, u0, p=DiffEqBase.NullParameters(), args...; kwargs...)
481+
return NonlinearProblem(convert(NonlinearSystem,rs), u0, p, args...; kwargs...)
482+
end
483+
484+
493485
# SDEProblem from AbstractReactionNetwork
494-
function DiffEqBase.SDEProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, tspan, p=DiffEqBase.NullParameters(), args...; noise_scaling=nothing, kwargs...)
486+
function DiffEqBase.SDEProblem(rs::ReactionSystem, u0, tspan, p=DiffEqBase.NullParameters(), args...; noise_scaling=nothing, kwargs...)
495487
sde_sys = convert(SDESystem,rs,noise_scaling=noise_scaling)
496488
p_matrix = zeros(length(rs.states), length(rs.eqs))
497489
return SDEProblem(sde_sys,u0,tspan,p,args...; noise_rate_prototype=p_matrix,kwargs...)
498490
end
499491

500492
# DiscreteProblem from AbstractReactionNetwork
501-
function DiffEqBase.DiscreteProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, tspan::Tuple, p=DiffEqBase.NullParameters(), args...; kwargs...)
493+
function DiffEqBase.DiscreteProblem(rs::ReactionSystem, u0, tspan::Tuple, p=DiffEqBase.NullParameters(), args...; kwargs...)
502494
return DiscreteProblem(convert(JumpSystem,rs), u0,tspan,p, args...; kwargs...)
503495
end
504496

@@ -508,7 +500,7 @@ function DiffEqJump.JumpProblem(rs::ReactionSystem, prob, aggregator, args...; k
508500
end
509501

510502
# SteadyStateProblem from AbstractReactionNetwork
511-
function DiffEqBase.SteadyStateProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, p=DiffEqBase.NullParameters(), args...; kwargs...)
503+
function DiffEqBase.SteadyStateProblem(rs::ReactionSystem, u0, p=DiffEqBase.NullParameters(), args...; kwargs...)
512504
return SteadyStateProblem(ODEFunction(convert(ODESystem,rs)),u0,p, args...; kwargs...)
513505
end
514506

test/reactionsystem.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ du2 = sf.f(u,p,t)
8181
G2 = sf.g(u,p,t)
8282
@test norm(G-G2) < 100*eps()
8383

84+
# test conversion to NonlinearSystem
85+
ns = convert(NonlinearSystem,rs)
86+
fnl = eval(generate_function(ns)[2])
87+
dunl = similar(du)
88+
fnl(dunl,u,p)
89+
@test norm(du-dunl) < 100*eps()
90+
8491
# tests the noise_scaling argument.
8592
p = rand(length(k)+1)
8693
u = rand(length(k))

0 commit comments

Comments
 (0)