Skip to content

Commit 102537e

Browse files
committed
finish internal remake. Permit non-Float64 ints
1 parent d7cd5c2 commit 102537e

File tree

6 files changed

+60
-43
lines changed

6 files changed

+60
-43
lines changed

src/spatial_reaction_systems/lattice_reaction_systems.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,17 @@ struct LatticeReactionSystem{Q,R,S,T} <: MT.AbstractTimeDependentSystem
2929
All parameters related to the lattice reaction system
3030
(both those whose values are tied to vertices and edges).
3131
"""
32-
parameters::Vector{BasicSymbolic{Real}}
32+
parameters::Vector{Any}
3333
"""
3434
Parameters which values are tied to vertices,
3535
e.g. that possibly could have unique values at each vertex of the system.
3636
"""
37-
vertex_parameters::Vector{BasicSymbolic{Real}}
37+
vertex_parameters::Vector{Any}
3838
"""
3939
Parameters whose values are tied to edges (adjacencies),
4040
e.g. that possibly could have unique values at each edge of the system.
4141
"""
42-
edge_parameters::Vector{BasicSymbolic{Real}}
42+
edge_parameters::Vector{Any}
4343
"""
4444
An iterator over all the lattice's edges. Currently, the format is always a Vector{Pair{Int64,Int64}}.
4545
However, in the future, different types could potentially be used for different types of lattice

src/spatial_reaction_systems/spatial_ODE_systems.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ struct LatticeTransportODEFunction{P,Q,R,S,T}
6262

6363
function LatticeTransportODEFunction(ofunc::P, ps::Vector{<:Pair},
6464
lrs::LatticeReactionSystem, transport_rates::Vector{Pair{Int64, SparseMatrixCSC{S, Int64}}},
65-
jac_transport::Union{Nothing, Matrix{S}, SparseMatrixCSC{S, Int64}}, sparse::Bool) where {P,S}
65+
jac_transport::Union{Nothing, Matrix{S}, SparseMatrixCSC{S, Int64}}) where {P,S}
6666

