diff --git a/doc/rtd_requirements.txt b/doc/rtd_requirements.txt index ce3e1d1688..1e58136155 100644 --- a/doc/rtd_requirements.txt +++ b/doc/rtd_requirements.txt @@ -8,7 +8,6 @@ setuptools>=67.7.2 # for building the documentation, we don't care whether this fully works git+https://github.com/pysb/pysb@0afeaab385e9a1d813ecf6fdaf0153f4b91358af # For forward type definition in generate_equinox -git+https://github.com/PEtab-dev/petab_sciml.git@727d177fd3f85509d0bdcc278b672e9eeafd2384#subdirectory=src/python matplotlib>=3.7.1 optax nbsphinx diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 030999704b..8a0adee9f6 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -2,7 +2,6 @@ from __future__ import annotations -import contextlib import copy import itertools import logging @@ -47,7 +46,6 @@ from .exporters.sundials.cxxcodeprinter import csc_matrix from .importers.utils import ( ObservableTransformation, - SBMLException, _default_simplify, amici_time_symbol, smart_subs_dict, @@ -340,6 +338,10 @@ def algebraic_states(self) -> list[AlgebraicState]: """Get all algebraic states.""" return self._algebraic_states + def algebraic_equations(self) -> list[AlgebraicEquation]: + """Get all algebraic equations.""" + return self._algebraic_equations + def observables(self) -> list[Observable]: """Get all observables.""" return self._observables @@ -404,165 +406,6 @@ def states(self) -> list[State]: """Get all states.""" return self._differential_states + self._algebraic_states - def _process_sbml_rate_of(self) -> None: - """Substitute any SBML-rateOf constructs in the model equations""" - from sbmlmath import rate_of as rate_of_func - - species_sym_to_xdot = dict( - zip(self.sym("x"), self.sym("xdot"), strict=True) - ) - species_sym_to_idx = {x: i for i, x in enumerate(self.sym("x"))} - - def get_rate(symbol: sp.Symbol): - """Get rate of change of the given symbol""" - if symbol.find(rate_of_func): - raise SBMLException("Nesting rateOf() is not allowed.") - - # Replace all rateOf(some_species) by their respective xdot equation - with contextlib.suppress(KeyError): - return self._eqs["xdot"][species_sym_to_idx[symbol]] - - # For anything other than a state, rateOf(.) is 0 or invalid - return 0 - - # replace rateOf-instances in xdot by xdot symbols - made_substitutions = False - for i_state in range(len(self.eq("xdot"))): - if rate_ofs := self._eqs["xdot"][i_state].find(rate_of_func): - self._eqs["xdot"][i_state] = self._eqs["xdot"][i_state].subs( - { - # either the rateOf argument is a state, or it's 0 - rate_of: species_sym_to_xdot.get(rate_of.args[0], 0) - for rate_of in rate_ofs - } - ) - made_substitutions = True - - if made_substitutions: - # substitute in topological order - subs = toposort_symbols( - dict(zip(self.sym("xdot"), self.eq("xdot"), strict=True)) - ) - self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs) - - # replace rateOf-instances in w by xdot equation - # here we may need toposort, as xdot may depend on w - made_substitutions = False - for i_expr in range(len(self.eq("w"))): - new, replacement = self._eqs["w"][i_expr].replace( - rate_of_func, get_rate, map=True - ) - if replacement: - self._eqs["w"][i_expr] = new - made_substitutions = True - - if made_substitutions: - # Sort expressions in self._expressions, w symbols, and w equations - # in topological order. Ideally, this would already happen before - # adding the expressions to the model, but at that point, we don't - # have access to xdot yet. - # NOTE: elsewhere, conservations law expressions are expected to - # occur before any other w expressions, so we must maintain their - # position - # toposort everything but conservation law expressions, - # then prepend conservation laws - w_sorted = toposort_symbols( - dict( - zip( - self.sym("w")[self.num_cons_law() :, :], - self.eq("w")[self.num_cons_law() :, :], - strict=True, - ) - ) - ) - w_sorted = ( - dict( - zip( - self.sym("w")[: self.num_cons_law(), :], - self.eq("w")[: self.num_cons_law(), :], - strict=True, - ) - ) - | w_sorted - ) - old_syms = tuple(self._syms["w"]) - topo_expr_syms = tuple(w_sorted.keys()) - new_order = [old_syms.index(s) for s in topo_expr_syms] - self._expressions = [self._expressions[i] for i in new_order] - self._syms["w"] = sp.Matrix(topo_expr_syms) - self._eqs["w"] = sp.Matrix(list(w_sorted.values())) - - # replace rateOf-instances in x0 by xdot equation - # indices of state variables whose x0 was modified - changed_indices = [] - for i_state in range(len(self.eq("x0"))): - new, replacement = self._eqs["x0"][i_state].replace( - rate_of_func, get_rate, map=True - ) - if replacement: - self._eqs["x0"][i_state] = new - changed_indices.append(i_state) - if changed_indices: - # Replace any newly introduced state variables - # by their x0 expressions. - # Also replace any newly introduced `w` symbols by their - # expressions (after `w` was toposorted above). - subs = toposort_symbols( - dict(zip(self.sym("x_rdata"), self.eq("x0"), strict=True)) - ) - subs = dict(zip(self._syms["w"], self.eq("w"), strict=True)) | subs - for i_state in changed_indices: - self._eqs["x0"][i_state] = smart_subs_dict( - self._eqs["x0"][i_state], subs - ) - - for component in chain( - self.observables(), - self.events(), - self._algebraic_equations, - ): - if rate_ofs := component.get_val().find(rate_of_func): - if isinstance(component, Event): - # TODO froot(...) can currently not depend on `w`, so this substitution fails for non-zero rates - # see, e.g., sbml test case 01293 - raise SBMLException( - "AMICI does currently not support rateOf(.) inside event trigger functions." - ) - - if isinstance(component, AlgebraicEquation): - # TODO IDACalcIC fails with - # "The linesearch algorithm failed: step too small or too many backtracks." - # see, e.g., sbml test case 01482 - raise SBMLException( - "AMICI does currently not support rateOf(.) inside AlgebraicRules." - ) - - component.set_val( - component.get_val().subs( - { - rate_of: get_rate(rate_of.args[0]) - for rate_of in rate_ofs - } - ) - ) - - for event in self.events(): - state_update = event.get_state_update( - x=self.sym("x"), x_old=self.sym("x") - ) - if state_update is None: - continue - - for i_state in range(len(state_update)): - if rate_ofs := state_update[i_state].find(rate_of_func): - raise SBMLException( - "AMICI does currently not support rateOf(.) inside event state updates." - ) - # TODO here we need xdot sym, not eqs - # event._state_update[i_state] = event._state_update[i_state].subs( - # {rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs} - # ) - def add_component( self, component: ModelQuantity, insert_first: bool | None = False ) -> None: @@ -1271,9 +1114,7 @@ def generate_basic_variables(self) -> None: Generates the symbolic identifiers for all variables in ``DEModel._variable_prototype`` """ - # We need to process events and Heaviside functions in the ``DEModel`, - # before adding it to DEExporter - self.parse_events() + self._reorder_events() for var in self._variable_prototype: if var not in self._syms: @@ -1335,7 +1176,11 @@ def parse_events(self) -> None: for event in self.events(): event.set_val(event.get_val().subs(w_toposorted)) - # re-order events - first those that require root tracking, then the others + def _reorder_events(self) -> None: + """ + Re-order events - first those that require root tracking, + then the others. + """ constant_syms = set(self.sym("k")) | set(self.sym("p")) self._events = list( chain( @@ -2694,22 +2539,7 @@ def _process_hybridization(self, hybridization: dict) -> None: self._observables = [self._observables[i] for i in new_order] if added_expressions: - # toposort expressions - w_sorted = toposort_symbols( - dict( - zip( - self.sym("w"), - self.eq("w"), - strict=True, - ) - ) - ) - old_syms = tuple(self._syms["w"]) - topo_expr_syms = tuple(w_sorted.keys()) - new_order = [old_syms.index(s) for s in topo_expr_syms] - self._expressions = [self._expressions[i] for i in new_order] - self._syms["w"] = sp.Matrix(topo_expr_syms) - self._eqs["w"] = sp.Matrix(list(w_sorted.values())) + self.toposort_expressions() def get_explicit_roots(self) -> set[sp.Expr]: """ @@ -2752,3 +2582,51 @@ def has_event_assignments(self) -> bool: boolean indicating if event assignments are present """ return any(event.updates_state for event in self._events) + + def toposort_expressions(self) -> dict[sp.Symbol, sp.Expr]: + """ + Sort expressions in topological order. + + :return: + dict of expression symbols to expressions in topological order + """ + # ensure no symbols or equations that depend on `w` have been generated + # yet, otherwise the re-ordering might break dependencies + if ( + generated := set(self._syms) + | set(self._eqs) + | set(self._sparsesyms) + | set(self._sparseeqs) + ) - {"w", "p", "k", "x", "x_rdata"}: + raise AssertionError( + "This function must be called before computing any " + "derivatives. The following symbols/equations are already " + f"generated: {generated}" + ) + + # NOTE: elsewhere, conservations law expressions are expected to + # occur before any other w expressions, so we must maintain their + # position. + # toposort everything but conservation law expressions, + # then prepend conservation laws + + w_toposorted = toposort_symbols( + { + e.get_sym(): e.get_val() + for e in self.expressions()[self.num_cons_law() :] + } + ) + + w_toposorted = { + e.get_sym(): e.get_val() + for e in self.expressions()[: self.num_cons_law()] + } | w_toposorted + + old_syms = tuple(e.get_sym() for e in self.expressions()) + topo_expr_syms = tuple(w_toposorted) + new_order = [old_syms.index(s) for s in topo_expr_syms] + self._expressions = [self._expressions[i] for i in new_order] + self._syms["w"] = sp.Matrix(topo_expr_syms) + self._eqs["w"] = sp.Matrix(list(w_toposorted.values())) + + return w_toposorted diff --git a/python/sdist/amici/importers/pysb/__init__.py b/python/sdist/amici/importers/pysb/__init__.py index cb44744cc1..939bf1a13e 100644 --- a/python/sdist/amici/importers/pysb/__init__.py +++ b/python/sdist/amici/importers/pysb/__init__.py @@ -397,6 +397,7 @@ def ode_model_from_pysb_importer( _process_stoichiometric_matrix(model, ode, fixed_parameters) + ode.parse_events() ode.generate_basic_variables() return ode diff --git a/python/sdist/amici/importers/sbml/__init__.py b/python/sdist/amici/importers/sbml/__init__.py index 685836a830..d025edcd80 100644 --- a/python/sdist/amici/importers/sbml/__init__.py +++ b/python/sdist/amici/importers/sbml/__init__.py @@ -5,6 +5,7 @@ in the `Systems Biology Markup Language (SBML) `_. """ +import contextlib import copy import itertools as itt import logging @@ -14,6 +15,7 @@ import warnings import xml.etree.ElementTree as ET from collections.abc import Callable, Iterable, Sequence +from itertools import chain from pathlib import Path from typing import ( Any, @@ -30,7 +32,12 @@ from amici import get_model_dir, has_clibs from amici.constants import SymbolId from amici.de_model import DEModel -from amici.de_model_components import Expression, symbol_to_type +from amici.de_model_components import ( + DifferentialState, + Event, + Expression, + symbol_to_type, +) from amici.importers.sbml.splines import AbstractSpline from amici.importers.sbml.utils import SBMLException from amici.importers.utils import ( @@ -682,12 +689,14 @@ def _build_ode_model( if hybridization: ode_model._process_hybridization(hybridization) + ode_model.parse_events() + # substitute SBML-rateOf constructs + # must be done after parse_events, but before generate_basic_variables + self._process_sbml_rate_of(ode_model) + # fill in 'self._sym' based on prototypes and components in ode_model ode_model.generate_basic_variables() - # substitute SBML-rateOf constructs - ode_model._process_sbml_rate_of() - return ode_model @log_execution_time("importing SBML", logger) @@ -3014,6 +3023,126 @@ def _transform_dxdt_to_concentration( return dxdt / v + def _process_sbml_rate_of(self, de_model: DEModel) -> None: + """Substitute any SBML-rateOf constructs in the model equations""" + from sbmlmath import rate_of as rate_of_func + + # DAE models in general are supported, but not if they include rateOf() + # The is_ode check below will not trigger in case of a DAE model with + # rateOf() only in algebraic equations, so we need to check for that. + for ae in de_model.algebraic_equations(): + if ae.get_val().find(rate_of_func): + # Not implemented + raise SBMLException( + "rateOf() is not supported in algebraic equations. " + f"Used in: {ae.get_id()}: {ae.get_val()}" + ) + + sym_to_state_var: dict[sp.Symbol, DifferentialState] = { + state.get_sym(): state for state in de_model.differential_states() + } + + is_ode = de_model.is_ode() + + def get_rate(symbol: sp.Symbol): + """Get rate of change of the given symbol + (symbol is argument of rateOf)""" + if not is_ode: + # Not implemented for DAE models + # sbml test case 01482 + raise SBMLException( + "rateOf() is only supported in ODE models." + ) + + if symbol.find(rate_of_func): + raise SBMLException("Nesting rateOf() is not allowed.") + + # Replace all rateOf(some_state_var) by their respective + # xdot equation + with contextlib.suppress(KeyError): + return sym_to_state_var[symbol].get_dt() + + # For anything other than a state variable, + # rateOf(.) is 0 or invalid + return sp.Integer(0) + + def do_subs(expr, rate_ofs) -> sp.Expr: + """Substitute rateOf(...) in expr by appropriate expressions""" + expr = expr.subs( + {rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs} + ) + if rate_ofs := expr.find(rate_of_func): + # recursively substitute until no rateOf remains + expr = do_subs(expr, rate_ofs) + return expr + + # replace rateOf-instances in xdot + for state in de_model.differential_states(): + if rate_ofs := state.get_dt().find(rate_of_func): + state.set_dt(do_subs(state.get_dt(), rate_ofs)) + + # replace rateOf-instances in expressions which we will need for + # substitutions later + for expr in de_model.expressions(): + if rate_ofs := expr.get_val().find(rate_of_func): + expr.set_val(do_subs(expr.get_val(), rate_ofs)) + + w_toposorted = de_model.toposort_expressions() + + # replace rateOf-instances in x0 + # indices of state variables whose x0 was modified + changed_indices = [] + for i_state, state in enumerate(de_model.differential_states()): + new, replacement = state.get_val().replace( + rate_of_func, get_rate, map=True + ) + if replacement: + state.set_val(new) + changed_indices.append(i_state) + if changed_indices: + # Replace any newly introduced state variables + # by their x0 expressions. + subs = w_toposorted | toposort_symbols( + { + state.get_x_rdata(): state.get_val() + for state in de_model.differential_states() + } + ) + states = de_model.differential_states() + for i_state in changed_indices: + states[i_state].set_val( + smart_subs_dict(states[i_state].get_val(), subs) + ) + + for component in chain( + de_model.observables(), + de_model.events(), + ): + if rate_ofs := component.get_val().find(rate_of_func): + component.set_val(do_subs(component.get_val(), rate_ofs)) + + if isinstance(component, Event): + if rate_ofs: + # currently, `root` cannot depend on `w`. + # this could be changed, but for now, + # we just flatten out w expressions + component.set_val( + smart_subs_dict(component.get_val(), w_toposorted) + ) + + # for events, also substitute in state updates + for target, assignment in component._assignments.items(): + if rate_ofs := assignment.find(rate_of_func): + new_assignment = do_subs(assignment, rate_ofs) + # currently, deltax cannot depend on `w`. + # this could be changed, or we could use xdot + # symbols, but for now, we just flatten out w + # expressions + new_assignment = smart_subs_dict( + new_assignment, w_toposorted + ) + component._assignments[target] = new_assignment + def _check_lib_sbml_errors( sbml_doc: libsbml.SBMLDocument, show_warnings: bool = False