Skip to content

Commit 5b6d6ec

Browse files
committed
enable rebuildig oproblems/integrators
1 parent 4731155 commit 5b6d6ec

File tree

5 files changed

+260
-41
lines changed

5 files changed

+260
-41
lines changed

src/Catalyst.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ export make_edge_p_values, make_directed_edge_values
181181

182182
# Specific spatial problem types.
183183
include("spatial_reaction_systems/spatial_ODE_systems.jl")
184+
export rebuild_lat_internals!
184185
include("spatial_reaction_systems/lattice_jump_systems.jl")
185186

186187
# General spatial modelling utility functions.

src/spatial_reaction_systems/spatial_ODE_systems.jl

Lines changed: 132 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
### Spatial ODE Functor Structure ###
22

3-
# Functor with information about a spatial Lattice Reaction ODE;s forcing and Jacobian functions.
3+
# Functor with information about a spatial Lattice Reaction ODEs forcing and Jacobian functions.
44
# Also used as ODE Function input to corresponding `ODEProblem`.
55
struct LatticeTransportODEFunction{P,Q,R,S,T}
66
"""
@@ -59,40 +59,60 @@ struct LatticeTransportODEFunction{P,Q,R,S,T}
5959
used).
6060
"""
6161
jac_transport::T
62+
""" Whether sparse jacobian representation is used. """
63+
sparse::Bool
64+
"""Remove when we add this as problem metadata"""
65+
lrs::LatticeReactionSystem
6266

6367
function LatticeTransportODEFunction(ofunc::P, ps::Vector{<:Pair},
6468
lrs::LatticeReactionSystem, transport_rates::Vector{Pair{Int64, SparseMatrixCSC{S, Int64}}},
65-
jac_transport::Union{Nothing, Matrix{S}, SparseMatrixCSC{S, Int64}}) where {P,S}
66-
67-
# Creates a vector with the heterogeneous vertex parameters' indexes in the full parameter vector.
68-
p_dict = Dict(ps)
69-
heterogeneous_vert_p_idxs = findall((p_dict[p] isa Vector) && (length(p_dict[p]) > 1)
70-
for p in parameters(lrs))
71-
72-
# Creates the MTKParameters structure and `p_setters` vector (which are used to manage
73-
# the vertex parameter values during the simulations).
74-
nonspatial_osys = complete(convert(ODESystem, reactionsystem(lrs)))
75-
p_init = [p => p_dict[p][1] for p in parameters(nonspatial_osys)]
76-
mtk_ps = MT.MTKParameters(nonspatial_osys, p_init)
77-
p_setters = [MT.setp(nonspatial_osys, p) for p in parameters(lrs)[heterogeneous_vert_p_idxs]]
78-
79-
# Computes the transport rate type vector and leaving rate matrix.
80-
t_rate_idx_types = [size(tr[2]) == (1,1) for tr in transport_rates]
81-
leaving_rates = zeros(length(transport_rates), num_verts(lrs))
82-
for (s_idx, tr_pair) in enumerate(transport_rates)
83-
for e in Catalyst.edge_iterator(lrs)
84-
# Updates the exit rate for species s_idx from vertex e.src.
85-
leaving_rates[s_idx, e[1]] += get_transport_rate(tr_pair[2], e, t_rate_idx_types[s_idx])
86-
end
87-
end
69+
jac_transport::Union{Nothing, Matrix{S}, SparseMatrixCSC{S, Int64}}, sparse) where {P,S}
70+
# Computes `LatticeTransportODEFunction` functor fields.
71+
heterogeneous_vert_p_idxs = make_heterogeneous_vert_p_idxs(ps, lrs)
72+
mtk_ps, p_setters = make_mtk_ps_structs(ps, lrs, heterogeneous_vert_p_idxs)
73+
t_rate_idx_types, leaving_rates = make_t_types_and_leaving_rates(transport_rates, lrs)
8874