6767
# Creates a vector with the heterogeneous vertex parameters' indexes in the full parameter vector.
6868
p_dict = Dict(ps)
@@ -118,7 +118,7 @@ function (lt_ofun::LatticeTransportODEFunction)(du::AbstractVector, u, p, t)
118118
for e in lt_ofun.edge_iterator
119119
idx_src = get_index(e[1], s, lt_ofun.num_species)
120120
idx_dst = get_index(e[2], s, lt_ofun.num_species)
121-
du[idx_dst] += get_transport_rate(s_idx, lt_ofun, e) * u[idx_src]
121+
du[idx_dst] += get_transport_rate(rates, e, lt_ofun.t_rate_idx_types[s_idx]) * u[idx_src]
122122
end
123123
end
124124
end
@@ -182,10 +182,10 @@ function DiffEqBase.ODEProblem(lrs::LatticeReactionSystem, u0_in, tspan,
182182
end
183183

184184
# Builds an ODEFunction for a spatial ODEProblem.
185-
function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{BasicSymbolic{Real},Vector{T}}},
186-
edge_ps::Vector{Pair{BasicSymbolic{Real},SparseMatrixCSC{T, Int64}}},
185+
function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{R,Vector{T}}},
186+
edge_ps::Vector{Pair{S,SparseMatrixCSC{T, Int64}}},
187187
jac::Bool, sparse::Bool, name, include_zero_odes, combinatoric_ratelaws,
188-
remove_conserved, checks) where {T}
188+
remove_conserved, checks) where {R,S,T}
189189
# Error check.
190190
if remove_conserved
191191
error("Removal of conserved quantities is currently not supported for `LatticeReactionSystem`s")
@@ -209,7 +209,7 @@ function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{Basi
209209
end
210210

211211
# Creates the `LatticeTransportODEFunction` functor (if `jac`, sets it as the Jacobian as well).
212-
f = LatticeTransportODEFunction(ofunc_dense, [vert_ps; edge_ps], lrs, transport_rates, jac_transport, sparse)
212+
f = LatticeTransportODEFunction(ofunc_dense, [vert_ps; edge_ps], lrs, transport_rates, jac_transport)
213213
J = (jac ? f : nothing)
214214

215215
# Extracts the `Symbol` form for species and parameters. Creates and returns the `ODEFunction`.
@@ -268,7 +268,7 @@ function build_jac_prototype(ns_jac_prototype::SparseMatrixCSC{Float64, Int64},
268268
end
269269

270270
# Create a sparse Jacobian prototype with 0-valued entries.
271-
jac_prototype = sparse(i_idxs, j_idxs, zeros(num_entries))
271+
jac_prototype = sparse(i_idxs, j_idxs, zeros(T, num_entries))
272272

273273
# Set element values.
274274
if set_nonzero

src/spatial_reaction_systems/utility.jl

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ end
1515

1616
# From u0 input, extract their values and store them in the internal format.
1717
# Internal format: a vector on the form [spec 1 at vert 1, spec 2 at vert 1, ..., spec 1 at vert 2, ...]).
18-
function lattice_process_u0(u0_in, u0_syms::Vector{BasicSymbolic{Real}}, lrs::LatticeReactionSystem)
18+
function lattice_process_u0(u0_in, u0_syms::Vector, lrs::LatticeReactionSystem)
1919
# u0 values can be given in various forms. This converts it to a Vector{Pair{Symbolics,...}} form.
2020
# Top-level vector: Maps each species to its value(s).
2121
u0 = lattice_process_input(u0_in, u0_syms)
@@ -32,8 +32,8 @@ end
3232

3333
# From a parameter input, split it into vertex parameters and edge parameters.
3434
# Store these in the desired internal format.
35-
function lattice_process_p(ps_in, ps_vertex_syms::Vector{BasicSymbolic{Real}},
36-
ps_edge_syms::Vector{BasicSymbolic{Real}}, lrs::LatticeReactionSystem)
35+
function lattice_process_p(ps_in, ps_vertex_syms::Vector,
36+
ps_edge_syms::Vector, lrs::LatticeReactionSystem)
3737
# p values can be given in various forms. This converts it to a Vector{Pair{Symbolics,...}} form.
3838
# Top-level vector: Maps each parameter to its value(s).
3939
# Second-level: Contains either a vector (vertex parameters) or a sparse matrix (edge parameters).
@@ -52,7 +52,7 @@ end
5252
# The input (parameters or initial conditions) may either be a dictionary (symbolics to value(s).)
5353
# or a map (in vector or tuple form) from symbolics to value(s). This converts the input to a
5454
# (Vector) map from symbolics to value(s), where the entries have the same order as `syms`.
55-
function lattice_process_input(input::Dict{BasicSymbolic{Real}, T}, syms::Vector{BasicSymbolic{Real}}) where {T}
55+
function lattice_process_input(input::Dict{<:Any, T}, syms::Vector) where {T}
5656
# Error checks
5757
if !isempty(setdiff(keys(input), syms))
5858
error("You have provided values for the following unrecognised parameters/initial conditions: $(setdiff(keys(input), syms)).")
@@ -63,16 +63,15 @@ function lattice_process_input(input::Dict{BasicSymbolic{Real}, T}, syms::Vector
6363

6464
return [sym => input[sym] for sym in syms]
6565
end
66-
function lattice_process_input(input, syms::Vector{BasicSymbolic{Real}})
66+
function lattice_process_input(input, syms::Vector)
6767
if ((input isa Vector) || (input isa Tuple)) && all(entry isa Pair for entry in input)
6868
return lattice_process_input(Dict(input), syms)
6969
end
7070
error("Input parameters/initial conditions have the wrong format ($(typeof(input))). These should either be a Dictionary, or a Tuple or a Vector (where each entry is a Pair taking a parameter/species to its value).")
7171
end
7272

7373
# Splits parameters into vertex and edge parameters.
74-
# function split_parameters(ps::Vector{<: Pair}, p_vertex_syms::Vector, p_edge_syms::Vector)
75-
function split_parameters(ps, p_vertex_syms::Vector{BasicSymbolic{Real}}, p_edge_syms::Vector{BasicSymbolic{Real}})
74+
function split_parameters(ps, p_vertex_syms::Vector, p_edge_syms::Vector)
7675
vert_ps = [p for p in ps if any(isequal(p[1]), p_vertex_syms)]
7776
edge_ps = [p for p in ps if any(isequal(p[1]), p_edge_syms)]
7877
return vert_ps, edge_ps
@@ -86,7 +85,7 @@ function vertex_value_map(values, lrs::LatticeReactionSystem)
8685
end
8786

8887
# Converts the values for an individual species/vertex parameter to its correct vector form.
89-
function vertex_value_form(values, lrs::LatticeReactionSystem, sym::BasicSymbolic{Real})
88+
function vertex_value_form(values, lrs::LatticeReactionSystem, sym::BasicSymbolic)
9089
# If the value is a scalar (i.e. uniform across the lattice), return it in vector form.
9190
(values isa AbstractArray) || (return [values])
9291

@@ -112,7 +111,7 @@ end
112111

113112
# Converts values to the correct vector form for a Cartesian grid lattice.
114113
function vertex_value_form(values::AbstractArray, num_verts::Int64, lattice::CartesianGridRej{N,T},
115-
sym::BasicSymbolic{Real}) where {N,T}
114+
sym::BasicSymbolic) where {N,T}
116115
if size(values) != lattice.dims
117116
error("The values for $sym did not have the same format as the lattice. Expected a $(lattice.dims) array, got one of size $(size(values))")
118117
end
@@ -124,7 +123,7 @@ end
124123

125124
# Converts values to the correct vector form for a masked grid lattice.
126125
function vertex_value_form(values::AbstractArray, num_verts::Int64, lattice::Array{Bool,T},
127-
sym::BasicSymbolic{Real}) where {T}
126+
sym::BasicSymbolic) where {T}
128127
if size(values) != size(lattice)
129128
error("The values for $sym did not have the same format as the lattice. Expected a $(size(lattice)) array, got one of size $(size(values))")
130129
end
@@ -174,9 +173,9 @@ end
174173
# The species is represented by its index (in species(lrs).
175174
# If the rate is uniform across all edges, the transportation rate will be a size (1,1) sparse matrix.
176175
# Else, the rate will be a size (num_verts,num_verts) sparse matrix.
177-
function make_sidxs_to_transrate_map(vert_ps::Vector{Pair{BasicSymbolic{Real},Vector{T}}},
178-
edge_ps::Vector{Pair{BasicSymbolic{Real},SparseMatrixCSC{T, Int64}}},
179-
lrs::LatticeReactionSystem) where {T}
176+
function make_sidxs_to_transrate_map(vert_ps::Vector{Pair{R,Vector{T}}},
177+
edge_ps::Vector{Pair{S,SparseMatrixCSC{T, Int64}}},
178+
lrs::LatticeReactionSystem) where {R,S,T}
180179
# Creates a dictionary with each parameter's value(s).
181180
p_val_dict = Dict(vcat(vert_ps, edge_ps))
182181

@@ -203,7 +202,7 @@ end
203202
# and the values of all our parameters, compute the transport rate(s).
204203
# If all parameters that the rate depends on are uniform across all edges, this becomes a length-1 vector.
205204
# Else it becomes a vector where each value corresponds to the rate at one specific edge.
206-
function compute_transport_rates(s::BasicSymbolic{Real}, p_val_dict, lrs::LatticeReactionSystem)
205+
function compute_transport_rates(s::BasicSymbolic, p_val_dict, lrs::LatticeReactionSystem)
207206
# Find parameters involved in the rate and create a function evaluating the rate law.
208207
rate_law = get_transport_rate_law(s, lrs)
209208
relevant_ps = Symbolics.get_variables(rate_law)
@@ -228,7 +227,7 @@ end
228227
# For a species, retrieve the symbolic expression for its transportation rate
229228
# (likely only a single parameter, such as `D`, but could be e.g. L*D, where L and D are parameters).
230229
# If there are several transportation reactions for the species, their sum is used.
231-
function get_transport_rate_law(s::BasicSymbolic{Real}, lrs::LatticeReactionSystem)
230+
function get_transport_rate_law(s::BasicSymbolic, lrs::LatticeReactionSystem)
232231
rates = filter(sr -> isequal(s, sr.species), spatial_reactions(lrs))
233232
return sum(getfield.(rates, :rate))
234233
end

test/spatial_modelling/lattice_reaction_systems.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,11 @@ let
223223
tr_3 = TransportReaction(dZ, Z)
224224
tr_macro_1 = @transport_reaction $dX X
225225
tr_macro_2 = @transport_reaction $(rate2) Y
226+
@test_broken false
226227
# tr_macro_3 = @transport_reaction dZ $species3 # Currently does not work, something with meta programming.
227228

228229
@test isequal(tr_1, tr_macro_1)
229-
@test isequal(tr_2, tr_macro_2) # Unsure why these fails, since for components equality hold: `isequal(tr_1.species, tr_macro_1.species)` and `isequal(tr_1.rate, tr_macro_1.rate)` are both true.
230+
@test isequal(tr_2, tr_macro_2)
230231
# @test isequal(tr_3, tr_macro_3)
231232
end
232233

test/spatial_modelling/lattice_reaction_systems_ODEs.jl

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,24 +85,23 @@ end
8585

8686
### Tests Simulation Correctness ###
8787

88-
# Checks that non-spatial brusselator simulation is identical to all on an unconnected lattice.
88+
# Tests with non-Float64 parameter values.
8989
let
90-
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, unconnected_graph)
91-
u0 = [:X => 2.0 + 2.0 * rand(rng), :Y => 10.0 * (1.0 * rand(rng))]
92-
pV = brusselator_p
93-
pE = [:dX => 0.2]
94-
oprob_nonspatial = ODEProblem(brusselator_system, u0, (0.0, 100.0), pV)
95-
oprob_spatial = ODEProblem(lrs, u0, (0.0, 100.0), [pV; pE])
96-
sol_nonspatial = solve(oprob_nonspatial, QNDF(); abstol = 1e-12, reltol = 1e-12)
97-
sol_spatial = solve(oprob_spatial, QNDF(); abstol = 1e-12, reltol = 1e-12)
98-
99-
for i in 1:nv(unconnected_graph)
100-
@test all(isapprox.(sol_nonspatial.u[end],
101-
sol_spatial.u[end][((i - 1) * 2 + 1):((i - 1) * 2 + 2)]))
90+
lrs = LatticeReactionSystem(SIR_system, SIR_srs_2, very_small_2d_cartesian_grid)
91+
u0 = [:S => 990.0, :I => rand_v_vals(lrs), :R => 0.0]
92+
ps_1 = [ => 0.1, => 0.01, :dS => 0.01, :dI => 0.01, :dR => 0.01]
93+
ps_2 = [ => 1//10, => 1//100, :dS => 1//100, :dI => 1//100, :dR => 1//100]
94+
ps_3 = [ => 1//10, => 0.01, :dS => 0.01, :dI => 1//100, :dR => 0.01]
95+
sol_base = solve(ODEProblem(lrs, u0, (0.0, 100.0), ps_1), Rosenbrock23(); saveat = 0.1)
96+
for ps in [ps_1, ps_2, ps_3]
97+
for jac in [true, false], sparse in [true, false]
98+
oprob = ODEProblem(lrs, u0, (0.0, 100.0), ps; jac, sparse)
99+
@test sol_base solve(oprob, Rosenbrock23(); saveat = 0.1)
100+
end
102101
end
103102
end
104103

105-
# Compares Jacobian and forcing functions of spatial system to analytically computed on.
104+
# Compares Jacobian and forcing functions of spatial system to analytically computed ones.
106105
let
107106
# Creates LatticeReactionNetwork ODEProblem.
108107
rs = @reaction_network begin
@@ -119,7 +118,7 @@ let
119118
D_vals[1,2] = 0.2; D_vals[2,1] = 0.2;
120119
D_vals[2,3] = 0.3; D_vals[3,2] = 0.3;
121120
u0 = [:X => [1.0, 2.0, 3.0], :Y => 1.0]
122-
ps = [:pX => [2.0, 2.5, 3.0], :pY => 0.5, :d => 0.1, :D => D_vals]
121+
ps = [:pX => [2.0, 2.5, 3.0], :d => 0.1, :pY => 0.5, :D => D_vals]
123122
oprob = ODEProblem(lrs, u0, (0.0, 0.0), ps; jac=true, sparse=true)
124123

125124
# Creates manual f and jac functions.
@@ -172,7 +171,7 @@ let
172171

173172
# Sets test input values.
174173
u = rand(rng, 6)
175-
p = [rand(rng, 3), rand(rng, 1), rand(rng, 1)]
174+
p = [rand(rng, 3), ps[2][2], ps[3][2]]
176175

177176
# Tests forcing function.
178177
du1 = fill(0.0, 6)
@@ -614,6 +613,22 @@ let
614613
end
615614
end
616615

616+
# Tests with non-Int64 parameter values.
617+
let
618+
lrs = LatticeReactionSystem(SIR_system, SIR_srs_2, very_small_2d_cartesian_grid)
619+
u0 = [:S => 990.0, :I => rand_v_vals(lrs), :R => 0.0]
620+
ps_1 = [ => 0.1, => 0.01, :dS => 0.01, :dI => 0.01, :dR => 0.01]
621+
ps_2 = [ => Float32(0.1), => Float32(0.01), :dS => Float32(0.01), :dI => Float32(0.01), :dR => Float32(0.01)]
622+
ps_3 = [ => 1//10, => 0.01, :dS => 0.01, :dI => 1//100, :dR => Float32(0.01)]
623+
sol_base = solve(ODEProblem(lrs, u0, (0.0, 100.0), ps_1), Rosenbrock23(); savetat = 0.1)
624+
for ps in [ps_1, ps_2, ps_3]
625+
for jac in [true, false], sparse in [true, false]
626+
oprob = ODEProblem(lrs, u0, (0.0, 100.0), ps; jac, sparse)
627+
@test sol_base solve(oprob, Rosenbrock23(); savetat = 0.1)
628+
end
629+
end
630+
end
631+
617632
# Tests various types of numbers for initial conditions/parameters (e.g. Real numbers, Float32, etc.).
618633
let
619634
# Declare u0 versions.

test/spatial_modelling/lattice_reaction_systems_jumps.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ let
6464

6565
# Prepare various (diffusion) parameter input types.
6666
pE_1 = [:dI => 0.02, :dS => 0.01, :dR => 0.03]
67-
pE_2 = [:dI => 0.02, :dS => uniform_e_vals(lrs, 0.01), :dR => 0.03]
67+
dS_vals = spzeros(num_verts(lrs), num_verts(lrs))
68+
foreach(e -> (dS_vals[e[1], e[2]] = 0.01), edge_iterator(lrs))
69+
pE_2 = [:dI => 0.02, :dS => dS_vals, :dR => 0.03]
6870

6971
# Checks hopping rates and u0 are correct.
7072
true_u0 = [fill(1.0, 1, 25); fill(2.0, 1, 25); fill(3.0, 1, 25)]

0 commit comments

Comments
 (0)