Skip to content

Commit 37a14d0

Browse files
committed
add solution interfacing
1 parent c70e716 commit 37a14d0

File tree

7 files changed

+402
-9
lines changed

7 files changed

+402
-9
lines changed

src/Catalyst.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,16 @@ include("spatial_reaction_systems/spatial_reactions.jl")
171171
export TransportReaction, TransportReactions, @transport_reaction
172172
export isedgeparameter
173173

174-
# Lattice reaction systems
174+
# Lattice reaction systems.
175175
include("spatial_reaction_systems/lattice_reaction_systems.jl")
176176
export LatticeReactionSystem
177177
export spatial_species, vertex_parameters, edge_parameters
178178
export CartesianGrid, CartesianGridReJ # (Implemented in JumpProcesses)
179179
export has_cartesian_lattice, has_masked_lattice, has_grid_lattice, has_graph_lattice,
180180
grid_dims, grid_size
181181
export make_edge_p_values, make_directed_edge_values
182+
include("spatial_reaction_systems/lattice_solution_interfacing.jl")
183+
export get_lrs_vals
182184

183185
# Specific spatial problem types.
184186
include("spatial_reaction_systems/spatial_ODE_systems.jl")

src/spatial_reaction_systems/lattice_jump_systems.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ end
122122

123123
### Extra ###
124124

