Skip to content

Commit fc47179

Browse files
committed
update
1 parent 2e3e2d4 commit fc47179

File tree

3 files changed

+82
-48
lines changed

3 files changed

+82
-48
lines changed

src/spatial_reaction_systems/spatial_ODE_systems.jl

Lines changed: 72 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@ struct LatticeTransportODEf{R,S,T}
1313
"""Temporary vector. For parameters which values are identical across the lattice, at some point these have to be converted of a length num_verts vector. To avoid re-allocation they are written to this vector."""
1414
work_vert_ps::Vector{S}
1515
"""For each parameter in vert_ps, its value is a vector with length either num_verts or 1. To know whenever a parameter's value need expanding to the work_vert_ps array, its length needs checking. This check is done once, and the value stored to this array. This field (specifically) is an enumerate over that array."""
16-
enum_v_ps_idx_types::Base.Iterators.Enumerate{Vector{Bool}}
16+
v_ps_idx_types::Vector{Bool}
1717
"""A vector of pairs, with a value for each species with transportation. The first value is the species index (in the species(::ReactionSystem) vector), and the second is a vector with its transport rate values. If the transport rate is uniform (across all edges), that value is the only value in the vector. Else, there is one value for each edge in the lattice."""
1818
transport_rates::Vector{Pair{Int64, Vector{S}}}
1919
"""A matrix, NxM, where N is the number of species with transportation and M the number of vertexes. Each value is the total rate at which that species leaves that vertex (e.g. for a species with constant diffusion rate D, in a vertex with n neighbours, this value is n*D)."""
2020
leaving_rates::Matrix{S}
2121
"""An (enumerate'ed) iterator over all the edges of the lattice."""
22-
enum_edges::T
22+
edges::Graphs.SimpleGraphs.SimpleEdgeIter{SimpleDiGraph{Int64}}
23+
"""The edge parameters used to create the spatial ODEProblem. Currently unused, but will be needed to support changing these (e.g. due to events). Contain one vector for each edge parameter (length one if uniform, else one value for each edge)."""
24+
edge_ps::Vector{Vector{T}}
2325

24-
function LatticeTransportODEf(ofunc::R, vert_ps::Vector{Vector{S}}, transport_rates::Vector{Pair{Int64, Vector{S}}}, lrs::LatticeReactionSystem) where {R,S}
26+
function LatticeTransportODEf(ofunc::R, vert_ps::Vector{Vector{S}}, transport_rates::Vector{Pair{Int64, Vector{S}}}, edge_ps::Vector{Vector{T}}, lrs::LatticeReactionSystem) where {R,S,T}
2527
leaving_rates = zeros(length(transport_rates), lrs.num_verts)
2628
for (s_idx, trpair) in enumerate(transport_rates)
2729
rates = last(trpair)
@@ -31,37 +33,38 @@ struct LatticeTransportODEf{R,S,T}
3133
end
3234
end
3335
work_vert_ps = zeros(lrs.num_verts)
34-
# 1 if ps are constant across the graph, 0 else
35-
enum_v_ps_idx_types = enumerate(map(vp -> length(vp) == 1, vert_ps))
36-
enum_edges = deepcopy(enumerate(edges(lrs.lattice))) # Creates an iterator over all the edges. Again, this is always used in the enumerated form.
37-
new{R,S,typeof(enum_edges)}(ofunc, lrs.num_verts, lrs.num_species, vert_ps, work_vert_ps, enum_v_ps_idx_types, transport_rates, leaving_rates, enum_edges)
36+
# 1 if ps are constant across the graph, 0 else.
37+
v_ps_idx_types = map(vp -> length(vp) == 1, vert_ps)
38+
eds = edges(lrs.lattice)
39+
new{R,S,T}(ofunc, lrs.num_verts, lrs.num_species, vert_ps, work_vert_ps, v_ps_idx_types, transport_rates, leaving_rates, eds, edge_ps)
3840
end
3941
end
4042

