Skip to content

Commit 621d664

Browse files
committed
use generated rate law function
1 parent 4341962 commit 621d664

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
1616
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
1717
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1818
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
19+
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
1920
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2021
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
2122
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

src/Catalyst.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ const MT = ModelingToolkit
1616
using Unitful
1717
@reexport using ModelingToolkit
1818
using Symbolics
19+
20+
using RuntimeGeneratedFunctions
21+
RuntimeGeneratedFunctions.init(@__MODULE__)
22+
1923
import Symbolics: BasicSymbolic
2024
import SymbolicUtils
2125
using ModelingToolkit: Symbolic, value, istree, get_states, get_ps, get_iv, get_systems,

src/spatial_reaction_systems/utility.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,16 +204,17 @@ end
204204
# Else a vector with each value corresponding to the rate at one specific edge.
205205
function compute_transport_rates(rate_law::Num,
206206
p_val_dict::Dict{SymbolicUtils.BasicSymbolic{Real}, Vector{Float64}}, num_edges::Int64)
207-
relevant_ps = Symbolics.get_variables(rate_law)
208-
209-
# If all these parameters are spatially uniform. `rates` becomes a vector with 1 value.
210-
if all(length(p_val_dict[P]) == 1 for P in relevant_ps)
211-
rates = [substitute(rate_law, Dict(p => p_val_dict[p][1] for p in relevant_ps))]
207+
# Finds parameters involved in rate and create a function evaluating teh rate law.
208+
relevant_ps = Symbolics.get_variables(rate_law)
209+
rate_law_func = drop_expr(@RuntimeGeneratedFunction(build_function(rate_law, relevant_ps...)))
212210

211+
# If all these parameters are spatially uniform. `rates` becomes a vector with 1 value.
212+
if all(length(p_val_dict[P]) == 1 for P in relevant_ps)
213+
rates = [rate_law_func([p_val_dict[p][1] for p in relevant_ps]...)]
213214
# If at least on parameter the rate depends on have a value varying across all edges,
214215
# we have to compute one rate value for each edge.
215216
else
216-
rates = [substitute(rate_law, Dict(p => get_component_value(p_val_dict[p], idxE) for p in relevant_ps))
217+
rates = [rate_law_func([get_component_value(p_val_dict[p], idxE) for p in relevant_ps]...)
217218
for idxE in 1:num_edges]
218219
end
219220
return Symbolics.value.(rates)

0 commit comments

Comments
 (0)