Skip to content

Commit 38c2869

Browse files
committed
Use SpatialMassActionJump
1 parent 50c20be commit 38c2869

File tree

4 files changed

+222
-23
lines changed

4 files changed

+222
-23
lines changed

src/Catalyst.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ module Catalyst
66
using DocStringExtensions
77
using SparseArrays, DiffEqBase, Reexport, Setfield
88
using LaTeXStrings, Latexify, Requires
9-
using JumpProcesses: JumpProcesses,
10-
JumpProblem, MassActionJump, ConstantRateJump,
11-
VariableRateJump
9+
using JumpProcesses: JumpProcesses, JumpProblem,
10+
MassActionJump, ConstantRateJump, VariableRateJump,
11+
SpatialMassActionJump
1212

1313
# ModelingToolkit imports and convenience functions we use
1414
using ModelingToolkit

src/spatial_reaction_systems/lattice_jump_systems.jl

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,19 @@ end
2828
function JumpProcesses.JumpProblem(lrs::LatticeReactionSystem, dprob, aggregator, args...; name = nameof(lrs.rs),
2929
combinatoric_ratelaws = get_combinatoric_ratelaws(lrs.rs), kwargs...)
3030
# Error checks.
31-
(dprob.p isa Vector{Vector{Vector{Float64}}}) || dprob.p isa Vector{Vector} || error("Parameters in input DiscreteProblem is of an unexpected type: $(typeof(dprob.p)). Was a LatticeReactionProblem passed into the DiscreteProblem when it was created?") # The second check (Vector{Vector} is needed becaus on the CI server somehow the Tuple{..., ...} is covnerted into a Vector[..., ...]). It does not happen when I run tests locally, so no ideal how to fix.
32-
any(length.(dprob.p[1]) .> 1) && error("Spatial reaction rates are currently not supported in lattice jump simulations.")
31+
# The second check (Vector{Vector} is needed because on the CI server somehow the Tuple{..., ...} is converted into a Vector[..., ...]).
32+
# It does not happen when I run tests locally, so no ideal how to fix.
33+
(dprob.p isa Vector{Vector{Vector{Float64}}}) || dprob.p isa Vector{Vector} || error("Parameters in input DiscreteProblem is of an unexpected type: $(typeof(dprob.p)). Was a LatticeReactionProblem passed into the DiscreteProblem when it was created?")
3334

3435
# Computes hopping constants and mass action jumps (requires some internal juggling).
35-
# The non-spatial DiscreteProblem have a u0 matrix with entries for all combinations of species and vertexes.
3636
# Currently, JumpProcesses requires uniform vertex parameters (hence `p=first.(dprob.p[1])`).
37+
# Currently, the resulting JumpProblem does not depend on parameters (no way to incorporate these).
38+
# Hence the parameters of this one does nto actually matter. If at some point JumpProcess can
39+
# handle parameters this can be updated and improved.
40+
# The non-spatial DiscreteProblem have a u0 matrix with entries for all combinations of species and vertexes.
3741
hopping_constants = make_hopping_constants(dprob, lrs)
42+
sma_jumps = make_spatial_majumps(dprob, lrs)
3843
non_spat_dprob = DiscreteProblem(reshape(dprob.u0, lrs.num_species, lrs.num_verts), dprob.tspan, first.(dprob.p[1]))
39-
sma_jumps = make_spatial_majumps(non_spat_dprob, dprob, lrs)
4044