4143
# Functor structure containing the information for the forcing function of a spatial ODE with spatial movement on a lattice.
42-
struct LatticeTransportODEjac{R,S,T}
44+
struct LatticeTransportODEjac{Q,R,S,T}
4345
"""The ODEFunction of the (non-spatial) reaction system which generated this function."""
44-
ofunc::R
46+
ofunc::Q
4547
"""The number of vertices."""
4648
num_verts::Int64
4749
"""The number of species."""
4850
num_species::Int64
4951
"""The values of the parameters which values are tied to vertexes."""
50-
vert_ps::Vector{Vector{S}}
52+
vert_ps::Vector{Vector{R}}
5153
"""Temporary vector. For parameters which values are identical across the lattice, at some point these have to be converted of a length(num_verts) vector. To avoid re-allocation they are written to this vector."""
52-
work_vert_ps::Vector{S}
54+
work_vert_ps::Vector{R}
5355
"""For each parameter in vert_ps, it either have length num_verts or 1. To know whenever a parameter's value need expanding to the work_vert_ps array, its length needs checking. This check is done once, and the value stored to this array. This field (specifically) is an enumerate over that array."""
54-
enum_v_ps_idx_types::Base.Iterators.Enumerate{Vector{Bool}}
56+
v_ps_idx_types::Vector{Bool}
5557
"""Whether the Jacobian is sparse or not."""
5658
sparse::Bool
5759
"""The transport rates. Can be a dense matrix (for non-sparse) or as the "nzval" field if sparse."""
58-
jac_values::T
60+
jac_transport::S
61+
"""The edge parameters used to create the spatial ODEProblem. Currently unused, but will be needed to support changing these (e.g. due to events). Contain one vector for each edge parameter (length one if uniform, else one value for each edge)."""
62+
edge_ps::Vector{Vector{T}}
5963

60-
function LatticeTransportODEjac(ofunc::R, vert_ps::Vector{Vector{S}}, lrs::LatticeReactionSystem, jac_prototype::Union{Nothing, SparseMatrixCSC{Float64, Int64}}, sparse::Bool) where {R,S}
64+
function LatticeTransportODEjac(ofunc::R, vert_ps::Vector{Vector{S}}, lrs::LatticeReactionSystem, jac_transport::Union{Nothing, SparseMatrixCSC{Float64, Int64}}, edge_ps::Vector{Vector{T}}, sparse::Bool) where {R,S,T}
6165
work_vert_ps = zeros(lrs.num_verts)
62-
enum_v_ps_idx_types = enumerate(map(vp -> length(vp) == 1, vert_ps))
63-
jac_values = sparse ? jac_prototype.nzval : Matrix(jac_prototype) # Retrieves the diffusion values (form depending on Jacobian sparsity).
64-
new{R,S,typeof(jac_values)}(ofunc, lrs.num_verts, lrs.num_species, vert_ps, work_vert_ps, enum_v_ps_idx_types, sparse, jac_values)
66+
v_ps_idx_types = map(vp -> length(vp) == 1, vert_ps)
67+
new{R,S,typeof(jac_transport),T}(ofunc, lrs.num_verts, lrs.num_species, vert_ps, work_vert_ps, v_ps_idx_types, sparse, jac_transport, edge_ps)
6568
end
6669
end
6770

@@ -70,7 +73,10 @@ end
7073
# Creates an ODEProblem from a LatticeReactionSystem.
7174
function DiffEqBase.ODEProblem(lrs::LatticeReactionSystem, u0_in, tspan,
7275
p_in = DiffEqBase.NullParameters(), args...;
73-
jac = true, sparse = jac, kwargs...)
76+
jac = true, sparse = jac,
77+
name = nameof(lrs), include_zero_odes = true,
78+
combinatoric_ratelaws = get_combinatoric_ratelaws(lrs.rs),
79+
remove_conserved = false, checks = false, kwargs...)
7480
is_transport_system(lrs) || error("Currently lattice ODE simulations are only supported when all spatial reactions are TransportReactions.")
7581

