@@ -74,7 +74,8 @@ struct LatticeTransportODEFunction{R,S,T}
74
74
75
75
function LatticeTransportODEFunction (ofunc:: S , vert_ps:: Vector{Pair{BasicSymbolic{Real},Vector{T}}} , edge_ps,
76
76
transport_rates:: Vector{Pair{Int64, SparseMatrixCSC{T, Int64}}} ,
77
- lrs:: LatticeReactionSystem ) where {S,T}
77
+ lrs:: LatticeReactionSystem , jac_transport:: Union{Nothing, SparseMatrixCSC{Float64, Int64}} ,
78
+ sparse:: Bool ) where {S,T}
78
79
# Records which parameters and rates are uniform and which are not.
79
80
v_ps_idx_types = map (vp -> length (vp[2 ]) == 1 , vert_ps)
80
81
t_rate_idx_types = map (tr -> size (tr[2 ]) == (1 ,1 ), transport_rates)
@@ -108,12 +109,58 @@ struct LatticeTransportODEFunction{R,S,T}
108
109
# Declares `work_ps` (used as storage during computation) and the edge iterator.
109
110
work_ps = zeros (length (parameters (lrs)))
110
111
edge_iterator = Catalyst. edge_iterator (lrs)
111
- new {S,T} (ofunc, num_verts (lrs), num_species (lrs), vert_p_idxs, edge_p_idxs, mtk_ps, p_setters,
112
+ new {typeof(jac_transport), S,T} (ofunc, num_verts (lrs), num_species (lrs), vert_p_idxs, edge_p_idxs, mtk_ps, p_setters,
112
113
nonspatial_rs_p_idxs, vert_ps, work_ps, v_ps_idx_types, transport_rates,
113
- t_rate_idx_types, leaving_rates, edge_iterator)
114
+ t_rate_idx_types, leaving_rates, edge_iterator, jac_transport)
115
+ end
116
+ end
117
+
118
+
119
+ # Defines the forcing functor's effect on the (spatial) ODE system.
120
+ function (f_func:: LatticeTransportODEFunction )(du:: AbstractVector , u, p, t)
121
+ # Updates for non-spatial reactions.
122
+ for vert_i in 1 : (f_func. num_verts)
123
+ # Gets the indices of all the species at vertex i.
124
+ idxs = get_indexes (vert_i, f_func. num_species)
125
+
126
+ # Updates the work vector to contain the vertex parameter values for vertex vert_i.
127
+ update_work_vert_ps! (f_func, p, vert_i)
128
+
129
+ # Evaluate reaction contributions to du at vert_i.
130
+ f_func. ofunc ((@view du[idxs]), (@view u[idxs]), f_func. mtk_ps, t)
131
+ end
132
+
133
+ # s_idx is the species index among transport species, s is the index among all species.
134
+ # rates are the species' transport rates.
135
+ for (s_idx, (s, rates)) in enumerate (f_func. transport_rates)
136
+ # Rate for leaving source vertex vert_i.
137
+ for vert_i in 1 : (f_func. num_verts)
138
+ idx_src = get_index (vert_i, s, f_func. num_species)
139
+ du[idx_src] -= f_func. leaving_rates[s_idx, vert_i] * u[idx_src]
140
+ end
141
+ # Add rates for entering a destination vertex via an incoming edge.
142
+ for e in f_func. edge_iterator
143
+ idx_src = get_index (e[1 ], s, f_func. num_species)
144
+ idx_dst = get_index (e[2 ], s, f_func. num_species)
145
+ du[idx_dst] += get_transport_rate (s_idx, f_func, e) * u[idx_src]
146
+ end
114
147
end
115
148
end
116
149
150
+ # Defines the Jacobian functor's effect on the (spatial) ODE system.
151
+ function (jac_func:: LatticeTransportODEFunction )(J:: AbstractMatrix , u, p, t)
152
+ J .= 0.0
153
+
154
+ # Update the Jacobian from non-spatial reaction terms.
155
+ for vert_i in 1 : (jac_func. num_verts)
156
+ idxs = get_indexes (vert_i, jac_func. num_species)
157
+ update_work_vert_ps! (jac_func, p, vert_i)
158
+ jac_func. ofunc. jac ((@view J[idxs, idxs]), (@view u[idxs]), jac_func. mtk_ps, t)
159
+ end
160
+
161
+ # Updates for the spatial reactions (adds the Jacobian values from the transportation reactions).
162
+ J .+ = jac_func. jac_transport
163
+ end
117
164
118
165
119
166
@@ -399,6 +446,16 @@ function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{Basi
399
446
J = nothing
400
447
end
401
448
449
+ ofunc_dense = ODEFunction (osys; jac = true , sparse = false )
450
+ ofunc_sparse = ODEFunction (osys; jac = true , sparse = true )
451
+ jac_transport = build_jac_prototype (ofunc_sparse. jac_prototype, transport_rates, lrs; set_nonzero = true )
452
+ f = LatticeTransportODEFunction (ofunc_sparse, vert_ps, edge_ps, transport_rates, lrs, jac_transport, sparse)
453
+ if jac
454
+ J = f
455
+ else
456
+ J = nothing
457
+ end
458
+
402
459
return ODEFunction (f; jac = J, jac_prototype = jac_prototype)
403
460
end
404
461
0 commit comments