Skip to content

Commit b12c68a

Browse files
authored
Fix incorrect initialization for non-constant initial assignments (#2939)
Fixes a bug where targets of initial assignments are not initialized correctly if the initial assignment expression is implicitly time-dependent. Fixes #2936. * eliminate state variables from x0 after replacing rate-of expressions * implement supposed-to-be static expressions that have non-constant entities in their initial condition as state variables instead of expressions.
1 parent 032c21a commit b12c68a

File tree

4 files changed

+111
-18
lines changed

4 files changed

+111
-18
lines changed

pytest.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ filterwarnings =
3030
ignore:jax.* is deprecated:DeprecationWarning
3131

3232

33-
norecursedirs = .git amici_models build doc documentation matlab models ThirdParty amici sdist examples
33+
norecursedirs = .git amici_models build doc documentation matlab models ThirdParty amici sdist examples *build*

python/sdist/amici/de_model.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -454,14 +454,6 @@ def get_rate(symbol: sp.Symbol):
454454
)
455455
self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs)
456456

457-
# replace rateOf-instances in x0 by xdot equation
458-
for i_state in range(len(self.eq("x0"))):
459-
new, replacement = self._eqs["x0"][i_state].replace(
460-
rate_of_func, get_rate, map=True
461-
)
462-
if replacement:
463-
self._eqs["x0"][i_state] = new
464-
465457
# replace rateOf-instances in w by xdot equation
466458
# here we may need toposort, as xdot may depend on w
467459
made_substitutions = False
@@ -509,6 +501,30 @@ def get_rate(symbol: sp.Symbol):
509501
self._syms["w"] = sp.Matrix(topo_expr_syms)
510502
self._eqs["w"] = sp.Matrix(list(w_sorted.values()))
511503