7682
# Converts potential symmaps to varmaps (parameter conversion is more involved since the vertex and edge parameters may be given in a tuple, or in a common vector).
@@ -85,26 +91,49 @@ function DiffEqBase.ODEProblem(lrs::LatticeReactionSystem, u0_in, tspan,
8591
vert_ps, edge_ps = lattice_process_p(p_in, vertex_parameters(lrs), edge_parameters(lrs), lrs)
8692

8793
# Creates ODEProblem.
88-
ofun = build_odefunction(lrs, vert_ps, edge_ps, jac, sparse) # Builds the ODEFunction.
89-
return ODEProblem(ofun, u0, tspan, vert_ps, args...; kwargs...) # Creates a normal ODEProblem.
94+
ofun = build_odefunction(lrs, vert_ps, edge_ps, jac, sparse, name, include_zero_odes, combinatoric_ratelaws, remove_conserved, checks)
95+
return ODEProblem(ofun, u0, tspan, vert_ps, args...; kwargs...)
9096
end
9197

9298
# Builds an ODEFunction for a spatial ODEProblem.
9399
function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Vector{T}},
94-
edge_ps::Vector{Vector{T}}, use_jac::Bool, sparse::Bool) where {T}
95-
# Prepares (non-spatial) ODE functions and list of spatially moving species and their rates.
96-
ofunc = ODEFunction(convert(ODESystem, lrs.rs); jac = use_jac, sparse = false) # Creates the (non-spatial) ODEFunction corresponding to the (non-spatial) reaction network.
97-
ofunc_sparse = ODEFunction(convert(ODESystem, lrs.rs); jac = use_jac, sparse = true) # Creates the same function, but sparse. Could insert so this is only computed for sparse cases.
98-
transport_rates_speciesmap = compute_all_transport_rates(vert_ps, edge_ps, lrs) # Creates a map (Vector{Pair}), mapping each species that is transported to a vector with its transportation rate. If the rate is uniform across all edges, the vector will be length 1 (with this value), else there will be a separate value for each edge.
99-
transport_rates = Pair{Int64, Vector{T}}[findfirst(isequal(spat_rates[1]), species(lrs)) => spat_rates[2]
100-
for spat_rates in transport_rates_speciesmap] # Remakes "transport_rates_speciesmap". Rates are identical, but the species are represented as their index (in the species(::ReactionSystem) vector). In "transport_rates_speciesmap" they instead were Symbolics. Pair{Int64, Vector{T}}[] is required in case vector is empty (otherwise it becomes Any[], causing type error later).
101-
102-
f = LatticeTransportODEf(ofunc, vert_ps, transport_rates, lrs)
103-
jac_prototype = (use_jac || sparse) ?
104-
build_jac_prototype(ofunc_sparse.jac_prototype, transport_rates,
105-
lrs; set_nonzero = use_jac) : nothing # Computes the Jacobian prototype (nothing if `jac=false`).
106-
jac = use_jac ? LatticeTransportODEjac(ofunc, vert_ps, lrs, jac_prototype, sparse) : nothing # (Potentially) Creates a functor for the ODE Jacobian function (incorporating spatial and non-spatial reactions).
107-
return ODEFunction(f; jac = jac, jac_prototype = (sparse ? jac_prototype : nothing)) # Creates the ODEFunction used in the ODEProblem.
100+
edge_ps::Vector{Vector{T}}, jac::Bool, sparse::Bool,
101+
name, include_zero_odes, combinatoric_ratelaws, remove_conserved, checks) where {T}
102+
println()
103+
remove_conserved && error("Removal of conserved quantities is currently not supported for `LatticeReactionSystem`s")
104+
105+
# Creates a map, taking (the index in species(lrs) each species (with transportation) to its transportation rate (uniform or one value for each edge).
106+
transport_rates = make_sidxs_to_transrate_map(vert_ps, edge_ps, lrs)
107+
108+
# Prepares the Jacobian and forcing functions (depending on jacobian and sparsity selection).
109+
if jac
110+
ofunc_dense = ODEFunction(convert(ODESystem, lrs.rs; name, combinatoric_ratelaws, include_zero_odes, checks); jac = true, sparse = false) # Always used for build_jac_prototype.
111+
ofunc_sparse = ODEFunction(convert(ODESystem, lrs.rs; name, combinatoric_ratelaws, include_zero_odes, checks); jac = true, sparse = true) # Always used for LatticeTransportODEjac.
112+
jac_vals = build_jac_prototype(ofunc_sparse.jac_prototype, transport_rates, lrs; set_nonzero = true)
113+
if sparse
114+
f = LatticeTransportODEf(ofunc_sparse, vert_ps, transport_rates, edge_ps, lrs)
115+
jac_vals = build_jac_prototype(ofunc_sparse.jac_prototype, transport_rates, lrs; set_nonzero = true)
116+
J = LatticeTransportODEjac(ofunc_dense, vert_ps, lrs, jac_vals, edge_ps, true)
117+
jac_prototype = jac_vals
118+
else
119+
f = LatticeTransportODEf(ofunc_dense, vert_ps, transport_rates, edge_ps, lrs)
120+
J = LatticeTransportODEjac(ofunc_dense, vert_ps, lrs, jac_vals, edge_ps, false)
121+
jac_prototype = nothing
122+
end
123+
else
124+
if sparse
125+
ofunc_sparse = ODEFunction(convert(ODESystem, lrs.rs; name, combinatoric_ratelaws, include_zero_odes, checks); jac = false, sparse = true)
126+
f = LatticeTransportODEf(ofunc_sparse, vert_ps, transport_rates, edge_ps, lrs)
127+
jac_prototype = build_jac_prototype(ofunc_sparse.jac_prototype, transport_rates, lrs; set_nonzero = false)
128+
else
129+
ofunc_dense = ODEFunction(convert(ODESystem, lrs.rs; name, combinatoric_ratelaws, include_zero_odes, checks); jac = false, sparse = false)
130+
f = LatticeTransportODEf(ofunc_dense, vert_ps, transport_rates, edge_ps, lrs)
131+
jac_prototype = nothing
132+
end
133+
J = nothing
134+
end
135+
136+
return ODEFunction(f; jac = J, jac_prototype = jac_prototype)
108137
end
109138

