Skip to content

Commit 2228734

Browse files
authored
Merge pull request #756 from SciML/spatial_internal_update
[WIP] LatticeReactionSystem internal update
2 parents a4c75fc + e19d67d commit 2228734

16 files changed

+2827
-1371
lines changed

src/Catalyst.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using LaTeXStrings, Latexify, Requires
99
using LinearAlgebra, Combinatorics
1010
using JumpProcesses: JumpProcesses, JumpProblem,
1111
MassActionJump, ConstantRateJump, VariableRateJump,
12-
SpatialMassActionJump
12+
SpatialMassActionJump, CartesianGrid, CartesianGridRej
1313

1414
# ModelingToolkit imports and convenience functions we use
1515
using ModelingToolkit
@@ -171,18 +171,25 @@ include("spatial_reaction_systems/spatial_reactions.jl")
171171
export TransportReaction, TransportReactions, @transport_reaction
172172
export isedgeparameter
173173

174-
# Lattice reaction systems
174+
# Lattice reaction systems.
175175
include("spatial_reaction_systems/lattice_reaction_systems.jl")
176176
export LatticeReactionSystem
177177
export spatial_species, vertex_parameters, edge_parameters
178-
179-
# Various utility functions
180-
include("spatial_reaction_systems/utility.jl")
178+
export CartesianGrid, CartesianGridReJ # (Implemented in JumpProcesses)
179+
export has_cartesian_lattice, has_masked_lattice, has_grid_lattice, has_graph_lattice,
180+
grid_dims, grid_size
181+
export make_edge_p_values, make_directed_edge_values
182+
include("spatial_reaction_systems/lattice_solution_interfacing.jl")
183+
export get_lrs_vals
181184

182185
# Specific spatial problem types.
183186
include("spatial_reaction_systems/spatial_ODE_systems.jl")
187+
export rebuild_lat_internals!
184188
include("spatial_reaction_systems/lattice_jump_systems.jl")
185189

190+
# General spatial modelling utility functions.
191+
include("spatial_reaction_systems/utility.jl")
192+
186193
### ReactionSystem Serialisation ###
187194
# Has to be at the end (because it uses records of all metadata declared by Catalyst).
188195
include("reactionsystem_serialisation/serialisation_support.jl")

src/spatial_reaction_systems/lattice_jump_systems.jl