8975
# Creates and returns the `LatticeTransportODEFunction` functor.
9076
new{P,typeof(mtk_ps),typeof(p_setters),S,typeof(jac_transport)}(ofunc, num_verts(lrs),
9177
num_species(lrs), heterogeneous_vert_p_idxs, mtk_ps, p_setters, transport_rates,
92-
t_rate_idx_types, leaving_rates, Catalyst.edge_iterator(lrs), jac_transport)
78+
t_rate_idx_types, leaving_rates, Catalyst.edge_iterator(lrs), jac_transport, sparse, lrs)
79+
end
80+
end
81+
82+
# `LatticeTransportODEFunction` helper functions (re used by rebuild function later on).
83+
84+
# Creates a vector with the heterogeneous vertex parameters' indexes in the full parameter vector.
85+
function make_heterogeneous_vert_p_idxs(ps, lrs)
86+
p_dict = Dict(ps)
87+
return findall((p_dict[p] isa Vector) && (length(p_dict[p]) > 1) for p in parameters(lrs))
88+
end
89+
90+
# Creates the MTKParameters structure and `p_setters` vector (which are used to manage
91+
# the vertex parameter values during the simulations).
92+
function make_mtk_ps_structs(ps, lrs, heterogeneous_vert_p_idxs)
93+
p_dict = Dict(ps)
94+
nonspatial_osys = complete(convert(ODESystem, reactionsystem(lrs)))
95+
p_init = [p => p_dict[p][1] for p in parameters(nonspatial_osys)]
96+
mtk_ps = MT.MTKParameters(nonspatial_osys, p_init)
97+
p_setters = [MT.setp(nonspatial_osys, p) for p in parameters(lrs)[heterogeneous_vert_p_idxs]]
98+
return mtk_ps, p_setters
99+
end
100+
101+
# Computes the transport rate type vector and leaving rate matrix.
102+
function make_t_types_and_leaving_rates(transport_rates, lrs)
103+
t_rate_idx_types = [size(tr[2]) == (1,1) for tr in transport_rates]
104+
leaving_rates = zeros(length(transport_rates), num_verts(lrs))
105+
for (s_idx, tr_pair) in enumerate(transport_rates)
106+
for e in Catalyst.edge_iterator(lrs)
107+
# Updates the exit rate for species s_idx from vertex e.src.
108+
leaving_rates[s_idx, e[1]] += get_transport_rate(tr_pair[2], e, t_rate_idx_types[s_idx])
109+
end
93110
end
111+
return t_rate_idx_types, leaving_rates
94112
end
95113