125-
# Temporary. Awaiting implementation in SII, or proper implementation withinCatalyst (with more general functionality).
125+
# Temporary. Awaiting implementation in SII, or proper implementation within Catalyst (with
126+
# more general functionality).
126127
function int_map(map_in, sys)
127128
return [ModelingToolkit.variable_index(sys, pair[1]) => pair[2] for pair in map_in]
128129
end
@@ -141,3 +142,14 @@ end
141142
# majpmapper = ModelingToolkit.JumpSysMajParamMapper(js, p; jseqs = eqs, rateconsttype = invttype)
142143
# return ModelingToolkit.assemble_maj(eqs.x[1], statetoid, majpmapper)
143144
# end
145+
146+
147+
### Problem & Integrator Rebuilding ###
148+
149+
# Currently not implemented.
150+
function rebuild_lat_internals!(dprob::DiscreteProblem)
151+
error("Modification and/or rebuilding of `DiscreteProblem`s is currently not supported. Please create a new problem instead.")
152+
end
153+
function rebuild_lat_internals!(jprob::JumpProblem)
154+
error("Modification and/or rebuilding of `JumpProblem`s is currently not supported. Please create a new problem instead.")
155+
end
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
### Rudimentary Interfacing Function ###
2+
# A single function, `get_lrs_vals`, which contain all interfacing functionality. However,
3+
# long-term it should be replaced with a sleeker interface. Ideally as MTK-wider support for
4+
# lattice problems and solutions are introduced.
5+
6+
"""
7+
get_lrs_vals(sol, sp, lrs::LatticeReactionSystem; t = nothing)
8+
9+
A function for retrieving the solution of a `LatticeReactionSystem`-based simulation on various
10+
desired forms. Generally, for `LatticeReactionSystem`s, the values in `sol` is ordered in a
11+
way which is not directly interpretable by the user. Furthermore, the normal Catalyst interface
12+
for solutions (e.g. `sol[:X]`) does not work for these solutions. Hence this function is used instead.
13+
14+
The output is a vector, which in each position contain sp's value (either at a time step of time,
15+
depending on the input `t`). Its shape depends on the lattice (using a similar form as heterogeneous
16+
initial conditions). I.e. for a NxM cartesian grid, the values are NxM matrices. For a masked grid,
17+
the values are sparse matrices. For a graph lattice, the values are vectors (where the value in
18+
the n'th position corresponds to sp's value in the n'th vertex).
19+
20+
Arguments:
21+
- `sol`: The solution from which we wish to retrieve some values.
22+
- `sp`: The species which values we wish to retrieve. Can be either a symbol (e.g. `:X`) or a symbolic
23+
variable (e.g. `X`).
24+
- `lrs`: The `LatticeReactionSystem` which was simulated to generate the solution.
25+
- `t = nothing`: If `nothing`, we simply returns the solution across all saved timesteps. If `t`
26+
instead is a vector (or range of values), returns the solutions interpolated at these timepoints.
27+
28+
Notes:
29+
- The `get_lrs_vals` is not optimised for performance. However, it should still be quite performant,
30+
but there might be some limitations if called a very large number of times.
31+
- Long-term it is likely that this function gets replaced with a sleeker interface.
32+
33+
Example:
34+
```julia
35+
using Catalyst, OrdinaryDiffEq
36+
37+
# Prepare `LatticeReactionSystem`s.
38+
rs = @reaction_network begin
39+
(k1,k2), X1 <--> X2
40+
end
41+
tr = @transport_reaction D X1
42+
lrs = LatticeReactionSystem(rs, [tr], CartesianGrid((2,2)))
43+
44+
# Create problems.
45+
u0 = [:X1 => 1, :X2 => 2]
46+
tspan = (0.0, 10.0)
47+
ps = [:k1 => 1, :k2 => 2.0, :D => 0.1]
48+
49+
oprob = ODEProblem(lrs1, u0, tspan, ps)
50+
osol = solve(oprob1, Tsit5())
51+
get_lrs_vals(osol, :X1, lrs) # Returns the value of X1 at each timestep.
52+
get_lrs_vals(osol, :X1, lrs; t = 0.0:10.0) # Returns the value of X1 at times 0.0, 1.0, ..., 10.0
53+
```
54+
"""
55+
function get_lrs_vals(sol, sp, lrs::LatticeReactionSystem; t = nothing)
56+
# Figures out which species we wish to fetch information about.
57+
(sp isa Symbol) && (sp = Catalyst._symbol_to_var(lrs, sp))
58+
sp_idx = findfirst(isequal(sp), species(lrs))
59+
sp_tot = length(species(lrs))
60+
61+
# Extracts the lattice and calls the next function. Masked grids (Array of Bools) are converted
62+
# to sparse array using the same template size as we wish to shape the data to.
63+
lattice = Catalyst.lattice(lrs)
64+
if has_masked_lattice(lrs)
65+
if grid_dims(lrs) == 3
66+
error("The `get_lrs_vals` function is not defined for systems based on 3d sparse arrays. Please raise an issue at the Catalyst GitHub site if this is something which would be useful to you.")
67+
end
68+
lattice = sparse(lattice)
69+
end
70+
get_lrs_vals(sol, lattice, t, sp_idx, sp_tot)
71+
end
72+
73+
# Function which handles the input in the case where `t` is `nothing` (i.e. return `sp`s value
74+
# across all sample points).
75+
function get_lrs_vals(sol, lattice, t::Nothing, sp_idx, sp_tot)
76+
# ODE simulations contain, in each data point, all values in a single vector. Jump simulations
77+
# instead in a matrix (NxM, where N is the number of species and M the number of vertices). We
78+
# must consider each case separately.
79+
if sol.prob isa ODEProblem
80+
return [reshape_vals(vals[sp_idx:sp_tot:end], lattice) for vals in sol.u]
81+
elseif sol.prob isa DiscreteProblem
82+
return [reshape_vals(vals[sp_idx,:], lattice) for vals in sol.u]
83+
else
84+
error("Unknown type of solution provided to `get_lrs_vals`. Only ODE or Jump solutions are supported.")
85+
end
86+
end
87+
88+
# Function which handles the input in the case where `t` is a range of values (i.e. return `sp`s
89+
# value at all designated time points.
90+
function get_lrs_vals(sol, lattice, t::AbstractVector{T}, sp_idx, sp_tot) where {T <: Number}
91+
if (minimum(t) < sol.t[1]) || (maximum(t) > sol.t[end])
92+
error("The range of the t values provided for sampling, ($(minimum(t)),$(maximum(t))) is not fully within the range of the simulation time span ($(sol.t[1]),$(sol.t[end])).")
93+
end
94+
95+
# ODE simulations contain, in each data point, all values in a single vector. Jump simulations
96+
# instead in a matrix (NxM, where N is the number of species and M the number of vertices). We
97+
# must consider each case separately.
98+
if sol.prob isa ODEProblem
99+
return [reshape_vals(sol(ti)[sp_idx:sp_tot:end], lattice) for ti in t]
100+
elseif sol.prob isa DiscreteProblem
101+
return [reshape_vals(sol(ti)[sp_idx,:], lattice) for ti in t]
102+
else
103+
error("Unknown type of solution provided to `get_lrs_vals`. Only ODE or Jump solutions are supported.")
104+
end
105+
end
106+
107+
# Functions which in each sample point reshapes the vector of values to the correct form (depending
108+
# on the type of lattice used).
109+
function reshape_vals(vals, lattice::CartesianGridRej{N, T}) where {N,T}
110+
return reshape(vals, lattice.dims...)
111+
end
112+
function reshape_vals(vals, lattice::AbstractSparseArray{Bool, Int64, 1})
113+
return SparseVector(lattice.n, lattice.nzind, vals)
114+
end
115+
function reshape_vals(vals, lattice::AbstractSparseArray{Bool, Int64, 2})
116+
return SparseMatrixCSC(lattice.m, lattice.n, lattice.colptr, lattice.rowval, vals)
117+
end
118+
function reshape_vals(vals, lattice::DiGraph)
119+
return vals
120+
end
121+

src/spatial_reaction_systems/spatial_ODE_systems.jl

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,10 @@ function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{R, V
243243
jac_transport, transport_rates)
244244
J = (jac ? f : nothing)
245245