Lines changed: 67 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -7,123 +7,123 @@ function DiffEqBase.DiscreteProblem(lrs::LatticeReactionSystem, u0_in, tspan,
77
error("Currently lattice Jump simulations only supported when all spatial reactions are transport reactions.")
88
end
99

10-
# Converts potential symmaps to varmaps
11-
# Vertex and edge parameters may be given in a tuple, or in a common vector, making parameter case complicated.
10+
# Converts potential symmaps to varmaps.
1211
u0_in = symmap_to_varmap(lrs, u0_in)
13-
p_in = (p_in isa Tuple{<:Any, <:Any}) ?
14-
(symmap_to_varmap(lrs, p_in[1]), symmap_to_varmap(lrs, p_in[2])) :
15-
symmap_to_varmap(lrs, p_in)
12+
p_in = symmap_to_varmap(lrs, p_in)
1613

1714
# Converts u0 and p to their internal forms.
15+
# u0 is simply a vector with all the species' initial condition values across all vertices.
1816
# u0 is [spec 1 at vert 1, spec 2 at vert 1, ..., spec 1 at vert 2, ...].
19-
u0 = lattice_process_u0(u0_in, species(lrs), lrs.num_verts)
20-
# Both vert_ps and edge_ps becomes vectors of vectors. Each have 1 element for each parameter.
21-
# These elements are length 1 vectors (if the parameter is uniform),
22-
# or length num_verts/nE, with unique values for each vertex/edge (for vert_ps/edge_ps, respectively).
17+
u0 = lattice_process_u0(u0_in, species(lrs), lrs)
18+
# vert_ps and `edge_ps` are vector maps, taking each parameter's Symbolics representation to its value(s).
19+
# vert_ps values are vectors. Here, index (i) is a parameter's value in vertex i.
20+
# edge_ps values are sparse matrices. Here, index (i,j) is a parameter's value in the edge from vertex i to vertex j.
21+
# Uniform vertex/edge parameters store only a single value (a length 1 vector, or size 1x1 sparse matrix).
2322
vert_ps, edge_ps = lattice_process_p(p_in, vertex_parameters(lrs),
2423
edge_parameters(lrs), lrs)
2524

26-
# Returns a DiscreteProblem.
27-
# Previously, a Tuple was used for (vert_ps, edge_ps), but this was converted to a Vector internally.
28-
return DiscreteProblem(u0, tspan, [vert_ps, edge_ps], args...; kwargs...)
25+
# Returns a DiscreteProblem (which basically just stores the processed input).
26+
return DiscreteProblem(u0, tspan, [vert_ps; edge_ps], args...; kwargs...)
2927
end
3028

31-
# Builds a spatial JumpProblem from a DiscreteProblem containing a Lattice Reaction System.
32-
function JumpProcesses.JumpProblem(lrs::LatticeReactionSystem, dprob, aggregator,
33-
args...; name = nameof(lrs.rs),
34-
combinatoric_ratelaws = get_combinatoric_ratelaws(lrs.rs), kwargs...)
29+
# Builds a spatial JumpProblem from a DiscreteProblem containing a `LatticeReactionSystem`.
30+
function JumpProcesses.JumpProblem(lrs::LatticeReactionSystem, dprob, aggregator, args...;
31+
combinatoric_ratelaws = get_combinatoric_ratelaws(reactionsystem(lrs)),
32+
name = nameof(reactionsystem(lrs)), kwargs...)
3533
# Error checks.
3634
if !isnothing(dprob.f.sys)
37-
error("Unexpected `DiscreteProblem` passed into `JumpProblem`. Was a `LatticeReactionSystem` used as input to the initial `DiscreteProblem`?")
35+
throw(ArgumentError("Unexpected `DiscreteProblem` passed into `JumpProblem`. Was a `LatticeReactionSystem` used as input to the initial `DiscreteProblem`?"))
3836
end
3937

4038
# Computes hopping constants and mass action jumps (requires some internal juggling).
41-
# Currently, JumpProcesses requires uniform vertex parameters (hence `p=first.(dprob.p[1])`).
4239
# Currently, the resulting JumpProblem does not depend on parameters (no way to incorporate these).
43-
# Hence the parameters of this one does nto actually matter. If at some point JumpProcess can
40+
# Hence the parameters of this one do not actually matter. If at some point JumpProcess can
4441
# handle parameters this can be updated and improved.
4542
# The non-spatial DiscreteProblem have a u0 matrix with entries for all combinations of species and vertexes.
4643
hopping_constants = make_hopping_constants(dprob, lrs)
4744
sma_jumps = make_spatial_majumps(dprob, lrs)
48-
non_spat_dprob = DiscreteProblem(
49-
reshape(dprob.u0, lrs.num_species, lrs.num_verts), dprob.tspan, first.(dprob.p[1]))
45+
non_spat_dprob = DiscreteProblem(reshape(dprob.u0, num_species(lrs), num_verts(lrs)),
46+
dprob.tspan, first.(dprob.p[1]))
5047

48+
# Creates and returns a spatial JumpProblem (masked lattices are not supported by these).
49+
spatial_system = has_masked_lattice(lrs) ? get_lattice_graph(lrs) : lattice(lrs)
5150
return JumpProblem(non_spat_dprob, aggregator, sma_jumps;
52-
hopping_constants, spatial_system = lrs.lattice, name, kwargs...)
51+
hopping_constants, spatial_system, name, kwargs...)
5352
end
5453

5554
# Creates the hopping constants from a discrete problem and a lattice reaction system.
5655
function make_hopping_constants(dprob::DiscreteProblem, lrs::LatticeReactionSystem)
5756
# Creates the all_diff_rates vector, containing for each species, its transport rate across all edges.
58-
# If transport rate is uniform for one species, the vector have a single element, else one for each edge.
59-
spatial_rates_dict = Dict(compute_all_transport_rates(dprob.p[1], dprob.p[2], lrs))
57+
# If the transport rate is uniform for one species, the vector has a single element, else one for each edge.
58+
spatial_rates_dict = Dict(compute_all_transport_rates(Dict(dprob.p), lrs))
6059
all_diff_rates = [haskey(spatial_rates_dict, s) ? spatial_rates_dict[s] : [0.0]
6160
for s in species(lrs)]
6261

