Skip to content

Commit 9f3f704

Browse files
authored
Fix SBML import for parameters with piecewise initial assignments (#2980)
Parameters that are targets of parameter-dependent initial assignments are currently implemented as `w` expressions. Those symbols are currently not supported in event triggers. Previously, `w` symbols were only substituted in a subset of root functions, resulting in compilation failures due to undefined symbols for such models. Now `w` symbols are substituted in all trigger functions. Fixes PEtab v2 import for `Smith_BMCSystBiol2013`. Also, add `NULL` to the list of reserved symbols, fixing macro-redefinition compiler warnings when importing `Smith_BMCSystBiol2013`.
1 parent a666ef9 commit 9f3f704

File tree

3 files changed

+58
-31
lines changed

3 files changed

+58
-31
lines changed

python/sdist/amici/de_model.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,18 +1305,36 @@ def parse_events(self) -> None:
13051305
and replaces the formulae of the found roots by identifiers of AMICI's
13061306
Heaviside function implementation in the right-hand side
13071307
"""
1308+
# toposorted w_sym -> w_expr for substitution of 'w' in trigger function
1309+
# do only once. `w` is not modified during this function.
1310+
w_toposorted = toposort_symbols(
1311+
dict(
1312+
zip(
1313+
[expr.get_id() for expr in self._expressions],
1314+
[expr.get_val() for expr in self._expressions],
1315+
strict=True,
1316+
)
1317+
)
1318+
)
1319+
13081320
# Track all roots functions in the right-hand side
13091321
roots = copy.deepcopy(self._events)
13101322
for state in self._differential_states:
1311-
state.set_dt(self._process_heavisides(state.get_dt(), roots))
1323+
state.set_dt(
1324+
self._process_heavisides(state.get_dt(), roots, w_toposorted)
1325+
)
13121326

13131327
for expr in self._expressions:
1314-
expr.set_val(self._process_heavisides(expr.get_val(), roots))
1328+
expr.set_val(
1329+
self._process_heavisides(expr.get_val(), roots, w_toposorted)
1330+
)
13151331

13161332
# remove all possible Heavisides from roots, which may arise from
13171333
# the substitution of `'w'` in `_collect_heaviside_roots`
13181334
for root in roots:
1319-
root.set_val(self._process_heavisides(root.get_val(), roots))
1335+
root.set_val(
1336+
self._process_heavisides(root.get_val(), roots, w_toposorted)
1337+
)
13201338

13211339
# Now add the found roots to the model components
13221340
for root in roots:
@@ -1326,6 +1344,11 @@ def parse_events(self) -> None:
13261344
# add roots of heaviside functions
13271345
self.add_component(root)
13281346

1347+
# Substitute 'w' expressions into root expressions, to avoid rewriting
1348+
# 'root.cpp' and 'stau.cpp' headers to include 'w.h'.
1349+
for event in self.events():
1350+
event.set_val(event.get_val().subs(w_toposorted))
1351+
13291352
# re-order events - first those that require root tracking, then the others
13301353
constant_syms = set(self.sym("k")) | set(self.sym("p"))
13311354
self._events = list(
@@ -2391,7 +2414,7 @@ def _expr_is_time_dependent(self, expr: sp.Expr) -> bool:
23912414
expr_syms = {str(sym) for sym in expr.free_symbols}
23922415

23932416
# Check if the time variable is in the expression.
2394-
if "t" in expr_syms:
2417+
if amici_time_symbol.name in expr_syms:
23952418
return True
23962419

23972420
# Check if any time-dependent states are in the expression.
@@ -2464,33 +2487,11 @@ def _collect_heaviside_roots(
24642487

24652488
return root_funs
24662489

2467-
def _substitute_w_in_roots(
2468-
self,
2469-
root_funs: list[tuple[sp.Expr, sp.Expr]],
2470-
) -> list[tuple[sp.Expr, sp.Expr]]:
2471-
"""
2472-
Substitute 'w' expressions into root expressions, to avoid rewriting
2473-
'root.cpp' and 'stau.cpp' headers to include 'w.h'.
2474-
"""
2475-
w_sorted = toposort_symbols(
2476-
dict(
2477-
zip(
2478-
[expr.get_id() for expr in self._expressions],
2479-
[expr.get_val() for expr in self._expressions],
2480-
strict=True,
2481-
)
2482-
)
2483-
)
2484-
root_funs = [
2485-
(r[0].subs(w_sorted), r[1].subs(w_sorted)) for r in root_funs
2486-
]
2487-
2488-
return root_funs
2489-
24902490
def _process_heavisides(
24912491
self,
24922492
dxdt: sp.Expr,
24932493
roots: list[Event],
2494+
w_toposorted: dict[sp.Symbol, sp.Expr],
24942495
) -> sp.Expr:
24952496
"""
24962497
Parses the RHS of a state variable, checks for Heaviside functions,
@@ -2502,7 +2503,8 @@ def _process_heavisides(
25022503
right-hand side of state variable
25032504
:param roots:
25042505
list of known root functions with identifier
2505-
2506+
:param w_toposorted:
2507+
`w` symbols->expressions sorted in topological order
25062508
:returns:
25072509
dxdt with Heaviside functions replaced by amici helper variables
25082510
"""
@@ -2511,7 +2513,15 @@ def _process_heavisides(
25112513
heavisides = []
25122514
# run through the expression tree and get the roots
25132515
tmp_roots_old = self._collect_heaviside_roots((dxdt,))
2514-
tmp_roots_old = self._substitute_w_in_roots(tmp_roots_old)
2516+
# substitute 'w' symbols in the root expression by their equations,
2517+
# because currently,
2518+
# 1) root functions must not depend on 'w'
2519+
# 2) the check for time-dependence currently assumes only state
2520+
# variables are implicitly time-dependent
2521+
tmp_roots_old = [
2522+
(a.subs(w_toposorted), b.subs(w_toposorted))
2523+
for a, b in tmp_roots_old
2524+
]
25152525
for tmp_root_old, tmp_x0_old in unique_preserve_order(tmp_roots_old):
25162526
# we want unique identifiers for the roots
25172527
tmp_root_new = self._get_unique_root(tmp_root_old, roots)

python/sdist/amici/import_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,17 @@
1717
from sympy.logic.boolalg import BooleanAtom
1818
from toposort import toposort
1919

20-
RESERVED_SYMBOLS = ["x", "k", "p", "y", "w", "h", "t", "AMICI_EMPTY_BOLUS"]
20+
RESERVED_SYMBOLS = [
21+
"x",
22+
"k",
23+
"p",
24+
"y",
25+
"w",
26+
"h",
27+
"t",
28+
"AMICI_EMPTY_BOLUS",
29+
"NULL",
30+
]
2131

2232
try:
2333
import pysb

python/tests/test_bngl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22

3-
import amici
43
import numpy as np
54
import pytest
65

@@ -10,6 +9,7 @@
109
from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind
1110
from pysb.importers.bngl import model_from_bngl
1211
from pysb.simulator import ScipyOdeSimulator
12+
from contextlib import suppress
1313

1414
tests = [
1515
"CaOscillate_Func",
@@ -39,6 +39,13 @@
3939
@skip_on_valgrind
4040
@pytest.mark.parametrize("example", tests)
4141
def test_compare_to_pysb_simulation(example):
42+
import amici.import_utils
43+
44+
# allow "NULL" as model symbol
45+
# (used in CaOscillate_Func and Repressilator examples)
46+
with suppress(ValueError):
47+
amici.import_utils.RESERVED_SYMBOLS.remove("NULL")
48+
4249
atol = 1e-8
4350
rtol = 1e-8
4451

0 commit comments

Comments
 (0)