4145
return JumpProblem(non_spat_dprob, aggregator, sma_jumps;
4246
hopping_constants, spatial_system = lrs.lattice, name, kwargs...)
@@ -66,22 +70,65 @@ function make_hopping_constants(dprob::DiscreteProblem, lrs::LatticeReactionSyst
6670
return hopping_constants
6771
end
6872

69-
# Creates the (spatial) mass action jumps from a (spatial) DiscreteProblem its non-spatial version, and a LatticeReactionSystem.
70-
function make_spatial_majumps(non_spat_dprob, dprob, rs::LatticeReactionSystem)
71-
ma_jumps = make_majumps(non_spat_dprob, lrs.rs)
72-
73+
# Creates a SpatialMassActionJump struct from a (spatial) DiscreteProblem and a LatticeReactionSystem.
74+
# Could implementation a version which, if all reaction's rates are uniform, returns a MassActionJump.
75+
# Not sure if there is any form of performance improvement from that though. Possibly is not the case.
76+
function make_spatial_majumps(dprob, lrs::LatticeReactionSystem)
77+
# Creates a vector, storing which reactions have spatial components.
78+
is_spatials = [Catalyst.has_spatial_vertex_component(rx.rate, lrs; vert_ps = dprob.p[1]) for rx in reactions(lrs.rs)]
79+
80+
# Creates templates for the rates (uniform and spatial) and the stoichiometries.
81+
# We cannot fetch reactant_stoich and net_stoich from a (non-spatial) MassActionJump.
82+
# The reason is that we need to re-order the reactions so that uniform appears first, and spatial next.
83+
u_rates = Vector{Float64}(undef, length(reactions(lrs.rs)) - count(is_spatials))
84+
s_rates = Matrix{Float64}(undef, count(is_spatials), lrs.num_verts)
85+
reactant_stoich = Vector{Vector{Pair{Int64, Int64}}}(undef, length(reactions(lrs.rs)))
86+
net_stoich = Vector{Vector{Pair{Int64, Int64}}}(undef, length(reactions(lrs.rs)))
87+
88+
# Loops through reactions with non-spatial rates, computes their rates and stoichiometries.
89+
cur_rx = 1;
90+
for (is_spat, rx) in zip(is_spatials, reactions(lrs.rs))
91+
is_spat && continue
92+
u_rates[cur_rx] = compute_vertex_value(rx.rate, lrs; vert_ps = dprob.p[1])[1]
93+
substoich_map = Pair.(rx.substrates, rx.substoich)
94+
reactant_stoich[cur_rx] = int_map(substoich_map, lrs.rs)
95+
net_stoich[cur_rx] = int_map(rx.netstoich, lrs.rs)
96+
cur_rx += 1
97+
end
98+
# Loops through reactions with spatial rates, computes their rates and stoichiometries.
99+
for (is_spat, rx) in zip(is_spatials, reactions(lrs.rs))
100+
is_spat || continue
101+
s_rates[cur_rx-length(u_rates),:] = compute_vertex_value(rx.rate, lrs; vert_ps = dprob.p[1])
102+
substoich_map = Pair.(rx.substrates, rx.substoich)
103+
reactant_stoich[cur_rx] = int_map(substoich_map, lrs.rs)
104+
net_stoich[cur_rx] = int_map(rx.netstoich, lrs.rs)
105+
cur_rx += 1
106+
end
107+
# SpatialMassActionJump expects empty rate containers to be nothing.
108+
isempty(u_rates) && (u_rates = nothing)
109+
(count(is_spatials)==0) && (s_rates = nothing)
110+
111+
return SpatialMassActionJump(u_rates, s_rates, reactant_stoich, net_stoich)
73112
end
74113

75-
# Creates the (non-spatial) mass action jumps from a (non-spatial) DiscreteProblem (and its Reaction System of origin).
76-
function make_majumps(non_spat_dprob, rs::ReactionSystem)
77-
# Computes various required inputs for assembling the mass action jumps.
78-
js = convert(JumpSystem, rs)
79-
statetoid = Dict(ModelingToolkit.value(state) => i for (i, state) in enumerate(states(rs)))
80-
eqs = equations(js)
81-
invttype = non_spat_dprob.tspan[1] === nothing ? Float64 : typeof(1 / non_spat_dprob.tspan[2])
82-
83-
# Assembles the non-spatial mass action jumps.
84-
p = (non_spat_dprob.p isa DiffEqBase.NullParameters || non_spat_dprob.p === nothing) ? Num[] : non_spat_dprob.p
85-
majpmapper = ModelingToolkit.JumpSysMajParamMapper(js, p; jseqs = eqs, rateconsttype = invttype)
86-
return ModelingToolkit.assemble_maj(eqs.x[1], statetoid, majpmapper)
114+
### Extra ###
115+
116+
# Temporary. Awaiting implementation in SII, or proper implementation withinCatalyst (with more general functionality).
117+
function int_map(map_in, sys) where {T,S}
118+
return [ModelingToolkit.variable_index(sys, pair[1]) => pair[2] for pair in map_in]
87119
end
120+
121+
# Currently unused. If we want to create certain types of MassActionJumps (instead of SpatialMassActionJumps) we can take this one back.
122+
# Creates the (non-spatial) mass action jumps from a (non-spatial) DiscreteProblem (and its Reaction System of origin).
123+
# function make_majumps(non_spat_dprob, rs::ReactionSystem)
124+
# # Computes various required inputs for assembling the mass action jumps.
125+
# js = convert(JumpSystem, rs)
126+
# statetoid = Dict(ModelingToolkit.value(state) => i for (i, state) in enumerate(states(rs)))
127+
# eqs = equations(js)
128+
# invttype = non_spat_dprob.tspan[1] === nothing ? Float64 : typeof(1 / non_spat_dprob.tspan[2])
129+
#
130+
# # Assembles the non-spatial mass action jumps.
131+
# p = (non_spat_dprob.p isa DiffEqBase.NullParameters || non_spat_dprob.p === nothing) ? Num[] : non_spat_dprob.p
132+
# majpmapper = ModelingToolkit.JumpSysMajParamMapper(js, p; jseqs = eqs, rateconsttype = invttype)
133+
# return ModelingToolkit.assemble_maj(eqs.x[1], statetoid, majpmapper)
134+
# end

src/spatial_reaction_systems/utility.jl

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,99 @@ end
295295
function matrix_expand_component_values(values::Vector{<:Vector}, n)
296296
reshape(expand_component_values(values, n), length(values), n)
297297
end
298+
299+
# For an expression, computes its values using the provided state and parameter vectors.
300+
# The expression is assumed to be valid in edges (and can have edges parameter components).
301+
# If some component is non-uniform, output is a vector of length equal to the number of vertexes.
302+
# If all components are uniform, the output is a length one vector.
303+
function compute_edge_value(exp, lrs::LatticeReactionSystem, edge_ps)
304+
# Finds the symbols in the expression. Checks that all correspond to edge parameters.
305+
relevant_syms = Symbolics.get_variables(exp)
306+
if !all(any(isequal(sym, p) for p in edge_parameters(lrs)) for sym in relevant_syms)
307+
error("An non-edge parameter was encountered in expressions: $exp. Here, only edge parameters are expected.")
308+
end
309+
310+
# Creates a Function tha computes the expressions value for a parameter set.
311+
exp_func = drop_expr(@RuntimeGeneratedFunction(build_function(exp, relevant_syms...)))
312+
# Creates a dictionary with the value(s) for all edge parameters.
313+
sym_val_dict = vals_to_dict(edge_parameters(lrs), edge_ps)
314+
315+
# If all values are uniform, compute value once. Else, do it at all edges.
316+
if !has_spatial_edge_component(exp, lrs, edge_ps)
317+
return [exp_func([sym_val_dict[sym][1] for sym in relevant_syms]...)]
318+
end
319+
return [exp_func([get_component_value(sym_val_dict[sym], idxE) for sym in relevant_syms]...)
320+
for idxE in 1:lrs.num_edges]
321+
end
322+
323+
# For an expression, computes its values using the provided state and parameter vectors.
324+
# The expression is assumed to be valid in vertexes (and can have vertex parameter and state components).
325+
# If at least one component is non-uniform, output is a vector of length equal to the number of vertexes.
326+
# If all components are uniform, the output is a length one vector.
327+
function compute_vertex_value(exp, lrs::LatticeReactionSystem; u=nothing, vert_ps=nothing)
328+
# Finds the symbols in the expression. Checks that all correspond to states or vertex parameters.
329+
relevant_syms = Symbolics.get_variables(exp)
330+
if any(any(isequal(sym) in edge_parameters(lrs)) for sym in relevant_syms)
331+
error("An edge parameter was encountered in expressions: $exp. Here, on vertex-based components are expected.")
332+
end
333+
# Creates a Function tha computes the expressions value for a parameter set.
334+
exp_func = drop_expr(@RuntimeGeneratedFunction(build_function(exp, relevant_syms...)))
335+
# Creates a dictionary with the value(s) for all edge parameters.
336+
if !isnothing(u) && !isnothing(vert_ps)
337+
all_syms = [species(lrs); vertex_parameters(lrs)]
338+
all_vals = [u; vert_ps]
339+
elseif !isnothing(u) && isnothing(vert_ps)
340+
all_syms = species(lrs)
341+
all_vals = u
342+
343+
elseif isnothing(u) && !isnothing(vert_ps)
344+
all_syms = vertex_parameters(lrs)
345+
all_vals = vert_ps
346+
else
347+
error("Either u or vertex_ps have to be provided to has_spatial_vertex_component.")
348+
end
349+
sym_val_dict = vals_to_dict(all_syms, all_vals)
350+
351+
# If all values are uniform, compute value once. Else, do it at all edges.
352+
if !has_spatial_vertex_component(exp, lrs; u, vert_ps)
353+
return [exp_func([sym_val_dict[sym][1] for sym in relevant_syms]...)]
354+
end
355+
return [exp_func([get_component_value(sym_val_dict[sym], idxV) for sym in relevant_syms]...)
356+
for idxV in 1:lrs.num_verts]
357+
end
358+
359+
### System Property Checks ###
360+
361+
# For a Symbolic expression, a LatticeReactionSystem, and a parameter list of the internal format:
362+
# Checks if any edge parameter in the expression have a spatial component (that is, is not uniform).
363+
function has_spatial_edge_component(exp, lrs::LatticeReactionSystem, edge_ps)
364+
# Finds the edge parameters in the expression. Computes their indexes.
365+
exp_syms = Symbolics.get_variables(exp)
366+
exp_edge_ps = filter(sym -> any(isequal(sym), edge_parameters(lrs)), exp_syms)
367+
p_idxs = [findfirst(isequal(sym, edge_p) for edge_p in edge_parameters(lrs)) for sym in exp_syms]
368+
# Checks if any of the corresponding value vectors have length != 1 (that is, is not uniform).
369+
return any(length(edge_ps[p_idx]) != 1 for p_idx in p_idxs)
370+
end
371+
372+
# For a Symbolic expression, a LatticeReactionSystem, and a parameter list of the internal format (vector of vectors):
373+
# Checks if any vertex parameter in the expression have a spatial component (that is, is not uniform).
374+
function has_spatial_vertex_component(exp, lrs::LatticeReactionSystem; u=nothing, vert_ps=nothing)
375+
# Finds all the symbols in the expression.
376+
exp_syms = Symbolics.get_variables(exp)
377+
378+
# If vertex parameter values where given, checks if any of these have non-uniform values.
379+
if !isnothing(vert_ps)
380+
exp_vert_ps = filter(sym -> any(isequal(sym), vertex_parameters(lrs)), exp_syms)
381+
p_idxs = [ModelingToolkit.parameter_index(lrs.rs, sym) for sym in exp_vert_ps]
382+
any(length(vert_ps[p_idx]) != 1 for p_idx in p_idxs) && return true
383+
end
384+
385+
# If states values where given, checks if any of these have non-uniform values.
386+
if !isnothing(u)
387+
exp_u = filter(sym -> any(isequal(sym), species(lrs)), exp_syms)
388+
u_idxs = [ModelingToolkit.variable_index(lrs.rs, sym) for sym in exp_u]
389+
any(length(u[u_idx]) != 1 for u_idx in u_idxs) && return true
390+
end
391+
392+
return false
393+
end

test/spatial_reaction_systems/lattice_reaction_systems_jumps.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,62 @@ let
105105
end
106106

107107

108+
### SpatialMassActionJump Testing ###
109+
110+
# Checks that the correct structure is produced.
111+
let
112+
# Network for reference:
113+
# A, ∅ → X
114+
# 1, 2X + Y → 3X
115+
# B, X → Y
116+
# 1, X → ∅
117+
# srs = [@transport_reaction dX X]
118+
# Create LatticeReactionSystem
119+
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, small_3d_grid)
120+
121+
# Create JumpProblem
122+
u0 = [:X => 1, :Y => rand(1:10, lrs.num_verts)]
123+
tspan = (0.0, 100.0)
124+
ps = [:A => 1.0, :B => 5.0 .+ rand(lrs.num_verts), :dX => rand(lrs.num_edges)]
125+
dprob = DiscreteProblem(lrs, u0, tspan, ps)
126+
jprob = JumpProblem(lrs, dprob, NSM())
127+
128+
# Checks internal structures.
129+
jprob.massaction_jump.uniform_rates == [1.0, 0.5 ,10.] # 0.5 is due to combinatoric /2! in (2X + Y).
130+
jprob.massaction_jump.spatial_rates[1,:] == ps[2][2]
131+
# Test when new SII functions are ready, or we implement them in Catalyst.
132+
# @test isequal(to_int(getfield.(reactions(lrs.rs), :netstoich)), jprob.massaction_jump.net_stoch)
133+
# @test isequal(to_int(Pair.(getfield.(reactions(lrs.rs), :substrates),getfield.(reactions(lrs.rs), :substoich))), jprob.massaction_jump.net_stoch)
134+
135+
# Checks that problem can be simulated.
136+
@test SciMLBase.successful_retcode(solve(jprob, SSAStepper()))
137+
end
138+
139+
# Checks that simulations gives a correctly heterogeneous solution.
140+
let
141+
# Create model.
142+
birth_death_network = @reaction_network begin
143+
(p,d), 0 <--> X
144+
end
145+
srs = [(@transport_reaction D X)]
146+
lrs = LatticeReactionSystem(birth_death_network, srs, very_small_2d_grid)
147+
148+
# Create JumpProblem.
149+
u0 = [:X => 1]
150+
tspan = (0.0, 100.0)
151+
ps = [:p => [0.1, 1.0, 10.0, 100.0], :d => 1.0, :D => 0.0]
152+
dprob = DiscreteProblem(lrs, u0, tspan, ps)
153+
jprob = JumpProblem(lrs, dprob, NSM())
154+
155+
# Simulate model (a few repeats to ensure things don't succeed by change for uniform rates).
156+
# Check that higher p gives higher mean.
157+
for i = 1:5
158+
sol = solve(jprob, SSAStepper(); saveat = 1., seed = i*1234)
159+
@test mean(getindex.(sol.u, 1)) < mean(getindex.(sol.u, 2)) < mean(getindex.(sol.u, 3)) < mean(getindex.(sol.u, 4))
160+
end
161+
end
162+
163+
108164
### Tests taken from JumpProcesses ###
109165

110166
# ABC Model Test

0 commit comments

Comments
 (0)