Skip to content

Commit 8f8290c

Browse files
authored
Merge pull request #929 from SciML/fix_expand_funcs
[v14 - Ready] Fix bug where `expand_registered_functions` mutated original reaction system
2 parents 6ccbc1f + 2f07b71 commit 8f8290c

File tree

2 files changed

+120
-11
lines changed

2 files changed

+120
-11
lines changed

src/registered_functions.jl

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,37 @@ function Symbolics.derivative(::typeof(hillar), args::NTuple{5, Any}, ::Val{5})
109109
(args[1]^args[5] + args[2]^args[5] + args[4]^args[5])^2
110110
end
111111

112+
# Tuple storing all registered function (for use in various functionalities).
113+
const registered_funcs = (mm, mmr, hill, hillr, hillar)
114+
112115
### Custom CRN FUnction-related Functions ###
113116

114117
"""
115-
expand_registered_functions(expr)
118+
expand_registered_functions(in)
116119
117-
Takes an expression, and expands registered function expressions. E.g. `mm(X,v,K)` is replaced with v*X/(X+K). Currently supported functions: `mm`, `mmr`, `hill`, `hillr`, and `hill`.
120+
Takes an expression, and expands registered function expressions. E.g. `mm(X,v,K)` is replaced
121+
with v*X/(X+K). Currently supported functions: `mm`, `mmr`, `hill`, `hillr`, and `hill`. Can
122+
be applied to a reaction system, a reaction, an equation, or a symbolic expression. The input
123+
is not modified, while an output with any functions expanded is returned. If applied to a
124+
reaction system model, any cached network properties are reset.
118125
"""
119126
function expand_registered_functions(expr)
120-
iscall(expr) || return expr
127+
if hasnode(is_catalyst_function, expr)
128+
expr = replacenode(expr, expand_catalyst_function)
129+
end
130+
return expr
131+
end
132+
133+
# Checks whether an expression corresponds to a catalyst function call (e.g. `mm(X,v,K)`).
134+
function is_catalyst_function(expr)
135+
iscall(expr) || (return false)
136+
return operation(expr) in registered_funcs
137+
end
138+
139+
# If the input expression corresponds to a catalyst function call (e.g. `mm(X,v,K)`), returns
140+
# it in its expanded form. If not, returns the input expression.
141+
function expand_catalyst_function(expr)
142+
is_catalyst_function(expr) || (return expr)
121143
args = arguments(expr)
122144
if operation(expr) == Catalyst.mm
123145
return args[2] * args[1] / (args[1] + args[3])
@@ -131,23 +153,50 @@ function expand_registered_functions(expr)
131153
return args[3] * (args[1]^args[5]) /
132154
((args[1])^args[5] + (args[2])^args[5] + (args[4])^args[5])
133155
end
134-
for i in 1:length(args)
135-
args[i] = expand_registered_functions(args[i])
136-
end
137-
return expr
138156
end
157+
139158
# If applied to a Reaction, return a reaction with its rate modified.
140159
function expand_registered_functions(rx::Reaction)
141160
Reaction(expand_registered_functions(rx.rate), rx.substrates, rx.products,
142161
rx.substoich, rx.prodstoich, rx.netstoich, rx.only_use_rate, rx.metadata)
143162
end
144-
# If applied to a Equation, returns it with it applied to lhs and rhs
163+
164+
# If applied to a Equation, returns it with it applied to lhs and rhs.
145165
function expand_registered_functions(eq::Equation)
146166
return expand_registered_functions(eq.lhs) ~ expand_registered_functions(eq.rhs)
147167
end
168+
169+
# If applied to a continuous event, returns it applied to eqs and affect.
170+
function expand_registered_functions(ce::ModelingToolkit.SymbolicContinuousCallback)
171+
eqs = expand_registered_functions(ce.eqs)
172+
affect = expand_registered_functions(ce.affect)
173+
return ModelingToolkit.SymbolicContinuousCallback(eqs, affect)
174+
end
175+
176+
# If applied to a discrete event, returns it applied to condition and affects.
177+
function expand_registered_functions(de::ModelingToolkit.SymbolicDiscreteCallback)
178+
condition = expand_registered_functions(de.condition)
179+
affects = expand_registered_functions(de.affects)
180+
return ModelingToolkit.SymbolicDiscreteCallback(condition, affects)
181+
end
182+
183+
# If applied to a vector, applies it to every element in the vector.
184+
function expand_registered_functions(vec::Vector)
185+
return [Catalyst.expand_registered_functions(element) for element in vec]
186+
end
187+
148188
# If applied to a ReactionSystem, applied function to all Reactions and other Equations, and return updated system.
189+
# Currently, `ModelingToolkit.has_X_events` returns `true` even if event vector is empty (hence
190+
# this function cannot be used).
149191
function expand_registered_functions(rs::ReactionSystem)
150-
@set! rs.eqs = [Catalyst.expand_registered_functions(eq) for eq in get_eqs(rs)]
151-
@set! rs.rxs = [Catalyst.expand_registered_functions(rx) for rx in get_rxs(rs)]
192+
@set! rs.eqs = Catalyst.expand_registered_functions(get_eqs(rs))
193+
@set! rs.rxs = Catalyst.expand_registered_functions(get_rxs(rs))
194+
if !isempty(ModelingToolkit.get_continuous_events(rs))
195+
@set! rs.continuous_events = Catalyst.expand_registered_functions(ModelingToolkit.get_continuous_events(rs))
196+
end
197+
if !isempty(ModelingToolkit.get_discrete_events(rs))
198+
@set! rs.discrete_events = Catalyst.expand_registered_functions(ModelingToolkit.get_discrete_events(rs))
199+
end
200+
reset_networkproperties!(rs)
152201
return rs
153202
end

