From 853104cea5f6f251e933b2bab4794a2fe8c2cb82 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Sat, 15 Nov 2025 09:59:28 +0100 Subject: [PATCH 1/2] doc: Skip petab-sciml for now GitHub Action runners run out of disk space when installing petab-sciml with all its huge dependencies. Don't install that for now. So far, it's not used anywhere for the documentation build as far as I can see. This won't prevent enabling intersphinx later on. --- doc/rtd_requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/rtd_requirements.txt b/doc/rtd_requirements.txt index ce3e1d1688..1e58136155 100644 --- a/doc/rtd_requirements.txt +++ b/doc/rtd_requirements.txt @@ -8,7 +8,6 @@ setuptools>=67.7.2 # for building the documentation, we don't care whether this fully works git+https://github.com/pysb/pysb@0afeaab385e9a1d813ecf6fdaf0153f4b91358af # For forward type definition in generate_equinox -git+https://github.com/PEtab-dev/petab_sciml.git@727d177fd3f85509d0bdcc278b672e9eeafd2384#subdirectory=src/python matplotlib>=3.7.1 optax nbsphinx From e7b489667bc5db416e8a6a4e6c3525321b9dc7a6 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Fri, 14 Nov 2025 22:29:15 +0100 Subject: [PATCH 2/2] Refactor rateOf handling Pull rateOf-handling out of DEModel and keep it along other SBML processing where it belongs. This is easier to follow and prevents some lingering issues with the old approach due to the xdot / w interdependencies. This also handles rateOf expressions in some additional, previously unsupported places like event assignments. --- python/sdist/amici/de_model.py | 240 +++++------------- python/sdist/amici/importers/pysb/__init__.py | 1 + python/sdist/amici/importers/sbml/__init__.py | 137 +++++++++- 3 files changed, 193 insertions(+), 185 deletions(-) diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 030999704b..8a0adee9f6 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -2,7 +2,6 @@ from __future__ import annotations -import contextlib import copy import itertools import logging @@ -47,7 +46,6 @@ from .exporters.sundials.cxxcodeprinter import csc_matrix from .importers.utils import ( ObservableTransformation, - SBMLException, _default_simplify, amici_time_symbol, smart_subs_dict, @@ -340,6 +338,10 @@ def algebraic_states(self) -> list[AlgebraicState]: """Get all algebraic states.""" return self._algebraic_states + def algebraic_equations(self) -> list[AlgebraicEquation]: + """Get all algebraic equations.""" + return self._algebraic_equations + def observables(self) -> list[Observable]: """Get all observables.""" return self._observables @@ -404,165 +406,6 @@ def states(self) -> list[State]: """Get all states.""" return self._differential_states + self._algebraic_states - def _process_sbml_rate_of(self) -> None: - """Substitute any SBML-rateOf constructs in the model equations""" - from sbmlmath import rate_of as rate_of_func - - species_sym_to_xdot = dict( - zip(self.sym("x"), self.sym("xdot"), strict=True) - ) - species_sym_to_idx = {x: i for i, x in enumerate(self.sym("x"))} - - def get_rate(symbol: sp.Symbol): - """Get rate of change of the given symbol""" - if symbol.find(rate_of_func): - raise SBMLException("Nesting rateOf() is not allowed.") - - # Replace all rateOf(some_species) by their respective xdot equation - with contextlib.suppress(KeyError): - return self._eqs["xdot"][species_sym_to_idx[symbol]] - - # For anything other than a state, rateOf(.) is 0 or invalid - return 0 - - # replace rateOf-instances in xdot by xdot symbols - made_substitutions = False - for i_state in range(len(self.eq("xdot"))): - if rate_ofs := self._eqs["xdot"][i_state].find(rate_of_func): - self._eqs["xdot"][i_state] = self._eqs["xdot"][i_state].subs( - { - # either the rateOf argument is a state, or it's 0 - rate_of: species_sym_to_xdot.get(rate_of.args[0], 0) - for rate_of in rate_ofs - } - ) - made_substitutions = True - - if made_substitutions: - # substitute in topological order - subs = toposort_symbols( - dict(zip(self.sym("xdot"), self.eq("xdot"), strict=True)) - ) - self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs) - - # replace rateOf-instances in w by xdot equation - # here we may need toposort, as xdot may depend on w - made_substitutions = False - for i_expr in range(len(self.eq("w"))): - new, replacement = self._eqs["w"][i_expr].replace( - rate_of_func, get_rate, map=True - ) - if replacement: - self._eqs["w"][i_expr] = new - made_substitutions = True - - if made_substitutions: - # Sort expressions in self._expressions, w symbols, and w equations - # in topological order. Ideally, this would already happen before - # adding the expressions to the model, but at that point, we don't - # have access to xdot yet. - # NOTE: elsewhere, conservations law expressions are expected to - # occur before any other w expressions, so we must maintain their - # position - # toposort everything but conservation law expressions, - # then prepend conservation laws - w_sorted = toposort_symbols( - dict( - zip( - self.sym("w")[self.num_cons_law() :, :], - self.eq("w")[self.num_cons_law() :, :], - strict=True, - ) - ) - ) - w_sorted = ( - dict( - zip( - self.sym("w")[: self.num_cons_law(), :], - self.eq("w")[: self.num_cons_law(), :], - strict=True, - ) - ) - | w_sorted - ) - old_syms = tuple(self._syms["w"]) - topo_expr_syms = tuple(w_sorted.keys()) - new_order = [old_syms.index(s) for s in topo_expr_syms] - self._expressions = [self._expressions[i] for i in new_order] - self._syms["w"] = sp.Matrix(topo_expr_syms) - self._eqs["w"] = sp.Matrix(list(w_sorted.values())) - - # replace rateOf-instances in x0 by xdot equation - # indices of state variables whose x0 was modified - changed_indices = [] - for i_state in range(len(self.eq("x0"))): - new, replacement = self._eqs["x0"][i_state].replace( - rate_of_func, get_rate, map=True - ) - if replacement: - self._eqs["x0"][i_state] = new - changed_indices.append(i_state) - if changed_indices: - # Replace any newly introduced state variables - # by their x0 expressions. - # Also replace any newly introduced `w` symbols by their - # expressions (after `w` was toposorted above). - subs = toposort_symbols( - dict(zip(self.sym("x_rdata"), self.eq("x0"), strict=True)) - ) - subs = dict(zip(self._syms["w"], self.eq("w"), strict=True)) | subs - for i_state in changed_indices: - self._eqs["x0"][i_state] = smart_subs_dict( - self._eqs["x0"][i_state], subs - ) - - for component in chain( - self.observables(), - self.events(), - self._algebraic_equations, - ): - if rate_ofs := component.get_val().find(rate_of_func): - if isinstance(component, Event): - # TODO froot(...) can currently not depend on `w`, so this substitution fails for non-zero rates - # see, e.g., sbml test case 01293 - raise SBMLException( - "AMICI does currently not support rateOf(.) inside event trigger functions." - ) - - if isinstance(component, AlgebraicEquation): - # TODO IDACalcIC fails with - # "The linesearch algorithm failed: step too small or too many backtracks." - # see, e.g., sbml test case 01482 - raise SBMLException( - "AMICI does currently not support rateOf(.) inside AlgebraicRules." - ) - - component.set_val( - component.get_val().subs( - { - rate_of: get_rate(rate_of.args[0]) - for rate_of in rate_ofs - } - ) - ) - - for event in self.events(): - state_update = event.get_state_update( - x=self.sym("x"), x_old=self.sym("x") - ) - if state_update is None: - continue - - for i_state in range(len(state_update)): - if rate_ofs := state_update[i_state].find(rate_of_func): - raise SBMLException( - "AMICI does currently not support rateOf(.) inside event state updates." - ) - # TODO here we need xdot sym, not eqs - # event._state_update[i_state] = event._state_update[i_state].subs( - # {rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs} - # ) - def add_component( self, component: ModelQuantity, insert_first: bool | None = False ) -> None: @@ -1271,9 +1114,7 @@ def generate_basic_variables(self) -> None: Generates the symbolic identifiers for all variables in ``DEModel._variable_prototype`` """ - # We need to process events and Heaviside functions in the ``DEModel`, - # before adding it to DEExporter - self.parse_events() + self._reorder_events() for var in self._variable_prototype: if var not in self._syms: @@ -1335,7 +1176,11 @@ def parse_events(self) -> None: for event in self.events(): event.set_val(event.get_val().subs(w_toposorted)) - # re-order events - first those that require root tracking, then the others + def _reorder_events(self) -> None: + """ + Re-order events - first those that require root tracking, + then the others. + """ constant_syms = set(self.sym("k")) | set(self.sym("p")) self._events = list( chain( @@ -2694,22 +2539,7 @@ def _process_hybridization(self, hybridization: dict) -> None: self._observables = [self._observables[i] for i in new_order] if added_expressions: - # toposort expressions - w_sorted = toposort_symbols( - dict( - zip( - self.sym("w"), - self.eq("w"), - strict=True, - ) - ) - ) - old_syms = tuple(self._syms["w"]) - topo_expr_syms = tuple(w_sorted.keys()) - new_order = [old_syms.index(s) for s in topo_expr_syms] - self._expressions = [self._expressions[i] for i in new_order] - self._syms["w"] = sp.Matrix(topo_expr_syms) - self._eqs["w"] = sp.Matrix(list(w_sorted.values())) + self.toposort_expressions() def get_explicit_roots(self) -> set[sp.Expr]: """ @@ -2752,3 +2582,51 @@ def has_event_assignments(self) -> bool: boolean indicating if event assignments are present """ return any(event.updates_state for event in self._events) + + def toposort_expressions(self) -> dict[sp.Symbol, sp.Expr]: + """ + Sort expressions in topological order. + + :return: + dict of expression symbols to expressions in topological order + """ + # ensure no symbols or equations that depend on `w` have been generated + # yet, otherwise the re-ordering might break dependencies + if ( + generated := set(self._syms) + | set(self._eqs) + | set(self._sparsesyms) + | set(self._sparseeqs) + ) - {"w", "p", "k", "x", "x_rdata"}: + raise AssertionError( + "This function must be called before computing any " + "derivatives. The following symbols/equations are already " + f"generated: {generated}" + ) + + # NOTE: elsewhere, conservations law expressions are expected to + # occur before any other w expressions, so we must maintain their + # position. + # toposort everything but conservation law expressions, + # then prepend conservation laws + + w_toposorted = toposort_symbols( + { + e.get_sym(): e.get_val() + for e in self.expressions()[self.num_cons_law() :] + } + ) + + w_toposorted = { + e.get_sym(): e.get_val() + for e in self.expressions()[: self.num_cons_law()] + } | w_toposorted + + old_syms = tuple(e.get_sym() for e in self.expressions()) + topo_expr_syms = tuple(w_toposorted) + new_order = [old_syms.index(s) for s in topo_expr_syms] + self._expressions = [self._expressions[i] for i in new_order] + self._syms["w"] = sp.Matrix(topo_expr_syms) + self._eqs["w"] = sp.Matrix(list(w_toposorted.values())) + + return w_toposorted diff --git a/python/sdist/amici/importers/pysb/__init__.py b/python/sdist/amici/importers/pysb/__init__.py index cb44744cc1..939bf1a13e 100644 --- a/python/sdist/amici/importers/pysb/__init__.py +++ b/python/sdist/amici/importers/pysb/__init__.py @@ -397,6 +397,7 @@ def ode_model_from_pysb_importer( _process_stoichiometric_matrix(model, ode, fixed_parameters) + ode.parse_events() ode.generate_basic_variables() return ode diff --git a/python/sdist/amici/importers/sbml/__init__.py b/python/sdist/amici/importers/sbml/__init__.py index 685836a830..d025edcd80 100644 --- a/python/sdist/amici/importers/sbml/__init__.py +++ b/python/sdist/amici/importers/sbml/__init__.py @@ -5,6 +5,7 @@ in the `Systems Biology Markup Language (SBML) `_. """ +import contextlib import copy import itertools as itt import logging @@ -14,6 +15,7 @@ import warnings import xml.etree.ElementTree as ET from collections.abc import Callable, Iterable, Sequence +from itertools import chain from pathlib import Path from typing import ( Any, @@ -30,7 +32,12 @@ from amici import get_model_dir, has_clibs from amici.constants import SymbolId from amici.de_model import DEModel -from amici.de_model_components import Expression, symbol_to_type +from amici.de_model_components import ( + DifferentialState, + Event, + Expression, + symbol_to_type, +) from amici.importers.sbml.splines import AbstractSpline from amici.importers.sbml.utils import SBMLException from amici.importers.utils import ( @@ -682,12 +689,14 @@ def _build_ode_model( if hybridization: ode_model._process_hybridization(hybridization) + ode_model.parse_events() + # substitute SBML-rateOf constructs + # must be done after parse_events, but before generate_basic_variables + self._process_sbml_rate_of(ode_model) + # fill in 'self._sym' based on prototypes and components in ode_model ode_model.generate_basic_variables() - # substitute SBML-rateOf constructs - ode_model._process_sbml_rate_of() - return ode_model @log_execution_time("importing SBML", logger) @@ -3014,6 +3023,126 @@ def _transform_dxdt_to_concentration( return dxdt / v + def _process_sbml_rate_of(self, de_model: DEModel) -> None: + """Substitute any SBML-rateOf constructs in the model equations""" + from sbmlmath import rate_of as rate_of_func + + # DAE models in general are supported, but not if they include rateOf() + # The is_ode check below will not trigger in case of a DAE model with + # rateOf() only in algebraic equations, so we need to check for that. + for ae in de_model.algebraic_equations(): + if ae.get_val().find(rate_of_func): + # Not implemented + raise SBMLException( + "rateOf() is not supported in algebraic equations. " + f"Used in: {ae.get_id()}: {ae.get_val()}" + ) + + sym_to_state_var: dict[sp.Symbol, DifferentialState] = { + state.get_sym(): state for state in de_model.differential_states() + } + + is_ode = de_model.is_ode() + + def get_rate(symbol: sp.Symbol): + """Get rate of change of the given symbol + (symbol is argument of rateOf)""" + if not is_ode: + # Not implemented for DAE models + # sbml test case 01482 + raise SBMLException( + "rateOf() is only supported in ODE models." + ) + + if symbol.find(rate_of_func): + raise SBMLException("Nesting rateOf() is not allowed.") + + # Replace all rateOf(some_state_var) by their respective + # xdot equation + with contextlib.suppress(KeyError): + return sym_to_state_var[symbol].get_dt() + + # For anything other than a state variable, + # rateOf(.) is 0 or invalid + return sp.Integer(0) + + def do_subs(expr, rate_ofs) -> sp.Expr: + """Substitute rateOf(...) in expr by appropriate expressions""" + expr = expr.subs( + {rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs} + ) + if rate_ofs := expr.find(rate_of_func): + # recursively substitute until no rateOf remains + expr = do_subs(expr, rate_ofs) + return expr + + # replace rateOf-instances in xdot + for state in de_model.differential_states(): + if rate_ofs := state.get_dt().find(rate_of_func): + state.set_dt(do_subs(state.get_dt(), rate_ofs)) + + # replace rateOf-instances in expressions which we will need for + # substitutions later + for expr in de_model.expressions(): + if rate_ofs := expr.get_val().find(rate_of_func): + expr.set_val(do_subs(expr.get_val(), rate_ofs)) + + w_toposorted = de_model.toposort_expressions() + + # replace rateOf-instances in x0 + # indices of state variables whose x0 was modified + changed_indices = [] + for i_state, state in enumerate(de_model.differential_states()): + new, replacement = state.get_val().replace( + rate_of_func, get_rate, map=True + ) + if replacement: + state.set_val(new) + changed_indices.append(i_state) + if changed_indices: + # Replace any newly introduced state variables + # by their x0 expressions. + subs = w_toposorted | toposort_symbols( + { + state.get_x_rdata(): state.get_val() + for state in de_model.differential_states() + } + ) + states = de_model.differential_states() + for i_state in changed_indices: + states[i_state].set_val( + smart_subs_dict(states[i_state].get_val(), subs) + ) + + for component in chain( + de_model.observables(), + de_model.events(), + ): + if rate_ofs := component.get_val().find(rate_of_func): + component.set_val(do_subs(component.get_val(), rate_ofs)) + + if isinstance(component, Event): + if rate_ofs: + # currently, `root` cannot depend on `w`. + # this could be changed, but for now, + # we just flatten out w expressions + component.set_val( + smart_subs_dict(component.get_val(), w_toposorted) + ) + + # for events, also substitute in state updates + for target, assignment in component._assignments.items(): + if rate_ofs := assignment.find(rate_of_func): + new_assignment = do_subs(assignment, rate_ofs) + # currently, deltax cannot depend on `w`. + # this could be changed, or we could use xdot + # symbols, but for now, we just flatten out w + # expressions + new_assignment = smart_subs_dict( + new_assignment, w_toposorted + ) + component._assignments[target] = new_assignment + def _check_lib_sbml_errors( sbml_doc: libsbml.SBMLDocument, show_warnings: bool = False