Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion doc/rtd_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ setuptools>=67.7.2
# for building the documentation, we don't care whether this fully works
git+https://github.com/pysb/pysb@0afeaab385e9a1d813ecf6fdaf0153f4b91358af
# For forward type definition in generate_equinox
git+https://github.com/PEtab-dev/petab_sciml.git@727d177fd3f85509d0bdcc278b672e9eeafd2384#subdirectory=src/python
matplotlib>=3.7.1
optax
nbsphinx
Expand Down
240 changes: 59 additions & 181 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import contextlib
import copy
import itertools
import logging
Expand Down Expand Up @@ -47,7 +46,6 @@
from .exporters.sundials.cxxcodeprinter import csc_matrix
from .importers.utils import (
ObservableTransformation,
SBMLException,
_default_simplify,
amici_time_symbol,
smart_subs_dict,
Expand Down Expand Up @@ -340,6 +338,10 @@ def algebraic_states(self) -> list[AlgebraicState]:
"""Get all algebraic states."""
return self._algebraic_states

def algebraic_equations(self) -> list[AlgebraicEquation]:
"""Get all algebraic equations."""
return self._algebraic_equations

def observables(self) -> list[Observable]:
"""Get all observables."""
return self._observables
Expand Down Expand Up @@ -404,165 +406,6 @@ def states(self) -> list[State]:
"""Get all states."""
return self._differential_states + self._algebraic_states

def _process_sbml_rate_of(self) -> None:
"""Substitute any SBML-rateOf constructs in the model equations"""
from sbmlmath import rate_of as rate_of_func

species_sym_to_xdot = dict(
zip(self.sym("x"), self.sym("xdot"), strict=True)
)
species_sym_to_idx = {x: i for i, x in enumerate(self.sym("x"))}

def get_rate(symbol: sp.Symbol):
"""Get rate of change of the given symbol"""
if symbol.find(rate_of_func):
raise SBMLException("Nesting rateOf() is not allowed.")

# Replace all rateOf(some_species) by their respective xdot equation
with contextlib.suppress(KeyError):
return self._eqs["xdot"][species_sym_to_idx[symbol]]

# For anything other than a state, rateOf(.) is 0 or invalid
return 0

# replace rateOf-instances in xdot by xdot symbols
made_substitutions = False
for i_state in range(len(self.eq("xdot"))):
if rate_ofs := self._eqs["xdot"][i_state].find(rate_of_func):
self._eqs["xdot"][i_state] = self._eqs["xdot"][i_state].subs(
{
# either the rateOf argument is a state, or it's 0
rate_of: species_sym_to_xdot.get(rate_of.args[0], 0)
for rate_of in rate_ofs
}
)
made_substitutions = True

if made_substitutions:
# substitute in topological order
subs = toposort_symbols(
dict(zip(self.sym("xdot"), self.eq("xdot"), strict=True))
)
self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs)

# replace rateOf-instances in w by xdot equation
# here we may need toposort, as xdot may depend on w
made_substitutions = False
for i_expr in range(len(self.eq("w"))):
new, replacement = self._eqs["w"][i_expr].replace(
rate_of_func, get_rate, map=True
)
if replacement:
self._eqs["w"][i_expr] = new
made_substitutions = True

if made_substitutions:
# Sort expressions in self._expressions, w symbols, and w equations
# in topological order. Ideally, this would already happen before
# adding the expressions to the model, but at that point, we don't
# have access to xdot yet.
# NOTE: elsewhere, conservations law expressions are expected to
# occur before any other w expressions, so we must maintain their
# position
# toposort everything but conservation law expressions,
# then prepend conservation laws
w_sorted = toposort_symbols(
dict(
zip(
self.sym("w")[self.num_cons_law() :, :],
self.eq("w")[self.num_cons_law() :, :],
strict=True,
)
)
)
w_sorted = (
dict(
zip(
self.sym("w")[: self.num_cons_law(), :],
self.eq("w")[: self.num_cons_law(), :],
strict=True,
)
)
| w_sorted
)
old_syms = tuple(self._syms["w"])
topo_expr_syms = tuple(w_sorted.keys())
new_order = [old_syms.index(s) for s in topo_expr_syms]
self._expressions = [self._expressions[i] for i in new_order]
self._syms["w"] = sp.Matrix(topo_expr_syms)
self._eqs["w"] = sp.Matrix(list(w_sorted.values()))

# replace rateOf-instances in x0 by xdot equation
# indices of state variables whose x0 was modified
changed_indices = []
for i_state in range(len(self.eq("x0"))):
new, replacement = self._eqs["x0"][i_state].replace(
rate_of_func, get_rate, map=True
)
if replacement:
self._eqs["x0"][i_state] = new
changed_indices.append(i_state)
if changed_indices:
# Replace any newly introduced state variables
# by their x0 expressions.
# Also replace any newly introduced `w` symbols by their
# expressions (after `w` was toposorted above).
subs = toposort_symbols(
dict(zip(self.sym("x_rdata"), self.eq("x0"), strict=True))
)
subs = dict(zip(self._syms["w"], self.eq("w"), strict=True)) | subs
for i_state in changed_indices:
self._eqs["x0"][i_state] = smart_subs_dict(
self._eqs["x0"][i_state], subs
)

