Skip to content

Commit 5036194

Browse files
committed
fix for LatticeReactionSystems
1 parent 816de1b commit 5036194

File tree

6 files changed

+241
-8
lines changed

6 files changed

+241
-8
lines changed

src/spatial_reaction_systems/lattice_reaction_systems.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,17 @@ struct LatticeReactionSystem{S,T} # <: MT.AbstractTimeDependentSystem
2525
All parameters related to the lattice reaction system
2626
(both with spatial and non-spatial effects).
2727
"""
28-
parameters::Vector{BasicSymbolic{Real}}
28+
parameters::Vector{Any}
2929
"""
3030
Parameters which values are tied to vertexes (adjacencies),
3131
e.g. (possibly) have an unique value at each vertex of the system.
3232
"""
33-
vertex_parameters::Vector{BasicSymbolic{Real}}
33+
vertex_parameters::Vector{Any}
3434
"""
3535
Parameters which values are tied to edges (adjacencies),
3636
e.g. (possibly) have an unique value at each edge of the system.
3737
"""
38-
edge_parameters::Vector{BasicSymbolic{Real}}
38+
edge_parameters::Vector{Any}
3939

4040
function LatticeReactionSystem(rs::ReactionSystem{S}, spatial_reactions::Vector{T},
4141
lattice::DiGraph; init_digraph = true) where {S, T}
@@ -52,7 +52,7 @@ struct LatticeReactionSystem{S,T} # <: MT.AbstractTimeDependentSystem
5252
num_species = length(unique([species(rs); spat_species]))
5353
rs_edge_parameters = filter(isedgeparameter, parameters(rs))
5454
if isempty(spatial_reactions)
55-
srs_edge_parameters = Vector{BasicSymbolic{Real}}[]
55+
srs_edge_parameters = Vector{Any}[]
5656
else
5757
srs_edge_parameters = setdiff(reduce(vcat, [parameters(sr) for sr in spatial_reactions]), parameters(rs))
5858
end

src/spatial_reaction_systems/spatial_reactions.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ function check_spatial_reaction_validity(rs::ReactionSystem, tr::TransportReacti
9595
if any([isequal(tr.species, s) && !isequivalent(tr.species, s) for s in species(rs)])
9696
error("A transport reaction used a species, $(tr.species), with metadata not matching its lattice reaction system. Please fetch this species from the reaction system and used in transport reaction creation.")
9797
end
98-
if any([isequal(rs_p, tr_p) && !equivalent_metadata(rs_p, tr_p)
98+
if any([isequal(rs_p, tr_p) && !isequivalent(rs_p, tr_p)
9999
for rs_p in parameters(rs), tr_p in Symbolics.get_variables(tr.rate)])
100100
error("A transport reaction used a parameter with metadata not matching its lattice reaction system. Please fetch this parameter from the reaction system and used in transport reaction creation.")
101101
end
@@ -105,15 +105,16 @@ function check_spatial_reaction_validity(rs::ReactionSystem, tr::TransportReacti
105105
error("Edge paramter(s) were found as a rate of a non-spatial reaction.")
106106
end
107107
end
108-
equivalent_metadata(p1, p2) = isempty(setdiff(p1.metadata, p2.metadata, [Catalyst.EdgeParameter => true]))
109108

110109
# Since MTK's "isequal" ignores metadata, we have to use a special function that accounts for this.
111110
# This is important because whether something is an edge parameter is defined in metadata.
112111
function isequivalent(sym1, sym2)
113-
!isequal(sym1, sym2) && (return false)
114-
(sym1.metadata != sym2.metadata) && (return false)
112+
isequal(sym1, sym2) || (return false)
113+
(ignore_ep_metadata(sym1.metadata) != ignore_ep_metadata(sym2.metadata)) && (return false)
114+
(typeof(sym1) != typeof(sym2)) && (return false)
115115
return true
116116
end
117+
ignore_ep_metadata(metadata) = setdiff(metadata, [Catalyst.EdgeParameter => true])
117118

118119
# Implements custom `==`.
119120
"""

test/dsl/dsl_options.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
# Fetch packages.
66
using Catalyst, ModelingToolkit, OrdinaryDiffEq, Plots, Test
7+
using Symbolics: unwrap
78

89
# Sets the default `t` to use.
910
t = default_t()
@@ -344,6 +345,40 @@ let
344345
@test merge(ModelingToolkit.defaults(rn28), defs28) == ModelingToolkit.defaults(rn27)
345346
end
346347