114+
### Spatial ODE Functor Functions ###
115+
96116
# Defines the functor's effect when applied as a forcing function.
97117
function (lt_ofun::LatticeTransportODEFunction)(du::AbstractVector, u, p, t)
98118
# Updates for non-spatial reactions.
@@ -198,7 +218,7 @@ function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{R,Ve
198218
transport_rates = make_sidxs_to_transrate_map(vert_ps, edge_ps, lrs)
199219

200220
# Depending on Jacobian and sparsity options, computes the Jacobian transport matrix and prototype.
201-
if sparse && !jac
221+
if !sparse && !jac
202222
jac_transport = nothing
203223
jac_prototype = nothing
204224
else
@@ -209,7 +229,7 @@ function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{R,Ve
209229
end
210230

211231
# Creates the `LatticeTransportODEFunction` functor (if `jac`, sets it as the Jacobian as well).
212-
f = LatticeTransportODEFunction(ofunc_dense, [vert_ps; edge_ps], lrs, transport_rates, jac_transport)
232+
f = LatticeTransportODEFunction(ofunc_dense, [vert_ps; edge_ps], lrs, transport_rates, jac_transport, sparse)
213233
J = (jac ? f : nothing)
214234

215235
# Extracts the `Symbol` form for species and parameters. Creates and returns the `ODEFunction`.
@@ -267,23 +287,95 @@ function build_jac_prototype(ns_jac_prototype::SparseMatrixCSC{Float64, Int64},
267287
end
268288
end
269289

270-
# Create a sparse Jacobian prototype with 0-valued entries.
290+
# Create a sparse Jacobian prototype with 0-valued entries. If requested,
291+
# updates values with non-zero entries.
271292
jac_prototype = sparse(i_idxs, j_idxs, zeros(T, num_entries))
293+
set_nonzero && set_jac_transport_values!(jac_prototype, transport_rates, lrs)
272294

273-
# Set element values.
274-
if set_nonzero
275-
for (s, rates) in transport_rates, e in edge_iterator(lrs)
276-
idx_src = get_index(e[1], s, num_species(lrs))
277-
idx_dst = get_index(e[2], s, num_species(lrs))
278-
val = get_transport_rate(rates, e, size(rates)==(1,1))
295+
return jac_prototype
296+
end
279297

280-
# Term due to species leaving source vertex.
281-
jac_prototype[idx_src, idx_src] -= val
298+
# For a Jacobian prototype with zero-valued entries. Set entry values according to a set of
299+
# transport reaction values.
300+
function set_jac_transport_values!(jac_prototype, transport_rates, lrs)
301+
for (s, rates) in transport_rates, e in edge_iterator(lrs)
302+
idx_src = get_index(e[1], s, num_species(lrs))
303+
idx_dst = get_index(e[2], s, num_species(lrs))
304+
val = get_transport_rate(rates, e, size(rates)==(1,1))
282305

283-
# Term due to species arriving to destination vertex.
284-
jac_prototype[idx_src, idx_dst] += val
285-
end
306+
# Term due to species leaving source vertex.
307+
jac_prototype[idx_src, idx_src] -= val
308+
309+
# Term due to species arriving to destination vertex.
310+
jac_prototype[idx_src, idx_dst] += val
286311
end
312+
end
287313

288-
return jac_prototype
314+
### Functor Updating Functionality ###
315+
316+
# Function for rebuilding a `LatticeReactionSystem` `ODEProblem` after it has been updated.
317+
function rebuild_lat_internals!(oprob::ODEProblem)
318+
rebuild_lat_internals!(oprob.f.f, oprob.p, oprob.f.f.lrs)
319+
end
320+
321+
# Function for rebuilding a `LatticeReactionSystem` integrator after it has been updated.
322+
# We could specify `integrator`'s type, but that required adding OrdinaryDiffEq as a direct
323+
# dependency of Catalyst.
324+
function rebuild_lat_internals!(integrator)
325+
rebuild_lat_internals!(integrator.f.f, integrator.p, integrator.f.f.lrs)
289326
end
327+
328+
# Function which rebuilds a `LatticeTransportODEFunction` functor for a new parameter set.
329+
function rebuild_lat_internals!(lt_ofun::LatticeTransportODEFunction, ps_new, lrs::LatticeReactionSystem)
330+
# Computes Jacobian properties.
331+
jac = !isnothing(lt_ofun.jac_transport)
332+
sparse = lt_ofun.sparse
333+
334+
# Recreates the new parameters on the requisite form.
335+
ps_new = [(length(p) == 1) ? p[1] : p for p in deepcopy(ps_new)]
336+
ps_new = [p => p_val for (p, p_val) in zip(parameters(lrs), deepcopy(ps_new))]
337+
vert_ps, edge_ps = lattice_process_p(ps_new, vertex_parameters(lrs), edge_parameters(lrs), lrs)
338+
ps_new = [vert_ps; edge_ps]
339+
340+
# Creates the new transport rates and transport Jacobian part.
341+
transport_rates = make_sidxs_to_transrate_map(vert_ps, edge_ps, lrs)
342+
if !isnothing(lt_ofun.jac_transport)
343+
lt_ofun.jac_transport .= 0.0
344+
set_jac_transport_values!(lt_ofun.jac_transport, transport_rates, lrs)
345+
end
346+
347+
# Computes new field values.
348+
heterogeneous_vert_p_idxs = make_heterogeneous_vert_p_idxs(ps_new, lrs)
349+
mtk_ps, p_setters = make_mtk_ps_structs(ps_new, lrs, heterogeneous_vert_p_idxs)
350+
t_rate_idx_types, leaving_rates = make_t_types_and_leaving_rates(transport_rates, lrs)
351+
352+
# Updates functor fields.
353+
replace_vec!(lt_ofun.heterogeneous_vert_p_idxs, heterogeneous_vert_p_idxs)
354+
replace_vec!(lt_ofun.p_setters, p_setters)
355+
replace_vec!(lt_ofun.transport_rates, transport_rates)
356+
replace_vec!(lt_ofun.t_rate_idx_types, t_rate_idx_types)
357+
lt_ofun.leaving_rates .= leaving_rates
358+
359+
# Updating the `MTKParameters` structure is a bit more complicated.
360+
p_dict = Dict(ps_new)
361+
osys = complete(convert(ODESystem, reactionsystem(lrs)))
362+
for p in parameters(osys)
363+
MT.setp(osys, p)(lt_ofun.mtk_ps, (p_dict[p] isa Number) ? p_dict[p] : p_dict[p][1])
364+
end
365+
366+
return nothing
367+
end
368+
369+
# Specialised function which replaced one vector in another in a mutating way.
370+
# Required to update the vectors in the `LatticeTransportODEFunction` functor.
371+
function replace_vec!(vec1, vec2)
372+
l1 = length(vec1)
373+
l2 = length(vec2)
374+
375+
# Updates the fields, then deletes superfluous fields, or additional ones.
376+
for (i, v) in enumerate(vec2[1:min(l1, l2)])
377+
vec1[i] = v
378+
end
379+
foreach(idx -> deleteat!(vec1, idx), l1:-1:(l2 + 1))
380+
foreach(val -> push!(vec1, val), vec2[l1+1:l2])
381+
end

src/spatial_reaction_systems/spatial_reactions.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ function make_transport_reaction(rateex, species)
6363
iv = :(@variables $(DEFAULT_IV_SYM))
6464
trxexpr = :(TransportReaction($rateex, $species))
6565

66+
# Appends `edgeparameter` metadata to all declared parameters.
67+
for idx = 4:2:(2 + 2*length(parameters))
68+
insert!(pexprs.args, idx, :([edgeparameter=true]))
69+
end
70+
6671
quote
6772
$pexprs
6873
$iv

test/spatial_modelling/lattice_reaction_systems_ODEs.jl

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,124 @@ let
547547
@test all(isequal.(ss_1, ss_2))
548548
end
549549

550+
### ODEProblem & Integrator Interfacing ###
551+
552+
# Checks that basic interfacing with ODEProblem parameters (getting and setting) works.
553+
let
554+
# Creates an initial `ODEProblem`.
555+
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, small_1d_cartesian_grid)
556+
u0 = [:X => 1.0, :Y => 2.0]
557+
ps = [:A => 1.0, :B => [1.0, 2.0, 3.0, 4.0, 5.0], :dX => 0.1]
558+
oprob = ODEProblem(lrs, u0, (0.0, 10.0), ps)
559+
560+
# Checks that retrieved parameters are correct.
561+
@test oprob.ps[:A] == [1.0]
562+
@test oprob.ps[:B] == [1.0, 2.0, 3.0, 4.0, 5.0]
563+
@test oprob.ps[:dX] == sparse([1], [1], [0.1])
564+
565+
# Updates content.
566+
oprob.ps[:A] = [10.0, 20.0, 30.0, 40.0, 50.0]
567+
oprob.ps[:B] = [10.0]
568+
oprob.ps[:dX] = [0.01]
569+
570+
# Checks that content is correct.
571+
@test oprob.ps[:A] == [10.0, 20.0, 30.0, 40.0, 50.0]
572+
@test oprob.ps[:B] == [10.0]
573+
@test oprob.ps[:dX] == [0.01]
574+
end
575+
576+
# Checks that the `rebuild_lat_internals!` function is correctly applied to an ODEProblem.
577+
let
578+
# Creates a brusselator `LatticeReactionSystem`.
579+
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_2, very_small_2d_cartesian_grid)
580+
581+
# Checks for all combinations of Jacobian and sparsity.
582+
for jac in [false, true], sparse in [false, true]
583+
# Creates an initial ODEProblem.
584+
u0 = [:X => 1.0, :Y => [1.0 2.0; 3.0 4.0]]
585+
dY_vals = spzeros(4,4)
586+
dY_vals[1,2] = 0.1; dY_vals[2,1] = 0.1;
587+
dY_vals[1,3] = 0.2; dY_vals[3,1] = 0.2;
588+
dY_vals[2,4] = 0.3; dY_vals[4,2] = 0.3;
589+
dY_vals[3,4] = 0.4; dY_vals[4,3] = 0.4;
590+
ps = [:A => 1.0, :B => [4.0 5.0; 6.0 7.0], :dX => 0.1, :dY => dY_vals]
591+
oprob_1 = ODEProblem(lrs, u0, (0.0, 10.0), ps; jac, sparse)
592+
593+
# Creates an alternative version of the ODEProblem.
594+
dX_vals = spzeros(4,4)
595+
dX_vals[1,2] = 0.01; dX_vals[2,1] = 0.01;
596+
dX_vals[1,3] = 0.02; dX_vals[3,1] = 0.02;
597+
dX_vals[2,4] = 0.03; dX_vals[4,2] = 0.03;
598+
dX_vals[3,4] = 0.04; dX_vals[4,3] = 0.04;
599+
ps = [:A => [1.1 1.2; 1.3 1.4], :B => 5.0, :dX => dX_vals, :dY => 0.01]
600+
oprob_2 = ODEProblem(lrs, u0, (0.0, 10.0), ps; jac, sparse)
601+
602+
# Modifies the initial ODEProblem to be identical to the new one.
603+
oprob_1.ps[:A] = [1.1 1.2; 1.3 1.4]
604+
oprob_1.ps[:B] = [5.0]
605+
oprob_1.ps[:dX] = dX_vals
606+
oprob_1.ps[:dY] = [0.01]
607+
rebuild_lat_internals!(oprob_1)
608+
609+
# Checks that simulations of the two `ODEProblem`s are identical.
610+
@test solve(oprob_1, Rodas5P()) solve(oprob_2, Rodas5P())
611+
end
612+
end
613+
614+
# Checks that the `rebuild_lat_internals!` function is correctly applied to an integrator.
615+
# Does through by applying it within a callback, and compare to simulations without callback.
616+
let
617+
# Prepares problem inputs.
618+
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_2, very_small_2d_cartesian_grid)
619+
u0 = [:X => 1.0, :Y => [1.0 2.0; 3.0 4.0]]
620+
A1 = 1.0
621+
B1 = [4.0 5.0; 6.0 7.0]
622+
A2 = [1.1 1.2; 1.3 1.4]
623+
B2 = 5.0
624+
dY_vals = spzeros(4,4)
625+
dY_vals[1,2] = 0.1; dY_vals[2,1] = 0.1;
626+
dY_vals[1,3] = 0.2; dY_vals[3,1] = 0.2;
627+
dY_vals[2,4] = 0.3; dY_vals[4,2] = 0.3;
628+
dY_vals[3,4] = 0.4; dY_vals[4,3] = 0.4;
629+
dX_vals = spzeros(4,4)
630+
dX_vals[1,2] = 0.01; dX_vals[2,1] = 0.01;
631+
dX_vals[1,3] = 0.02; dX_vals[3,1] = 0.02;
632+
dX_vals[2,4] = 0.03; dX_vals[4,2] = 0.03;
633+
dX_vals[3,4] = 0.04; dX_vals[4,3] = 0.04;
634+
dX1 = 0.1
635+
dY1 = dY_vals
636+
dX2 = dX_vals
637+
dY2 = 0.01
638+
ps_1 = [:A => A1, :B => B1, :dX => dX1, :dY => dY1]
639+
ps_2 = [:A => A2, :B => B2, :dX => dX2, :dY => dY2]
640+
641+
# Checks for all combinations of Jacobian and sparsity.
642+
for jac in [false, true], sparse in [false, true]
643+
# Creates simulation through two different separate simulations.
644+
oprob_1_1 = ODEProblem(lrs, u0, (0.0, 5.0), ps_1; jac, sparse)
645+
sol_1_1 = solve(oprob_1_1, Rosenbrock23(); saveat = 1.0, abstol = 1e-8, reltol = 1e-8)
646+
u0_1_2 = [:X => sol_1_1.u[end][1:2:end], :Y => sol_1_1.u[end][2:2:end]]
647+
oprob_1_2 = ODEProblem(lrs, u0_1_2, (0.0, 5.0), ps_2; jac, sparse)
648+
sol_1_2 = solve(oprob_1_2, Rosenbrock23(); saveat = 1.0, abstol = 1e-8, reltol = 1e-8)
649+
650+
# Creates simulation through a single simulation with a callback
651+
oprob_2 = ODEProblem(lrs, u0, (0.0, 10.0), ps_1; jac, sparse)
652+
condition(u, t, integrator) = (t == 5.0)
653+
function affect!(integrator)
654+
integrator.ps[:A] = A2
655+
integrator.ps[:B] = [B2]
656+
integrator.ps[:dX] = dX2
657+
integrator.ps[:dY] = [dY2]
658+
rebuild_lat_internals!(integrator)
659+
end
660+
callback = DiscreteCallback(condition, affect!)
661+
sol_2 = solve(oprob_2, Rosenbrock23(); saveat = 1.0, tstops = [5.0], callback, abstol = 1e-8, reltol = 1e-8)
662+
663+
# Check that trajectories are equivalent.
664+
@test [sol_1_1.u; sol_1_2.u] sol_2.u
665+
end
666+
end
667+
550668
### Tests Special Cases ###
551669

552670
# Create network using either graphs or di-graphs.

test/spatial_modelling/lattice_reaction_systems_jumps.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,7 @@ let
200200
@test abs(d) < reltol * non_spatial_mean[i]
201201
end
202202
end
203-
end
203+
end
204+
205+
206+
### JumpProblem & Integrator Interfacing ###

0 commit comments

Comments
 (0)