Skip to content

Commit bb6370d

Browse files
committed
save progress
1 parent ae077c2 commit bb6370d

File tree

1 file changed

+60
-3
lines changed

1 file changed

+60
-3
lines changed

src/spatial_reaction_systems/spatial_ODE_systems.jl

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ struct LatticeTransportODEFunction{R,S,T}
7474

7575
function LatticeTransportODEFunction(ofunc::S, vert_ps::Vector{Pair{BasicSymbolic{Real},Vector{T}}}, edge_ps,
7676
transport_rates::Vector{Pair{Int64, SparseMatrixCSC{T, Int64}}},
77-
lrs::LatticeReactionSystem) where {S,T}
77+
lrs::LatticeReactionSystem, jac_transport::Union{Nothing, SparseMatrixCSC{Float64, Int64}},
78+
sparse::Bool) where {S,T}
7879
# Records which parameters and rates are uniform and which are not.
7980
v_ps_idx_types = map(vp -> length(vp[2]) == 1, vert_ps)
8081
t_rate_idx_types = map(tr -> size(tr[2]) == (1,1), transport_rates)
@@ -108,12 +109,58 @@ struct LatticeTransportODEFunction{R,S,T}
108109
# Declares `work_ps` (used as storage during computation) and the edge iterator.
109110
work_ps = zeros(length(parameters(lrs)))
110111
edge_iterator = Catalyst.edge_iterator(lrs)
111-
new{S,T}(ofunc, num_verts(lrs), num_species(lrs), vert_p_idxs, edge_p_idxs, mtk_ps, p_setters,
112+
new{typeof(jac_transport),S,T}(ofunc, num_verts(lrs), num_species(lrs), vert_p_idxs, edge_p_idxs, mtk_ps, p_setters,
112113
nonspatial_rs_p_idxs, vert_ps, work_ps, v_ps_idx_types, transport_rates,
113-
t_rate_idx_types, leaving_rates, edge_iterator)
114+
t_rate_idx_types, leaving_rates, edge_iterator, jac_transport)
115+
end
116+
end
117+
118+
119+
# Defines the forcing functor's effect on the (spatial) ODE system.
120+
function (f_func::LatticeTransportODEFunction)(du::AbstractVector, u, p, t)
121+
# Updates for non-spatial reactions.
122+
for vert_i in 1:(f_func.num_verts)
123+
# Gets the indices of all the species at vertex i.
124+
idxs = get_indexes(vert_i, f_func.num_species)
125+
126+
# Updates the work vector to contain the vertex parameter values for vertex vert_i.
127+
update_work_vert_ps!(f_func, p, vert_i)
128+
129+
# Evaluate reaction contributions to du at vert_i.
130+
f_func.ofunc((@view du[idxs]), (@view u[idxs]), f_func.mtk_ps, t)
131+
end
132+
133+
# s_idx is the species index among transport species, s is the index among all species.
134+
# rates are the species' transport rates.
135+
for (s_idx, (s, rates)) in enumerate(f_func.transport_rates)
136+
# Rate for leaving source vertex vert_i.
137+
for vert_i in 1:(f_func.num_verts)
138+
idx_src = get_index(vert_i, s, f_func.num_species)
139+
du[idx_src] -= f_func.leaving_rates[s_idx, vert_i] * u[idx_src]
140+
end
141+
# Add rates for entering a destination vertex via an incoming edge.
142+
for e in f_func.edge_iterator
143+
idx_src = get_index(e[1], s, f_func.num_species)
144+
idx_dst = get_index(e[2], s, f_func.num_species)
145+
du[idx_dst] += get_transport_rate(s_idx, f_func, e) * u[idx_src]
146+
end
114147
end
115148
end
116149

150+
# Defines the Jacobian functor's effect on the (spatial) ODE system.
151+
function (jac_func::LatticeTransportODEFunction)(J::AbstractMatrix, u, p, t)
152+
J .= 0.0
153+
154+
# Update the Jacobian from non-spatial reaction terms.
155+
for vert_i in 1:(jac_func.num_verts)
156+
idxs = get_indexes(vert_i, jac_func.num_species)
157+
update_work_vert_ps!(jac_func, p, vert_i)
158+
jac_func.ofunc.jac((@view J[idxs, idxs]), (@view u[idxs]), jac_func.mtk_ps, t)
159+
end
160+
161+
# Updates for the spatial reactions (adds the Jacobian values from the transportation reactions).
162+
J .+= jac_func.jac_transport
163+
end
117164

118165

119166

@@ -399,6 +446,16 @@ function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{Basi
399446
J = nothing
400447
end
401448

449+
ofunc_dense = ODEFunction(osys; jac = true, sparse = false)
450+
ofunc_sparse = ODEFunction(osys; jac = true, sparse = true)
451+
jac_transport = build_jac_prototype(ofunc_sparse.jac_prototype, transport_rates, lrs; set_nonzero = true)
452+
f = LatticeTransportODEFunction(ofunc_sparse, vert_ps, edge_ps, transport_rates, lrs, jac_transport, sparse)
453+
if jac
454+
J = f
455+
else
456+
J = nothing
457+
end
458+
402459
return ODEFunction(f; jac = J, jac_prototype = jac_prototype)
403460
end
404461

0 commit comments

Comments
 (0)