504+
# replace rateOf-instances in x0 by xdot equation
505+
# indices of state variables whose x0 was modified
506+
changed_indices = []
507+
for i_state in range(len(self.eq("x0"))):
508+
new, replacement = self._eqs["x0"][i_state].replace(
509+
rate_of_func, get_rate, map=True
510+
)
511+
if replacement:
512+
self._eqs["x0"][i_state] = new
513+
changed_indices.append(i_state)
514+
if changed_indices:
515+
# Replace any newly introduced state variables
516+
# by their x0 expressions.
517+
# Also replace any newly introduced `w` symbols by their
518+
# expressions (after `w` was toposorted above).
519+
subs = toposort_symbols(
520+
dict(zip(self.sym("x_rdata"), self.eq("x0"), strict=True))
521+
)
522+
subs = dict(zip(self._syms["w"], self.eq("w"), strict=True)) | subs
523+
for i_state in changed_indices:
524+
self._eqs["x0"][i_state] = smart_subs_dict(
525+
self._eqs["x0"][i_state], subs
526+
)
527+
512528
for component in chain(
513529
self.observables(),
514530
self.events(),

python/sdist/amici/sbml_import.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,13 +1386,20 @@ def _process_parameters(
13861386
# Parameters that need to be turned into expressions or species
13871387
# so far, this concerns parameters with symbolic initial assignments
13881388
# (those have been skipped above) that are not rate rule targets
1389+
1390+
# Set of symbols in initial assignments that still allows handling them
1391+
# via amici expressions
1392+
syms_allowed_in_expr_ia = set(self.symbols[SymbolId.PARAMETER]) | set(
1393+
self.symbols[SymbolId.FIXED_PARAMETER]
1394+
)
1395+
13891396
for par in self.sbml.getListOfParameters():
13901397
if (
13911398
(ia := par_id_to_ia.get(par.getId())) is not None
13921399
and not ia.is_Number
13931400
and not self.is_rate_rule_target(par)
13941401
):
1395-
if not ia.has(sbml_time_symbol):
1402+
if not (ia.free_symbols - syms_allowed_in_expr_ia):
13961403
self.symbols[SymbolId.EXPRESSION][
13971404
_get_identifier_symbol(par)
13981405
] = {
@@ -1407,6 +1414,10 @@ def _process_parameters(
14071414
# We can't represent that as expression, since the
14081415
# initial simulation time is only known at the time of the
14091416
# simulation, so we can't substitute it.
1417+
# Also, any parameter with an initial assignment
1418+
# that expression that is implicitly time-dependent
1419+
# must be converted to a species to avoid re-evaluating
1420+
# the initial assignment at every time step.
14101421
self.symbols[SymbolId.SPECIES][
14111422
_get_identifier_symbol(par)
14121423
] = {
@@ -1515,13 +1526,36 @@ def _process_rules(self) -> None:
15151526
self.symbols[SymbolId.EXPRESSION], "value"
15161527
)
15171528

1518-
# expressions must not occur in definition of x0
1529+
# expressions must not occur in the definition of x0
1530+
allowed_syms = (
1531+
set(self.symbols[SymbolId.PARAMETER])
1532+
| set(self.symbols[SymbolId.FIXED_PARAMETER])
1533+
| {sbml_time_symbol}
1534+
)
15191535
for species in self.symbols[SymbolId.SPECIES].values():
1520-
species["init"] = self._make_initial(
1521-
smart_subs_dict(
1522-
species["init"], self.symbols[SymbolId.EXPRESSION], "value"
1536+
# only parameters are allowed as free symbols
1537+
while True:
1538+
species["init"] = species["init"].subs(self.compartments)
1539+
sym_math, rateof_to_dummy = _rateof_to_dummy(species["init"])
1540+
old_init = species["init"]
1541+
if (
1542+
sym_math.free_symbols
1543+
- allowed_syms
1544+
- set(rateof_to_dummy.values())
1545+
== set()
1546+
):
1547+
break
1548+
species["init"] = self._make_initial(
1549+
smart_subs_dict(
1550+
species["init"],
1551+
self.symbols[SymbolId.EXPRESSION],
1552+
"value",
1553+
)
15231554
)
1524-
)
1555+
if species["init"] == old_init:
1556+
raise AssertionError(
1557+
f"Infinite loop detected in _process_rules {species}."
1558+
)
15251559

15261560
def _process_rule_algebraic(self, rule: libsbml.AlgebraicRule):
15271561
formula = self._sympify(rule)
@@ -2359,6 +2393,10 @@ def _make_initial(
23592393
sym_math = sym_math.subs(
23602394
var, self.symbols[SymbolId.SPECIES][var]["init"]
23612395
)
2396+
elif var in self.symbols[SymbolId.ALGEBRAIC_STATE]:
2397+
sym_math = sym_math.subs(
2398+
var, self.symbols[SymbolId.ALGEBRAIC_STATE][var]["init"]
2399+
)
23622400
elif (
23632401
element := self.sbml.getElementBySId(element_id)
23642402
) and self.is_rate_rule_target(element):

python/tests/test_sbml_import.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
import sys
66
from numbers import Number
77
from pathlib import Path
8-
98
import amici
109
import libsbml
1110
import numpy as np
1211
import pytest
1312
from amici.gradient_check import check_derivatives
14-
from amici.sbml_import import SbmlImporter
15-
from amici.testing import skip_on_valgrind
13+
from amici.sbml_import import SbmlImporter, SymbolId
14+
from amici.import_utils import symbol_with_assumptions
1615
from numpy.testing import assert_allclose, assert_array_equal
1716
from amici import import_model_module
17+
from amici.testing import skip_on_valgrind
1818
from amici.testing import TemporaryDirectoryWinSafe as TemporaryDirectory
1919
from conftest import MODEL_STEADYSTATE_SCALED_XML
2020
import sympy as sp
@@ -1142,3 +1142,42 @@ def test_contains_periodic_subexpression():
11421142
assert cps(sp.sin(t), t) is True
11431143
assert cps(sp.cos(t), t) is True
11441144
assert cps(t + sp.sin(t), t) is True
1145+
1146+
1147+
@skip_on_valgrind
1148+
@pytest.mark.parametrize("compute_conservation_laws", [True, False])
1149+
def test_time_dependent_initial_assignment(compute_conservation_laws: bool):
1150+
"""Check that dynamic expressions for initial assignments are only
1151+
evaluated at t=t0."""
1152+
from amici.antimony_import import antimony2sbml
1153+
1154+
ant_model = """
1155+
x1' = 1
1156+
x1 = p0
1157+
p0 = 1
1158+
p1 = x1
1159+
x2 := x1
1160+
p2 = x2
1161+
"""
1162+
1163+
sbml_model = antimony2sbml(ant_model)
1164+
print(sbml_model)
1165+
si = SbmlImporter(sbml_model, from_file=False)
1166+
de_model = si._build_ode_model(
1167+
observables={"obs_p1": {"formula": "p1"}, "obs_p2": {"formula": "p2"}},
1168+
compute_conservation_laws=compute_conservation_laws,
1169+
)
1170+
# "species", because the initial assignment expression is time-dependent
1171+
assert symbol_with_assumptions("p2") in si.symbols[SymbolId.SPECIES].keys()
1172+
# "species", because differential state
1173+
assert symbol_with_assumptions("x1") in si.symbols[SymbolId.SPECIES].keys()
1174+
1175+
assert "p0" in [str(p.get_id()) for p in de_model.parameters()]
1176+
assert "p1" not in [str(p.get_id()) for p in de_model.parameters()]
1177+
assert "p2" not in [str(p.get_id()) for p in de_model.parameters()]
1178+
1179+
assert list(de_model.sym("x_rdata")) == [
1180+
symbol_with_assumptions("p2"),
1181+
symbol_with_assumptions("x1"),
1182+
]
1183+
assert list(de_model.eq("x0")) == [symbol_with_assumptions("p0")] * 2

0 commit comments

Comments
 (0)