246-
# Extracts the `Symbol` form for species and parameters. Creates and returns the `ODEFunction`.
247-
syms = MT.getname.(species(lrs))
248-
paramsyms = MT.getname.(parameters(lrs))
249-
return ODEFunction(f; jac = J, jac_prototype, syms, paramsyms)
246+
# Extracts the `Symbol` form for parameters (but not species). Creates and returns the `ODEFunction`.
247+
paramsyms = [MT.getname(p) for p in parameters(lrs)]
248+
sys = SciMLBase.SymbolCache([], paramsyms, [])
249+
return ODEFunction(f; jac = J, jac_prototype, sys)
250250
end
251251

252252
# Builds a jacobian prototype.
@@ -325,7 +325,48 @@ end
325325

326326
### Functor Updating Functionality ###
327327

328-
# Function for rebuilding a `LatticeReactionSystem` `ODEProblem` after it has been updated.
328+
"""
329+
rebuild_lat_internals!(sciml_struct)
330+
331+
Rebuilds the internal functions for simulating a LatticeReactionSystem. WHenever a problem or
332+
integrator have had its parameter values updated, thus function should be called for the update to
333+
be taken into account. For ODE simulations, `rebuild_lat_internals!` needs only to be called when
334+
- An edge parameter have been updated.
335+
- When a parameter with spatially homogeneous values have been given spatially heterogeneous values
336+
(or vice versa).
337+
338+
Arguments:
339+
- `sciml_struct`: The problem (e.g. an `ODEProblem`) or an integrator which we wish to rebuild.
340+
341+
Notes:
342+
- Currently does not work for `DiscreteProblem`s, `JumpProblem`s, or their integrators.
343+
- The function is not build with performance in mind, so avoid calling it multiple times in
344+
performance-critical applications.
345+
346+
Example:
347+
```julia
348+
# Creates an initial `ODEProblem`
349+
rs = @reaction_network begin
350+
(k1,k2), X1 <--> X2
351+
end
352+
tr = @transport_reaction D X1
353+
grid = CartesianGrid((2,2))
354+
lrs = LatticeReactionSystem(rs, [tr], grid)
355+
356+
u0 = [:X1 => 2, :X2 => [5 6; 7 8]]
357+
tspan = (0.0, 10.0)
358+
ps = [:k1 => 1.5, :k2 => [1.0 1.5; 2.0 3.5], :D => 0.1]
359+
360+
oprob = ODEProblem(lrs, u0, tspan, ps)
361+
362+
# Updates parameter values.
363+
oprob.ps[:ks] = [2.0 2.5; 3.0 4.5]
364+
oprob.ps[:D] = 0.05
365+
366+
# Rebuilds `ODEProblem` to make changes have an effect.
367+
rebuild_lat_internals!(oprob)
368+
```
369+
"""
329370
function rebuild_lat_internals!(oprob::ODEProblem)
330371
rebuild_lat_internals!(oprob.f.f, oprob.p, oprob.f.f.lrs)
331372
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,6 @@ using SafeTestsets, Test
7272
@time @safetestset "Spatial Lattice Variants" begin include("spatial_modelling/lattice_reaction_systems_lattice_types.jl") end
7373
@time @safetestset "ODE Lattice Systems Simulations" begin include("spatial_modelling/lattice_reaction_systems_ODEs.jl") end
7474
@time @safetestset "Jump Lattice Systems Simulations" begin include("spatial_modelling/lattice_reaction_systems_jumps.jl") end
75+
@time @safetestset "Jump Solution Interfacing" begin include("spatial_modelling/lattice_solution_interfacing.jl") end
7576

7677
end # @time

test/spatial_modelling/lattice_reaction_systems_jumps.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
### Preparations ###
22

33
# Fetch packages.
4-
using JumpProcesses
5-
using Random, Statistics, SparseArrays, Test
4+
using JumpProcesses, Statistics, SparseArrays, Test
65

76
# Fetch test networks.
87
include("../spatial_test_networks.jl")
@@ -204,6 +203,27 @@ end
204203

205204
### JumpProblem & Integrator Interfacing ###
206205

206+
# Currently not supported, check that corresponding functions yields errors.
207+
let
208+
# Prepare `LatticeReactionSystem`.
209+
rs = @reaction_network begin
210+
(k1,k2), X1 <--> X2
211+
end
212+
tr = @transport_reaction D X1
213+
grid = CartesianGrid((2,2))
214+
lrs = LatticeReactionSystem(rs, [tr], grid)
215+
216+
# Create problems.
217+
u0 = [:X1 => 2, :X2 => [5 6; 7 8]]
218+
tspan = (0.0, 10.0)
219+
ps = [:k1 => 1.5, :k2 => [1.0 1.5; 2.0 3.5], :D => 0.1]
220+
dprob = DiscreteProblem(lrs, u0, tspan, ps)
221+
jprob = JumpProblem(lrs, dprob, NSM())
222+
223+
# Checks that rebuilding errors.
224+
@test_throws Exception rebuild_lat_internals!(dprob)
225+
@test_throws Exception rebuild_lat_internals!(jprob)
226+
end
207227

208228
### Other Tests ###
209229

0 commit comments

Comments
 (0)