348+
# Tests that parameter type designation works.
349+
let
350+
rn = @reaction_network begin
351+
@parameters begin
352+
k1
353+
l1
354+
k2::Float64 = 2.0
355+
l2::Float64
356+
k3::Int64 = 2, [description="A parameter"]
357+
l3::Int64
358+
k4::Float32, [description="Another parameter"]
359+
l4::Float32
360+
k5::Rational{Int64}
361+
l5::Rational{Int64}
362+
end
363+
(k1,l1), X1 <--> Y1
364+
(k2,l2), X2 <--> Y2
365+
(k3,l3), X3 <--> Y3
366+
(k4,l4), X4 <--> Y4
367+
(k5,l5), X5 <--> Y5
368+
end
369+
370+
@test unwrap(rn.k1) isa SymbolicUtils.BasicSymbolic{Real}
371+
@test unwrap(rn.l1) isa SymbolicUtils.BasicSymbolic{Real}
372+
@test unwrap(rn.k2) isa SymbolicUtils.BasicSymbolic{Float64}
373+
@test unwrap(rn.l2) isa SymbolicUtils.BasicSymbolic{Float64}
374+
@test unwrap(rn.k3) isa SymbolicUtils.BasicSymbolic{Int64}
375+
@test unwrap(rn.l3) isa SymbolicUtils.BasicSymbolic{Int64}
376+
@test unwrap(rn.k4) isa SymbolicUtils.BasicSymbolic{Float32}
377+
@test unwrap(rn.l4) isa SymbolicUtils.BasicSymbolic{Float32}
378+
@test unwrap(rn.k5) isa SymbolicUtils.BasicSymbolic{Rational{Int64}}
379+
@test unwrap(rn.l5) isa SymbolicUtils.BasicSymbolic{Rational{Int64}}
380+
end
381+
347382
### Observables ###
348383