110139
# Builds a jacobian prototype. If requested, populate it with the Jacobian's (constant) values as well.
@@ -120,7 +149,7 @@ function build_jac_prototype(ns_jac_prototype::SparseMatrixCSC{Float64, Int64},
120149
ns_j_idxs = ns_jac_prototype_idxs[2]
121150

122151
# List the indexes of all non-zero Jacobian terms.
123-
non_spat_terms = [[get_index(vert, s_i, lrs.num_species), get_index(vert, s_j, lrs.num_species)] for vert in 1:(lrs.num_verts) for (s_i, s_j) in zip(ns_i_idxs,ns_j_idxs)] # Indexes of elements due to non-spatial dynamics.
152+
non_spat_terms = [[get_index(vert, s_i, lrs.num_species), get_index(vert, s_j, lrs.num_species)] for vert in 1:(lrs.num_verts) for (s_i, s_j) in zip(ns_i_idxs,ns_j_idxs)] # Indexes of elements due to non-spatial dynamics.
124153
trans_only_leaving_terms = [[get_index(e.src, s_idx, lrs.num_species), get_index(e.src, s_idx, lrs.num_species)] for e in edges(lrs.lattice) for s_idx in trans_only_species] # Indexes due to terms for a species leaves its current vertex (but does not have non-spatial dynamics). If the non-spatial Jacobian is fully dense, these would already be accounted for.
125154
trans_arriving_terms = [[get_index(e.src, s_idx, lrs.num_species), get_index(e.dst, s_idx, lrs.num_species)] for e in edges(lrs.lattice) for s_idx in trans_species] # Indexes due to terms for species arriving into a new vertex.
126155
all_terms = [non_spat_terms; trans_only_leaving_terms; trans_arriving_terms]
@@ -130,7 +159,7 @@ function build_jac_prototype(ns_jac_prototype::SparseMatrixCSC{Float64, Int64},
130159

131160
# Set element values.
132161
if set_nonzero
133-
for (s, rates) in trans_rates, (e_idx, e) in enumerate(edges(lrs.lattice)) # Loops through all species with transportation and all edges along which the can be transported.
162+
for (s, rates) in trans_rates, (e_idx, e) in enumerate(edges(lrs.lattice))
134163
jac_prototype[get_index(e.src, s, lrs.num_species), get_index(e.src, s, lrs.num_species)] -= get_component_value(rates, e_idx) # Term due to species leaving source vertex.
135164
jac_prototype[get_index(e.src, s, lrs.num_species), get_index(e.dst, s, lrs.num_species)] += get_component_value(rates, e_idx) # Term due to species arriving to destination vertex.
136165
end
@@ -147,7 +176,7 @@ function (f_func::LatticeTransportODEf)(du, u, p, t)
147176
idxs = get_indexes(vert_i, f_func.num_species)
148177

149178
# vector of vertex ps at vert_i
150-
vert_i_ps = view_vert_ps_vector!(f_func.work_vert_ps, p, vert_i, f_func.enum_v_ps_idx_types)
179+
vert_i_ps = view_vert_ps_vector!(f_func.work_vert_ps, p, vert_i, enumerate(f_func.v_ps_idx_types))
151180

152181
# evaluate reaction contributions to du at vert_i
153182
f_func.ofunc((@view du[idxs]), (@view u[idxs]), vert_i_ps, t)
@@ -156,13 +185,13 @@ function (f_func::LatticeTransportODEf)(du, u, p, t)
156185
# s_idx is species index among transport species, s is index among all species
157186
# rates are the species' transport rates
158187
for (s_idx, (s, rates)) in enumerate(f_func.transport_rates)
159-
# rate for leaving vert_i
188+
# Rate for leaving vert_i
160189
for vert_i in 1:(f_func.num_verts)
161190
idx = get_index(vert_i, s, f_func.num_species)
162191
du[idx] -= f_func.leaving_rates[s_idx, vert_i] * u[idx]
163192
end
164-
# add rates for entering a given vertex via an incoming edge
165-
for (e_idx, e) in f_func.enum_edges
193+
# Add rates for entering a given vertex via an incoming edge
194+
for (e_idx, e) in enumerate(f_func.edges)
166195
idx_dst = get_index(e.dst, s, f_func.num_species)
167196
idx_src = get_index(e.src, s, f_func.num_species)
168197
du[idx_dst] += get_component_value(rates, e_idx) * u[idx_src]
@@ -177,14 +206,10 @@ function (jac_func::LatticeTransportODEjac)(J, u, p, t)
177206
# Update the Jacobian from reaction terms
178207
for vert_i in 1:(jac_func.num_verts)
179208
idxs = get_indexes(vert_i, jac_func.num_species)
180-
vert_ps = view_vert_ps_vector!(jac_func.work_vert_ps, p, vert_i, jac_func.enum_v_ps_idx_types)
209+
vert_ps = view_vert_ps_vector!(jac_func.work_vert_ps, p, vert_i, enumerate(jac_func.v_ps_idx_types))
181210
jac_func.ofunc.jac((@view J[idxs, idxs]), (@view u[idxs]), vert_ps, t)
182211
end
183212

184213
# Updates for the spatial reactions (adds the Jacobian values from the diffusion reactions).
185-
add_spat_J_vals!(J, jac_func)
186-
end
187-
188-
# Updates the jacobian matrix with the diffusion values. Separate for spatial and non-spatial cases.
189-
add_spat_J_vals!(J::SparseMatrixCSC, jac_func::LatticeTransportODEjac) = (J.nzval .+= jac_func.jac_values)
190-
add_spat_J_vals!(J::Matrix, jac_func::LatticeTransportODEjac) = (J .+= jac_func.jac_values)
214+
J .+= jac_func.jac_transport
215+
end

src/spatial_reaction_systems/spatial_reactions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ function find_parameters_in_rate!(parameters, rateex::ExprValues)
105105
push!(parameters, rateex)
106106
end
107107
elseif rateex isa Expr
108-
# note, this (correctly) skips $(...) expressions
108+
# Note, this (correctly) skips $(...) expressions
109109
for i in 2:length(rateex.args)
110110
find_parameters_in_rate!(parameters, rateex.args[i])
111111
end

src/spatial_reaction_systems/utility.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,15 @@ function compute_transport_rates(rate_law::Num,
131131
for p in relevant_parameters)) for idxE in 1:num_edges]
132132
end
133133

134+
# Creates a map, taking each species (with transportation) to its transportation rate.
135+
# The species is represented by its index (in species(lrs).
136+
# If the rate is uniform across all edges, the vector will be length 1 (with this value), else there will be a separate value for each edge.
137+
# Pair{Int64, Vector{T}}[] is required in case vector is empty (otherwise it becomes Any[], causing type error later).
138+
function make_sidxs_to_transrate_map(vert_ps::Vector{Vector{Float64}}, edge_ps::Vector{Vector{T}}, lrs::LatticeReactionSystem) where T
139+
transport_rates_speciesmap = compute_all_transport_rates(vert_ps, edge_ps, lrs)
140+
return Pair{Int64, Vector{T}}[speciesmap(lrs.rs)[spat_rates[1]] => spat_rates[2] for spat_rates in transport_rates_speciesmap]
141+
end
142+
134143
### Accessing State & Parameter Array Values ###
135144

136145
# Gets the index in the u array of species s in vertex vert (when their are num_species species).

0 commit comments

Comments
 (0)