63-
# Creates the hopping constant Matrix. It contains one element for each combination of species and vertex.
64-
# Each element is a Vector, containing the outgoing hopping rates for that species, from that vertex, on that edge.
65-
hopping_constants = [Vector{Float64}(undef, length(lrs.lattice.fadjlist[j]))
66-
for i in 1:(lrs.num_species), j in 1:(lrs.num_verts)]
67-
68-
# For each edge, finds each position in `hopping_constants`.
69-
for (e_idx, e) in enumerate(edges(lrs.lattice))
70-
dst_idx = findfirst(isequal(e.dst), lrs.lattice.fadjlist[e.src])
71-
# For each species, sets that hopping rate.
72-
for s_idx in 1:(lrs.num_species)
73-
hopping_constants[s_idx, e.src][dst_idx] = get_component_value(
74-
all_diff_rates[s_idx], e_idx)
75-
end
62+
# Creates an array (of the same size as the hopping constant array) containing all edges.
63+
# First the array is a NxM matrix (number of species x number of vertices). Each element is a
64+
# vector containing all edges leading out from that vertex (sorted by destination index).
65+
edge_array = [Pair{Int64, Int64}[] for _1 in 1:num_species(lrs), _2 in 1:num_verts(lrs)]
66+
for e in edge_iterator(lrs), s_idx in 1:num_species(lrs)
67+
push!(edge_array[s_idx, e[1]], e)
7668
end
69+
foreach(e_vec -> sort!(e_vec; by = e -> e[2]), edge_array)
7770

71+
# Creates the hopping constants array. It has the same shape as the edge array, but each
72+
# element is that species transportation rate along that edge
73+
hopping_constants = [[Catalyst.get_edge_value(all_diff_rates[s_idx], e)
74+
for e in edge_array[s_idx, src_idx]]
75+
for s_idx in 1:num_species(lrs), src_idx in 1:num_verts(lrs)]
7876
return hopping_constants
7977
end
8078

8179
# Creates a SpatialMassActionJump struct from a (spatial) DiscreteProblem and a LatticeReactionSystem.
82-
# Could implementation a version which, if all reaction's rates are uniform, returns a MassActionJump.
83-
# Not sure if there is any form of performance improvement from that though. Possibly is not the case.
80+
# Could implement a version which, if all reactions' rates are uniform, returns a MassActionJump.
81+
# Not sure if there is any form of performance improvement from that though. Likely not the case.
8482
function make_spatial_majumps(dprob, lrs::LatticeReactionSystem)
8583
# Creates a vector, storing which reactions have spatial components.
86-
is_spatials = [Catalyst.has_spatial_vertex_component(rx.rate, lrs;
87-
vert_ps = dprob.p[1]) for rx in reactions(lrs.rs)]
84+
is_spatials = [has_spatial_vertex_component(rx.rate, dprob.p)
85+
for rx in reactions(reactionsystem(lrs))]
8886

8987
# Creates templates for the rates (uniform and spatial) and the stoichiometries.
9088
# We cannot fetch reactant_stoich and net_stoich from a (non-spatial) MassActionJump.
9189
# The reason is that we need to re-order the reactions so that uniform appears first, and spatial next.
92-
u_rates = Vector{Float64}(undef, length(reactions(lrs.rs)) - count(is_spatials))
93-
s_rates = Matrix{Float64}(undef, count(is_spatials), lrs.num_verts)
94-
reactant_stoich = Vector{Vector{Pair{Int64, Int64}}}(undef, length(reactions(lrs.rs)))
95-
net_stoich = Vector{Vector{Pair{Int64, Int64}}}(undef, length(reactions(lrs.rs)))
90+
num_rxs = length(reactions(reactionsystem(lrs)))
91+
u_rates = Vector{Float64}(undef, num_rxs - count(is_spatials))
92+
s_rates = Matrix{Float64}(undef, count(is_spatials), num_verts(lrs))
93+
reactant_stoich = Vector{Vector{Pair{Int64, Int64}}}(undef, num_rxs)
94+
net_stoich = Vector{Vector{Pair{Int64, Int64}}}(undef, num_rxs)
9695