349384
# Test basic functionality.
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
### Fetch Packages and Set Global Variables ###
2+
3+
# Fetch packages.
4+
using Catalyst, JumpProcesses, NonlinearSolve, OrdinaryDiffEq, StochasticDiffEq, Test
5+
using Symbolics: unwrap
6+
7+
# Sets stable rng number.
8+
using StableRNGs
9+
rng = StableRNG(12345)
10+
seed = rand(rng, 1:100)
11+
12+
# Sets the default `t` to use.
13+
t = default_t()
14+
15+
### Basic Tests ###
16+
17+
# Declares a simple model to run tests on.
18+
begin
19+
t = default_t()
20+
@parameters p1 p2::Float64 p3::Int64 p4::Float32 p5::Rational{Int64}
21+
@parameters d1 d2::Float64 = 1.2 d3::Int64 = 2 [description = "A parameter"] d4::Rational{Int64} d5::Float32
22+
@species X1(t) X2(t) X3(t) X4(t) X5(t)
23+
24+
rxs = [
25+
Reaction(p1, nothing, [X1]),
26+
Reaction(p2, nothing, [X2]),
27+
Reaction(p3, nothing, [X3]),
28+
Reaction(p4, nothing, [X4]),
29+
Reaction(p5, nothing, [X5]),
30+
Reaction(d1, [X1], nothing),
31+
Reaction(d2, [X2], nothing),
32+
Reaction(d3, [X3], nothing),
33+
Reaction(d4, [X4], nothing),
34+
Reaction(d5, [X5], nothing)
35+
]
36+
@named rs = ReactionSystem(rxs)
37+
rs = complete(rs)
38+
39+
# Declares initial condition and potential parameter sets.
40+
u0 = [X1 => 0.1, X2 => 0.2, X3 => 0.3, X4 => 0.4, X5 => 0.5]
41+
p_alts = [
42+
[p1 => 1.0, d1 => 1.0, p2 => 1.2, p3 => 2, p4 => 0.5, d4 => 1//2, p5 => 3//2, d5 => 1.5],
43+
(p1 => 1.0, d1 => 1.0, p2 => 1.2, p3 => 2, p4 => 0.5, d4 => 1//2, p5 => 3//2, d5 => 1.5),
44+
Dict([p1 => 1.0, d1 => 1.0, p2 => 1.2, p3 => 2, p4 => 0.5, d4 => 1//2, p5 => 3//2, d5 => 1.5])
45+
]
46+
end
47+
48+
# Tests that parameters stored in the system have the correct type.
49+
let
50+
@test Symbolics.unwrap(rs.p1) isa SymbolicUtils.BasicSymbolic{Real}
51+
@test Symbolics.unwrap(rs.d1) isa SymbolicUtils.BasicSymbolic{Real}
52+
@test Symbolics.unwrap(rs.p2) isa SymbolicUtils.BasicSymbolic{Float64}
53+
@test Symbolics.unwrap(rs.d2) isa SymbolicUtils.BasicSymbolic{Float64}
54+
@test Symbolics.unwrap(rs.p3) isa SymbolicUtils.BasicSymbolic{Int64}
55+
@test Symbolics.unwrap(rs.d3) isa SymbolicUtils.BasicSymbolic{Int64}
56+
@test Symbolics.unwrap(rs.p4) isa SymbolicUtils.BasicSymbolic{Float32}
57+
@test Symbolics.unwrap(rs.d4) isa SymbolicUtils.BasicSymbolic{Rational{Int64}}
58+
@test Symbolics.unwrap(rs.p5) isa SymbolicUtils.BasicSymbolic{Rational{Int64}}
59+
@test Symbolics.unwrap(rs.d5) isa SymbolicUtils.BasicSymbolic{Float32}
60+
end
61+
62+
# Tests that simulations with differentially typed variables yields correct results.
63+
let
64+
for p in p_alts
65+
oprob = ODEProblem(rs, u0, (0.0, 1000.0), p; abstol = 1e-10, reltol = 1e-10)
66+
sol = solve(oprob, Tsit5())
67+
@test all(sol[end] .≈ 1.0)
68+
end
69+
end
70+
71+
# Test that the various structures stores the parameters using the correct type.
72+
let
73+
# Creates problems, integrators, and solutions.
74+
oprob = ODEProblem(rs, u0, (0.0, 1.0), p_alts[1])
75+
sprob = SDEProblem(rs, u0, (0.0, 1.0), p_alts[1])
76+
dprob = DiscreteProblem(rs, u0, (0.0, 1.0), p_alts[1])
77+
jprob = JumpProblem(rs, dprob, Direct(); rng)
78+
nprob = NonlinearProblem(rs, u0, p_alts[1])
79+
80+
oinit = init(oprob, Tsit5())
81+
sinit = init(sprob, ImplicitEM())
82+
jinit = init(jprob, SSAStepper())
83+
ninit = init(nprob, NewtonRaphson())
84+
85+
osol = solve(oprob, Tsit5())
86+
ssol = solve(sprob, ImplicitEM(); seed)
87+
jsol = solve(jprob, SSAStepper(); seed)
88+
nsol = solve(nprob, NewtonRaphson())
89+
90+
# Checks the types of all stored parameter values.
91+
for mtk_struct in [oprob, sprob, dprob, jprob, nprob, oinit, sinit, jinit, osol, ssol, jsol, nsol]
92+
@test unwrap(mtk_struct.ps[p1]) isa Float64
93+
@test unwrap(mtk_struct.ps[d1]) isa Float64
94+
@test unwrap(mtk_struct.ps[p2]) isa Float64
95+
@test unwrap(mtk_struct.ps[d2]) isa Float64
96+
@test unwrap(mtk_struct.ps[p3]) isa Int64
97+
@test unwrap(mtk_struct.ps[d3]) isa Int64
98+
@test unwrap(mtk_struct.ps[p4]) isa Float32
99+
@test unwrap(mtk_struct.ps[d4]) isa Rational{Int64}
100+
@test unwrap(mtk_struct.ps[p5]) isa Rational{Int64}
101+
@test unwrap(mtk_struct.ps[d5]) isa Float32
102+
end
103+
104+
# Indexing currently broken for NonlinearSystem integrators (MTK intend to support this though).
105+
@test_broken unwrap(ninit.ps[p1]) isa Float64
106+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using SafeTestsets
88
@time @safetestset "Reactions" begin include("reactionsystem_structure/reactions.jl") end
99
@time @safetestset "ReactionSystem" begin include("reactionsystem_structure/reactionsystem.jl") end
1010
@time @safetestset "Higher Order Reactions" begin include("reactionsystem_structure/higher_order_reactions.jl") end
11+
@time @safetestset "Designation of Parameter Types" begin include("reactionsystem_structure/designating_parameter_types.jl") end
1112

1213
### Tests model creation via the @reaction_network DSL. ###
1314
@time @safetestset "Basic DSL" begin include("dsl/dsl_basics.jl") end

test/spatial_reaction_systems/lattice_reaction_systems.jl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22

33
# Fetch packages.
44
using Catalyst, Graphs, Test
5+
using Symbolics: unwrap
56
t = default_t()
67

78
# Pre declares a grid.
89
grid = Graphs.grid([2, 2])
910

11+
1012
### Tests LatticeReactionSystem Getters Correctness ###
13+
1114
# Test case 1.
1215
let
1316
rs = @reaction_network begin
@@ -272,3 +275,90 @@ let
272275
@test_throws ErrorException LatticeReactionSystem(rs, [tr], grid)
273276
end
274277

278+
### Test Designation of Parameter Types ###
279+
280+
# Checks that parameter types designated in the non-spatial `ReactionSystem` is handled correctly.
281+
let
282+
# Declares LatticeReactionSystem with designated parameter types.
283+
rs = @reaction_network begin
284+
@parameters begin
285+
k1
286+
l1
287+
k2::Float64 = 2.0
288+
l2::Float64
289+
k3::Int64 = 2, [description="A parameter"]
290+
l3::Int64
291+
k4::Float32, [description="Another parameter"]
292+
l4::Float32
293+
k5::Rational{Int64}
294+
l5::Rational{Int64}
295+
D1::Float32
296+
D2, [edgeparameter=true]
297+
D3::Rational{Int64}, [edgeparameter=true]
298+
end
299+
(k1,l1), X1 <--> Y1
300+
(k2,l2), X2 <--> Y2
301+
(k3,l3), X3 <--> Y3
302+
(k4,l4), X4 <--> Y4
303+
(k5,l5), X5 <--> Y5
304+
end
305+
tr1 = @transport_reaction $(rs.D1) X1
306+
tr2 = @transport_reaction $(rs.D2) X2
307+
tr3 = @transport_reaction $(rs.D3) X3
308+
lrs = LatticeReactionSystem(rs, [tr1, tr2, tr3], grid)
309+
310+
# Loops through all parameters, ensuring that they have the correct type
311+
p_types = Dict([ModelingToolkit.nameof(p) => typeof(unwrap(p)) for p in parameters(lrs)])
312+
@test p_types[:k1] == SymbolicUtils.BasicSymbolic{Real}
313+
@test p_types[:l1] == SymbolicUtils.BasicSymbolic{Real}
314+
@test p_types[:k2] == SymbolicUtils.BasicSymbolic{Float64}
315+
@test p_types[:l2] == SymbolicUtils.BasicSymbolic{Float64}
316+
@test p_types[:k3] == SymbolicUtils.BasicSymbolic{Int64}
317+
@test p_types[:l3] == SymbolicUtils.BasicSymbolic{Int64}
318+
@test p_types[:k4] == SymbolicUtils.BasicSymbolic{Float32}
319+
@test p_types[:l4] == SymbolicUtils.BasicSymbolic{Float32}
320+
@test p_types[:k5] == SymbolicUtils.BasicSymbolic{Rational{Int64}}
321+
@test p_types[:l5] == SymbolicUtils.BasicSymbolic{Rational{Int64}}
322+
@test p_types[:D1] == SymbolicUtils.BasicSymbolic{Float32}
323+
@test p_types[:D2] == SymbolicUtils.BasicSymbolic{Real}
324+
@test p_types[:D3] == SymbolicUtils.BasicSymbolic{Rational{Int64}}
325+
end
326+
327+
# Checks that programmatically declared parameters (with types) can be used in `TransportReaction`s.
328+
# Checks that LatticeReactionSystem with non-default parameter types can be simulated.
329+
let
330+
rs = @reaction_network begin
331+
@parameters p::Float32
332+
(p,d), 0 <--> X
333+
end
334+
@parameters D::Rational{Int64}
335+
tr = TransportReaction(D, rs.X)
336+
lrs = LatticeReactionSystem(rs, [tr], grid)
337+
338+
p_types = Dict([ModelingToolkit.nameof(p) => typeof(unwrap(p)) for p in parameters(lrs)])
339+
@test p_types[:p] == SymbolicUtils.BasicSymbolic{Float32}
340+
@test p_types[:d] == SymbolicUtils.BasicSymbolic{Real}
341+
@test p_types[:D] == SymbolicUtils.BasicSymbolic{Rational{Int64}}
342+
343+
u0 = [:X => [0.25, 0.5, 2.0, 4.0]]
344+
ps = [rs.p => 2.0, rs.d => 1.0, D => 1//2]
345+
346+
# Currently broken. This requires some non-trivial reworking of internals.
347+
# However, spatial internals have already been reworked (and greatly improved) in an unmerged PR.
348+
# This will be sorted out once that has finished.
349+
@test_broken false
350+
# oprob = ODEProblem(lrs, u0, (0.0, 10.0), ps)
351+
# sol = solve(oprob, Tsit5())
352+
# @test sol[end] == [1.0, 1.0, 1.0, 1.0]
353+
end
354+
355+
# Tests that LatticeReactionSystem cannot be generated where transport reactions depend on parameters
356+
# that have a type designated in the non-spatial `ReactionSystem`.
357+
let
358+
rs = @reaction_network begin
359+
@parameters D::Int64
360+
(p,d), 0 <--> X
361+
end
362+
tr = @transport_reaction D X
363+
@test_throws Exception LatticeReactionSystem(rs, tr, grid)
364+
end

0 commit comments

Comments
 (0)