test/reactionsystem_core/custom_crn_functions.jl

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# Fetch packages.
44
using Catalyst, Test
5+
using ModelingToolkit: get_continuous_events, get_discrete_events
56
using Symbolics: derivative
67

78
# Sets stable rng number.
@@ -154,4 +155,63 @@ let
154155
@test isequal(Catalyst.expand_registered_functions(eq3), 0 ~ V * (X^N) / (X^N + K^N))
155156
@test isequal(Catalyst.expand_registered_functions(eq4), 0 ~ V * (K^N) / (X^N + K^N))
156157
@test isequal(Catalyst.expand_registered_functions(eq5), 0 ~ V * (X^N) / (X^N + Y^N + K^N))
157-
end
158+
end
159+
160+
# Ensures that original system is not modified.
161+
let
162+
# Create model with a registered function.
163+
@species X(t)
164+
@variables V(t)
165+
@parameters v K
166+
eqs = [
167+
Reaction(mm(X,v,K), [], [X]),
168+
mm(V,v,K) ~ V + 1
169+
]
170+
@named rs = ReactionSystem(eqs, t)
171+
172+
# Check that `expand_registered_functions` does not mutate original model.
173+
rs_expanded_funcs = Catalyst.expand_registered_functions(rs)
174+
@test isequal(only(Catalyst.get_rxs(rs)).rate, Catalyst.mm(X,v,K))
175+
@test isequal(only(Catalyst.get_rxs(rs_expanded_funcs)).rate, v*X/(X + K))
176+
@test isequal(last(Catalyst.get_eqs(rs)).lhs, Catalyst.mm(V,v,K))
177+
@test isequal(last(Catalyst.get_eqs(rs_expanded_funcs)).lhs, v*V/(V + K))
178+
end
179+
180+
# Tests on model with events.
181+
let
182+
# Creates a model, saves it, and creates an expanded version.
183+
rs = @reaction_network begin
184+
@continuous_events begin
185+
[mm(X,v,K) ~ 1.0] => [X ~ X]
186+
end
187+
@discrete_events begin
188+
[1.0] => [X ~ mmr(X,v,K) + Y*(v + K)]
189+
1.0 => [X ~ X]
190+
(hill(X,v,K,n) > 1000.0) => [X ~ hillr(X,v,K,n) + 2]
191+
end
192+
v0 + hillar(X,Y,v,K,n), X --> Y
193+
end
194+
rs_saved = deepcopy(rs)
195+
rs_expanded = Catalyst.expand_registered_functions(rs)
196+
197+
# Checks that the original model is unchanged (equality currently does not consider events).
198+
@test rs == rs_saved
199+
@test get_continuous_events(rs) == get_continuous_events(rs_saved)
200+
@test get_discrete_events(rs) == get_discrete_events(rs_saved)
201+
202+
# Checks that the new system is expanded.
203+
@unpack v0, X, Y, v, K, n = rs
204+
continuous_events = [
205+
[v*X/(X + K) ~ 1.0] => [X ~ X]
206+
]
207+
discrete_events = [
208+
[1.0] => [X ~ v*K/(X + K) + Y*(v + K)]
209+
1.0 => [X ~ X]
210+
(v * (X^n) / (X^n + K^n) > 1000.0) => [X ~ v * (K^n) / (X^n + K^n) + 2]
211+
]
212+
continuous_events = ModelingToolkit.SymbolicContinuousCallback.(continuous_events)
213+
discrete_events = ModelingToolkit.SymbolicDiscreteCallback.(discrete_events)
214+
@test isequal(only(Catalyst.get_rxs(rs_expanded)).rate, v0 + v * (X^n) / (X^n + Y^n + K^n))
215+
@test isequal(get_continuous_events(rs_expanded), continuous_events)
216+
@test isequal(get_discrete_events(rs_expanded), discrete_events)
217+
end

0 commit comments

Comments
 (0)