9796
# Loops through reactions with non-spatial rates, computes their rates and stoichiometries.
9897
cur_rx = 1
99-
for (is_spat, rx) in zip(is_spatials, reactions(lrs.rs))
98+
for (is_spat, rx) in zip(is_spatials, reactions(reactionsystem(lrs)))
10099
is_spat && continue
101-
u_rates[cur_rx] = compute_vertex_value(rx.rate, lrs; vert_ps = dprob.p[1])[1]
100+
u_rates[cur_rx] = compute_vertex_value(rx.rate, lrs; ps = dprob.p)[1]
102101
substoich_map = Pair.(rx.substrates, rx.substoich)
103-
reactant_stoich[cur_rx] = int_map(substoich_map, lrs.rs)
104-
net_stoich[cur_rx] = int_map(rx.netstoich, lrs.rs)
102+
reactant_stoich[cur_rx] = int_map(substoich_map, reactionsystem(lrs))
103+
net_stoich[cur_rx] = int_map(rx.netstoich, reactionsystem(lrs))
105104
cur_rx += 1
106105
end
107106
# Loops through reactions with spatial rates, computes their rates and stoichiometries.
108-
for (is_spat, rx) in zip(is_spatials, reactions(lrs.rs))
107+
for (is_spat, rx) in zip(is_spatials, reactions(reactionsystem(lrs)))
109108
is_spat || continue
110-
s_rates[cur_rx - length(u_rates), :] = compute_vertex_value(rx.rate, lrs;
111-
vert_ps = dprob.p[1])
109+
s_rates[cur_rx - length(u_rates), :] .= compute_vertex_value(rx.rate, lrs;
110+
ps = dprob.p)
112111
substoich_map = Pair.(rx.substrates, rx.substoich)
113-
reactant_stoich[cur_rx] = int_map(substoich_map, lrs.rs)
114-
net_stoich[cur_rx] = int_map(rx.netstoich, lrs.rs)
112+
reactant_stoich[cur_rx] = int_map(substoich_map, reactionsystem(lrs))
113+
net_stoich[cur_rx] = int_map(rx.netstoich, reactionsystem(lrs))
115114
cur_rx += 1
116115
end
117116
# SpatialMassActionJump expects empty rate containers to be nothing.
118117
isempty(u_rates) && (u_rates = nothing)
119118
(count(is_spatials) == 0) && (s_rates = nothing)
120119

121-
return SpatialMassActionJump(u_rates, s_rates, reactant_stoich, net_stoich)
120+
return SpatialMassActionJump(u_rates, s_rates, reactant_stoich, net_stoich, nothing)
122121
end
123122

124123
### Extra ###
125124

126-
# Temporary. Awaiting implementation in SII, or proper implementation withinCatalyst (with more general functionality).
125+
# Temporary. Awaiting implementation in SII, or proper implementation within Catalyst (with
126+
# more general functionality).
127127
function int_map(map_in, sys)
128128
return [ModelingToolkit.variable_index(sys, pair[1]) => pair[2] for pair in map_in]
129129
end
@@ -133,7 +133,7 @@ end
133133
# function make_majumps(non_spat_dprob, rs::ReactionSystem)
134134
# # Computes various required inputs for assembling the mass action jumps.
135135
# js = convert(JumpSystem, rs)
136-
# statetoid = Dict(ModelingToolkit.value(state) => i for (i, state) in enumerate(states(rs)))
136+
# statetoid = Dict(ModelingToolkit.value(state) => i for (i, state) in enumerate(unknowns(rs)))
137137
# eqs = equations(js)
138138
# invttype = non_spat_dprob.tspan[1] === nothing ? Float64 : typeof(1 / non_spat_dprob.tspan[2])
139139
#
@@ -142,3 +142,13 @@ end
142142
# majpmapper = ModelingToolkit.JumpSysMajParamMapper(js, p; jseqs = eqs, rateconsttype = invttype)
143143
# return ModelingToolkit.assemble_maj(eqs.x[1], statetoid, majpmapper)
144144
# end
145+
146+
### Problem & Integrator Rebuilding ###
147+
148+
# Currently not implemented.
149+
function rebuild_lat_internals!(dprob::DiscreteProblem)
150+
error("Modification and/or rebuilding of `DiscreteProblem`s is currently not supported. Please create a new problem instead.")
151+
end
152+
function rebuild_lat_internals!(jprob::JumpProblem)
153+
error("Modification and/or rebuilding of `JumpProblem`s is currently not supported. Please create a new problem instead.")
154+
end

0 commit comments

Comments
 (0)