1
1
# ## Spatial ODE Functor Structure ###
2
2
3
- # Functor with information about a spatial Lattice Reaction ODE;s forcing and Jacobian functions.
3
+ # Functor with information about a spatial Lattice Reaction ODEs forcing and Jacobian functions.
4
4
# Also used as ODE Function input to corresponding `ODEProblem`.
5
5
struct LatticeTransportODEFunction{P,Q,R,S,T}
6
6
"""
@@ -59,40 +59,60 @@ struct LatticeTransportODEFunction{P,Q,R,S,T}
59
59
used).
60
60
"""
61
61
jac_transport:: T
62
+ """ Whether sparse jacobian representation is used. """
63
+ sparse:: Bool
64
+ """ Remove when we add this as problem metadata"""
65
+ lrs:: LatticeReactionSystem
62
66
63
67
function LatticeTransportODEFunction (ofunc:: P , ps:: Vector{<:Pair} ,
64
68
lrs:: LatticeReactionSystem , transport_rates:: Vector{Pair{Int64, SparseMatrixCSC{S, Int64}}} ,
65
- jac_transport:: Union{Nothing, Matrix{S}, SparseMatrixCSC{S, Int64}} ) where {P,S}
66
-
67
- # Creates a vector with the heterogeneous vertex parameters' indexes in the full parameter vector.
68
- p_dict = Dict (ps)
69
- heterogeneous_vert_p_idxs = findall ((p_dict[p] isa Vector) && (length (p_dict[p]) > 1 )
70
- for p in parameters (lrs))
71
-
72
- # Creates the MTKParameters structure and `p_setters` vector (which are used to manage
73
- # the vertex parameter values during the simulations).
74
- nonspatial_osys = complete (convert (ODESystem, reactionsystem (lrs)))
75
- p_init = [p => p_dict[p][1 ] for p in parameters (nonspatial_osys)]
76
- mtk_ps = MT. MTKParameters (nonspatial_osys, p_init)
77
- p_setters = [MT. setp (nonspatial_osys, p) for p in parameters (lrs)[heterogeneous_vert_p_idxs]]
78
-
79
- # Computes the transport rate type vector and leaving rate matrix.
80
- t_rate_idx_types = [size (tr[2 ]) == (1 ,1 ) for tr in transport_rates]
81
- leaving_rates = zeros (length (transport_rates), num_verts (lrs))
82
- for (s_idx, tr_pair) in enumerate (transport_rates)
83
- for e in Catalyst. edge_iterator (lrs)
84
- # Updates the exit rate for species s_idx from vertex e.src.
85
- leaving_rates[s_idx, e[1 ]] += get_transport_rate (tr_pair[2 ], e, t_rate_idx_types[s_idx])
86
- end
87
- end
69
+ jac_transport:: Union{Nothing, Matrix{S}, SparseMatrixCSC{S, Int64}} , sparse) where {P,S}
70
+ # Computes `LatticeTransportODEFunction` functor fields.
71
+ heterogeneous_vert_p_idxs = make_heterogeneous_vert_p_idxs (ps, lrs)
72
+ mtk_ps, p_setters = make_mtk_ps_structs (ps, lrs, heterogeneous_vert_p_idxs)
73
+ t_rate_idx_types, leaving_rates = make_t_types_and_leaving_rates (transport_rates, lrs)
88
74
89
75
# Creates and returns the `LatticeTransportODEFunction` functor.
90
76
new {P,typeof(mtk_ps),typeof(p_setters),S,typeof(jac_transport)} (ofunc, num_verts (lrs),
91
77
num_species (lrs), heterogeneous_vert_p_idxs, mtk_ps, p_setters, transport_rates,
92
- t_rate_idx_types, leaving_rates, Catalyst. edge_iterator (lrs), jac_transport)
78
+ t_rate_idx_types, leaving_rates, Catalyst. edge_iterator (lrs), jac_transport, sparse, lrs)
79
+ end
80
+ end
81
+
82
+ # `LatticeTransportODEFunction` helper functions (re used by rebuild function later on).
83
+
84
+ # Creates a vector with the heterogeneous vertex parameters' indexes in the full parameter vector.
85
+ function make_heterogeneous_vert_p_idxs (ps, lrs)
86
+ p_dict = Dict (ps)
87
+ return findall ((p_dict[p] isa Vector) && (length (p_dict[p]) > 1 ) for p in parameters (lrs))
88
+ end
89
+
90
+ # Creates the MTKParameters structure and `p_setters` vector (which are used to manage
91
+ # the vertex parameter values during the simulations).
92
+ function make_mtk_ps_structs (ps, lrs, heterogeneous_vert_p_idxs)
93
+ p_dict = Dict (ps)
94
+ nonspatial_osys = complete (convert (ODESystem, reactionsystem (lrs)))
95
+ p_init = [p => p_dict[p][1 ] for p in parameters (nonspatial_osys)]
96
+ mtk_ps = MT. MTKParameters (nonspatial_osys, p_init)
97
+ p_setters = [MT. setp (nonspatial_osys, p) for p in parameters (lrs)[heterogeneous_vert_p_idxs]]
98
+ return mtk_ps, p_setters
99
+ end
100
+
101
+ # Computes the transport rate type vector and leaving rate matrix.
102
+ function make_t_types_and_leaving_rates (transport_rates, lrs)
103
+ t_rate_idx_types = [size (tr[2 ]) == (1 ,1 ) for tr in transport_rates]
104
+ leaving_rates = zeros (length (transport_rates), num_verts (lrs))
105
+ for (s_idx, tr_pair) in enumerate (transport_rates)
106
+ for e in Catalyst. edge_iterator (lrs)
107
+ # Updates the exit rate for species s_idx from vertex e.src.
108
+ leaving_rates[s_idx, e[1 ]] += get_transport_rate (tr_pair[2 ], e, t_rate_idx_types[s_idx])
109
+ end
93
110
end
111
+ return t_rate_idx_types, leaving_rates
94
112
end
95
113
114
+ # ## Spatial ODE Functor Functions ###
115
+
96
116
# Defines the functor's effect when applied as a forcing function.
97
117
function (lt_ofun:: LatticeTransportODEFunction )(du:: AbstractVector , u, p, t)
98
118
# Updates for non-spatial reactions.
@@ -198,7 +218,7 @@ function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{R,Ve
198
218
transport_rates = make_sidxs_to_transrate_map (vert_ps, edge_ps, lrs)
199
219
200
220
# Depending on Jacobian and sparsity options, computes the Jacobian transport matrix and prototype.
201
- if sparse && ! jac
221
+ if ! sparse && ! jac
202
222
jac_transport = nothing
203
223
jac_prototype = nothing
204
224
else
@@ -209,7 +229,7 @@ function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{R,Ve
209
229
end
210
230
211
231
# Creates the `LatticeTransportODEFunction` functor (if `jac`, sets it as the Jacobian as well).
212
- f = LatticeTransportODEFunction (ofunc_dense, [vert_ps; edge_ps], lrs, transport_rates, jac_transport)
232
+ f = LatticeTransportODEFunction (ofunc_dense, [vert_ps; edge_ps], lrs, transport_rates, jac_transport, sparse )
213
233
J = (jac ? f : nothing )
214
234
215
235
# Extracts the `Symbol` form for species and parameters. Creates and returns the `ODEFunction`.
@@ -267,23 +287,95 @@ function build_jac_prototype(ns_jac_prototype::SparseMatrixCSC{Float64, Int64},
267
287
end
268
288
end
269
289
270
- # Create a sparse Jacobian prototype with 0-valued entries.
290
+ # Create a sparse Jacobian prototype with 0-valued entries. If requested,
291
+ # updates values with non-zero entries.
271
292
jac_prototype = sparse (i_idxs, j_idxs, zeros (T, num_entries))
293
+ set_nonzero && set_jac_transport_values! (jac_prototype, transport_rates, lrs)
272
294
273
- # Set element values.
274
- if set_nonzero
275
- for (s, rates) in transport_rates, e in edge_iterator (lrs)
276
- idx_src = get_index (e[1 ], s, num_species (lrs))
277
- idx_dst = get_index (e[2 ], s, num_species (lrs))
278
- val = get_transport_rate (rates, e, size (rates)== (1 ,1 ))
295
+ return jac_prototype
296
+ end
279
297
280
- # Term due to species leaving source vertex.
281
- jac_prototype[idx_src, idx_src] -= val
298
+ # For a Jacobian prototype with zero-valued entries. Set entry values according to a set of
299
+ # transport reaction values.
300
+ function set_jac_transport_values! (jac_prototype, transport_rates, lrs)
301
+ for (s, rates) in transport_rates, e in edge_iterator (lrs)
302
+ idx_src = get_index (e[1 ], s, num_species (lrs))
303
+ idx_dst = get_index (e[2 ], s, num_species (lrs))
304
+ val = get_transport_rate (rates, e, size (rates)== (1 ,1 ))
282
305
283
- # Term due to species arriving to destination vertex.
284
- jac_prototype[idx_src, idx_dst] += val
285
- end
306
+ # Term due to species leaving source vertex.
307
+ jac_prototype[idx_src, idx_src] -= val
308
+
309
+ # Term due to species arriving to destination vertex.
310
+ jac_prototype[idx_src, idx_dst] += val
286
311
end
312
+ end
287
313
288
- return jac_prototype
314
+ # ## Functor Updating Functionality ###
315
+
316
+ # Function for rebuilding a `LatticeReactionSystem` `ODEProblem` after it has been updated.
317
+ function rebuild_lat_internals! (oprob:: ODEProblem )
318
+ rebuild_lat_internals! (oprob. f. f, oprob. p, oprob. f. f. lrs)
319
+ end
320
+
321
+ # Function for rebuilding a `LatticeReactionSystem` integrator after it has been updated.
322
+ # We could specify `integrator`'s type, but that required adding OrdinaryDiffEq as a direct
323
+ # dependency of Catalyst.
324
+ function rebuild_lat_internals! (integrator)
325
+ rebuild_lat_internals! (integrator. f. f, integrator. p, integrator. f. f. lrs)
289
326
end
327
+
328
+ # Function which rebuilds a `LatticeTransportODEFunction` functor for a new parameter set.
329
+ function rebuild_lat_internals! (lt_ofun:: LatticeTransportODEFunction , ps_new, lrs:: LatticeReactionSystem )
330
+ # Computes Jacobian properties.
331
+ jac = ! isnothing (lt_ofun. jac_transport)
332
+ sparse = lt_ofun. sparse
333
+
334
+ # Recreates the new parameters on the requisite form.
335
+ ps_new = [(length (p) == 1 ) ? p[1 ] : p for p in deepcopy (ps_new)]
336
+ ps_new = [p => p_val for (p, p_val) in zip (parameters (lrs), deepcopy (ps_new))]
337
+ vert_ps, edge_ps = lattice_process_p (ps_new, vertex_parameters (lrs), edge_parameters (lrs), lrs)
338
+ ps_new = [vert_ps; edge_ps]
339
+
340
+ # Creates the new transport rates and transport Jacobian part.
341
+ transport_rates = make_sidxs_to_transrate_map (vert_ps, edge_ps, lrs)
342
+ if ! isnothing (lt_ofun. jac_transport)
343
+ lt_ofun. jac_transport .= 0.0
344
+ set_jac_transport_values! (lt_ofun. jac_transport, transport_rates, lrs)
345
+ end
346
+
347
+ # Computes new field values.
348
+ heterogeneous_vert_p_idxs = make_heterogeneous_vert_p_idxs (ps_new, lrs)
349
+ mtk_ps, p_setters = make_mtk_ps_structs (ps_new, lrs, heterogeneous_vert_p_idxs)
350
+ t_rate_idx_types, leaving_rates = make_t_types_and_leaving_rates (transport_rates, lrs)
351
+
352
+ # Updates functor fields.
353
+ replace_vec! (lt_ofun. heterogeneous_vert_p_idxs, heterogeneous_vert_p_idxs)
354
+ replace_vec! (lt_ofun. p_setters, p_setters)
355
+ replace_vec! (lt_ofun. transport_rates, transport_rates)
356
+ replace_vec! (lt_ofun. t_rate_idx_types, t_rate_idx_types)
357
+ lt_ofun. leaving_rates .= leaving_rates
358
+
359
+ # Updating the `MTKParameters` structure is a bit more complicated.
360
+ p_dict = Dict (ps_new)
361
+ osys = complete (convert (ODESystem, reactionsystem (lrs)))
362
+ for p in parameters (osys)
363
+ MT. setp (osys, p)(lt_ofun. mtk_ps, (p_dict[p] isa Number) ? p_dict[p] : p_dict[p][1 ])
364
+ end
365
+
366
+ return nothing
367
+ end
368
+
369
+ # Specialised function which replaced one vector in another in a mutating way.
370
+ # Required to update the vectors in the `LatticeTransportODEFunction` functor.
371
+ function replace_vec! (vec1, vec2)
372
+ l1 = length (vec1)
373
+ l2 = length (vec2)
374
+
375
+ # Updates the fields, then deletes superfluous fields, or additional ones.
376
+ for (i, v) in enumerate (vec2[1 : min (l1, l2)])
377
+ vec1[i] = v
378
+ end
379
+ foreach (idx -> deleteat! (vec1, idx), l1: - 1 : (l2 + 1 ))
380
+ foreach (val -> push! (vec1, val), vec2[l1+ 1 : l2])
381
+ end
0 commit comments