for component in chain(
self.observables(),
self.events(),
self._algebraic_equations,
):
if rate_ofs := component.get_val().find(rate_of_func):
if isinstance(component, Event):
# TODO froot(...) can currently not depend on `w`, so this substitution fails for non-zero rates
# see, e.g., sbml test case 01293
raise SBMLException(
"AMICI does currently not support rateOf(.) inside event trigger functions."
)

if isinstance(component, AlgebraicEquation):
# TODO IDACalcIC fails with
# "The linesearch algorithm failed: step too small or too many backtracks."
# see, e.g., sbml test case 01482
raise SBMLException(
"AMICI does currently not support rateOf(.) inside AlgebraicRules."
)

component.set_val(
component.get_val().subs(
{
rate_of: get_rate(rate_of.args[0])
for rate_of in rate_ofs
}
)
)

for event in self.events():
state_update = event.get_state_update(
x=self.sym("x"), x_old=self.sym("x")
)
if state_update is None:
continue

for i_state in range(len(state_update)):
if rate_ofs := state_update[i_state].find(rate_of_func):
raise SBMLException(
"AMICI does currently not support rateOf(.) inside event state updates."
)
# TODO here we need xdot sym, not eqs
# event._state_update[i_state] = event._state_update[i_state].subs(
# {rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs}
# )

def add_component(
self, component: ModelQuantity, insert_first: bool | None = False
) -> None:
Expand Down Expand Up @@ -1271,9 +1114,7 @@ def generate_basic_variables(self) -> None:
Generates the symbolic identifiers for all variables in
``DEModel._variable_prototype``
"""
# We need to process events and Heaviside functions in the ``DEModel`,
# before adding it to DEExporter
self.parse_events()
self._reorder_events()

for var in self._variable_prototype:
if var not in self._syms:
Expand Down Expand Up @@ -1335,7 +1176,11 @@ def parse_events(self) -> None:
for event in self.events():
event.set_val(event.get_val().subs(w_toposorted))

# re-order events - first those that require root tracking, then the others
def _reorder_events(self) -> None:
"""
Re-order events - first those that require root tracking,
then the others.
"""
constant_syms = set(self.sym("k")) | set(self.sym("p"))
self._events = list(
chain(
Expand Down Expand Up @@ -2694,22 +2539,7 @@ def _process_hybridization(self, hybridization: dict) -> None:
self._observables = [self._observables[i] for i in new_order]

if added_expressions:
# toposort expressions
w_sorted = toposort_symbols(
dict(
zip(
self.sym("w"),
self.eq("w"),
strict=True,
)
)
)
old_syms = tuple(self._syms["w"])
topo_expr_syms = tuple(w_sorted.keys())
new_order = [old_syms.index(s) for s in topo_expr_syms]
self._expressions = [self._expressions[i] for i in new_order]
self._syms["w"] = sp.Matrix(topo_expr_syms)
self._eqs["w"] = sp.Matrix(list(w_sorted.values()))
self.toposort_expressions()

def get_explicit_roots(self) -> set[sp.Expr]:
"""
Expand Down Expand Up @@ -2752,3 +2582,51 @@ def has_event_assignments(self) -> bool:
boolean indicating if event assignments are present
"""
return any(event.updates_state for event in self._events)

def toposort_expressions(self) -> dict[sp.Symbol, sp.Expr]:
"""
Sort expressions in topological order.

:return:
dict of expression symbols to expressions in topological order
"""
# ensure no symbols or equations that depend on `w` have been generated
# yet, otherwise the re-ordering might break dependencies
if (
generated := set(self._syms)
| set(self._eqs)
| set(self._sparsesyms)
| set(self._sparseeqs)
) - {"w", "p", "k", "x", "x_rdata"}:
raise AssertionError(
"This function must be called before computing any "
"derivatives. The following symbols/equations are already "
f"generated: {generated}"
)

# NOTE: elsewhere, conservations law expressions are expected to
# occur before any other w expressions, so we must maintain their
# position.
# toposort everything but conservation law expressions,
# then prepend conservation laws

w_toposorted = toposort_symbols(
{
e.get_sym(): e.get_val()
for e in self.expressions()[self.num_cons_law() :]
}
)

w_toposorted = {
e.get_sym(): e.get_val()
for e in self.expressions()[: self.num_cons_law()]
} | w_toposorted

old_syms = tuple(e.get_sym() for e in self.expressions())
topo_expr_syms = tuple(w_toposorted)
new_order = [old_syms.index(s) for s in topo_expr_syms]
self._expressions = [self._expressions[i] for i in new_order]
self._syms["w"] = sp.Matrix(topo_expr_syms)
self._eqs["w"] = sp.Matrix(list(w_toposorted.values()))

return w_toposorted
1 change: 1 addition & 0 deletions python/sdist/amici/importers/pysb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def ode_model_from_pysb_importer(

_process_stoichiometric_matrix(model, ode, fixed_parameters)

ode.parse_events()
ode.generate_basic_variables()

return ode
Expand Down
Loading
Loading