Skip to content

Commit 5ccb05a

Browse files
committed
Refactor rateOf handling
Pull rateOf-handling out of DEModel and keep it along other SBML processing where it belongs. This is easier to follow and prevents some lingering issues with the old approach due to the xdot / w interdependencies. This also handles rateOf expressions in some additional, previously unsupported places like event assignments.
1 parent 7bf3b76 commit 5ccb05a

File tree

2 files changed

+154
-165
lines changed

2 files changed

+154
-165
lines changed

python/sdist/amici/de_model.py

Lines changed: 34 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import contextlib
65
import copy
76
import itertools
87
import logging
@@ -47,7 +46,6 @@
4746
from .exporters.sundials.cxxcodeprinter import csc_matrix
4847
from .importers.utils import (
4948
ObservableTransformation,
50-
SBMLException,
5149
_default_simplify,
5250
amici_time_symbol,
5351
smart_subs_dict,
@@ -404,165 +402,6 @@ def states(self) -> list[State]:
404402
"""Get all states."""
405403
return self._differential_states + self._algebraic_states
406404

407-
def _process_sbml_rate_of(self) -> None:
408-
"""Substitute any SBML-rateOf constructs in the model equations"""
409-
from sbmlmath import rate_of as rate_of_func
410-
411-
species_sym_to_xdot = dict(
412-
zip(self.sym("x"), self.sym("xdot"), strict=True)
413-
)
414-
species_sym_to_idx = {x: i for i, x in enumerate(self.sym("x"))}
415-
416-
def get_rate(symbol: sp.Symbol):
417-
"""Get rate of change of the given symbol"""
418-
if symbol.find(rate_of_func):
419-
raise SBMLException("Nesting rateOf() is not allowed.")
420-
421-
# Replace all rateOf(some_species) by their respective xdot equation
422-
with contextlib.suppress(KeyError):
423-
return self._eqs["xdot"][species_sym_to_idx[symbol]]
424-
425-
# For anything other than a state, rateOf(.) is 0 or invalid
426-
return 0
427-
428-
# replace rateOf-instances in xdot by xdot symbols
429-
made_substitutions = False
430-
for i_state in range(len(self.eq("xdot"))):
431-
if rate_ofs := self._eqs["xdot"][i_state].find(rate_of_func):
432-
self._eqs["xdot"][i_state] = self._eqs["xdot"][i_state].subs(
433-
{
434-
# either the rateOf argument is a state, or it's 0
435-
rate_of: species_sym_to_xdot.get(rate_of.args[0], 0)
436-
for rate_of in rate_ofs
437-
}
438-
)
439-
made_substitutions = True
440-
441-
if made_substitutions:
442-
# substitute in topological order
443-
subs = toposort_symbols(
444-
dict(zip(self.sym("xdot"), self.eq("xdot"), strict=True))
445-
)
446-
self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs)
447-
448-
# replace rateOf-instances in w by xdot equation
449-
# here we may need toposort, as xdot may depend on w
450-
made_substitutions = False
451-
for i_expr in range(len(self.eq("w"))):
452-
new, replacement = self._eqs["w"][i_expr].replace(
453-
rate_of_func, get_rate, map=True
454-
)
455-
if replacement:
456-
self._eqs["w"][i_expr] = new
457-
made_substitutions = True
458-
459-
if made_substitutions:
460-
# Sort expressions in self._expressions, w symbols, and w equations
461-
# in topological order. Ideally, this would already happen before
462-
# adding the expressions to the model, but at that point, we don't
463-
# have access to xdot yet.
464-
# NOTE: elsewhere, conservations law expressions are expected to
465-
# occur before any other w expressions, so we must maintain their
466-
# position
467-
# toposort everything but conservation law expressions,
468-
# then prepend conservation laws
469-
w_sorted = toposort_symbols(
470-
dict(
471-
zip(
472-
self.sym("w")[self.num_cons_law() :, :],
473-
self.eq("w")[self.num_cons_law() :, :],
474-
strict=True,
475-
)
476-
)
477-
)
478-
w_sorted = (
479-
dict(
480-
zip(
481-
self.sym("w")[: self.num_cons_law(), :],
482-
self.eq("w")[: self.num_cons_law(), :],
483-
strict=True,
484-
)
485-
)
486-
| w_sorted
487-
)
488-
old_syms = tuple(self._syms["w"])
489-
topo_expr_syms = tuple(w_sorted.keys())
490-
new_order = [old_syms.index(s) for s in topo_expr_syms]
491-
self._expressions = [self._expressions[i] for i in new_order]
492-
self._syms["w"] = sp.Matrix(topo_expr_syms)
493-
self._eqs["w"] = sp.Matrix(list(w_sorted.values()))
494-
495-
# replace rateOf-instances in x0 by xdot equation
496-
# indices of state variables whose x0 was modified
497-
changed_indices = []
498-
for i_state in range(len(self.eq("x0"))):
499-
new, replacement = self._eqs["x0"][i_state].replace(
500-
rate_of_func, get_rate, map=True
501-
)
502-
if replacement:
503-
self._eqs["x0"][i_state] = new
504-
changed_indices.append(i_state)
505-
if changed_indices:
506-
# Replace any newly introduced state variables
507-
# by their x0 expressions.
508-
# Also replace any newly introduced `w` symbols by their
509-
# expressions (after `w` was toposorted above).
510-
subs = toposort_symbols(
511-
dict(zip(self.sym("x_rdata"), self.eq("x0"), strict=True))
512-
)
513-
subs = dict(zip(self._syms["w"], self.eq("w"), strict=True)) | subs
514-
for i_state in changed_indices:
515-
self._eqs["x0"][i_state] = smart_subs_dict(
516-
self._eqs["x0"][i_state], subs
517-
)
518-
519-
for component in chain(
520-
self.observables(),
521-
self.events(),
522-
self._algebraic_equations,
523-
):
524-
if rate_ofs := component.get_val().find(rate_of_func):
525-
if isinstance(component, Event):
526-
# TODO froot(...) can currently not depend on `w`, so this substitution fails for non-zero rates
527-
# see, e.g., sbml test case 01293
528-
raise SBMLException(
529-
"AMICI does currently not support rateOf(.) inside event trigger functions."
530-
)
531-
532-
if isinstance(component, AlgebraicEquation):
533-
# TODO IDACalcIC fails with
534-
# "The linesearch algorithm failed: step too small or too many backtracks."
535-
# see, e.g., sbml test case 01482
536-
raise SBMLException(
537-
"AMICI does currently not support rateOf(.) inside AlgebraicRules."
538-
)
539-
540-
component.set_val(
541-
component.get_val().subs(
542-
{
543-
rate_of: get_rate(rate_of.args[0])
544-
for rate_of in rate_ofs
545-
}
546-
)
547-
)
548-
549-
for event in self.events():
550-
state_update = event.get_state_update(
551-
x=self.sym("x"), x_old=self.sym("x")
552-
)
553-
if state_update is None:
554-
continue
555-
556-
for i_state in range(len(state_update)):
557-
if rate_ofs := state_update[i_state].find(rate_of_func):
558-
raise SBMLException(
559-
"AMICI does currently not support rateOf(.) inside event state updates."
560-
)
561-
# TODO here we need xdot sym, not eqs
562-
# event._state_update[i_state] = event._state_update[i_state].subs(
563-
# {rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs}
564-
# )
565-
566405
def add_component(
567406
self, component: ModelQuantity, insert_first: bool | None = False
568407
) -> None:
@@ -2752,3 +2591,37 @@ def has_event_assignments(self) -> bool:
27522591
boolean indicating if event assignments are present
27532592
"""
27542593
return any(event.updates_state for event in self._events)
2594+
2595+
def toposort_expressions(self) -> dict[sp.Symbol, sp.Expr]:
2596+
"""
2597+
Sort expressions in topological order.
2598+
2599+
:return:
2600+
dict of expression symbols to expressions in topological order
2601+
"""
2602+
# NOTE: elsewhere, conservations law expressions are expected to
2603+
# occur before any other w expressions, so we must maintain their
2604+
# position.
2605+
# toposort everything but conservation law expressions,
2606+
# then prepend conservation laws
2607+
if self._syms or self._eqs:
2608+
raise AssertionError(
2609+
"This function must be called before generating any symbols "
2610+
"or equations."
2611+
)
2612+
w_toposorted = toposort_symbols(
2613+
{
2614+
e.get_sym(): e.get_val()
2615+
for e in self.expressions()[self.num_cons_law() :]
2616+
}
2617+
)
2618+
2619+
w_toposorted = {
2620+
e.get_sym(): e.get_val()
2621+
for e in self.expressions()[: self.num_cons_law()]
2622+
} | w_toposorted
2623+
old_syms = [e.get_sym() for e in self.expressions()]
2624+
self._expressions = [
2625+
self._expressions[old_syms.index(s)] for s in w_toposorted
2626+
]
2627+
return w_toposorted

python/sdist/amici/importers/sbml/__init__.py

Lines changed: 120 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
in the `Systems Biology Markup Language (SBML) <https://sbml.org/>`_.
66
"""
77

8+
import contextlib
89
import copy
910
import itertools as itt
1011
import logging
@@ -14,6 +15,7 @@
1415
import warnings
1516
import xml.etree.ElementTree as ET
1617
from collections.abc import Callable, Iterable, Sequence
18+
from itertools import chain
1719
from pathlib import Path
1820
from typing import (
1921
Any,
@@ -30,7 +32,12 @@
3032
from amici import get_model_dir, has_clibs
3133
from amici.constants import SymbolId
3234
from amici.de_model import DEModel
33-
from amici.de_model_components import Expression, symbol_to_type
35+
from amici.de_model_components import (
36+
DifferentialState,
37+
Event,
38+
Expression,
39+
symbol_to_type,
40+
)
3441
from amici.importers.sbml.splines import AbstractSpline
3542
from amici.importers.sbml.utils import SBMLException
3643
from amici.importers.utils import (
@@ -682,12 +689,12 @@ def _build_ode_model(
682689
if hybridization:
683690
ode_model._process_hybridization(hybridization)
684691

692+
# substitute SBML-rateOf constructs
693+
self._process_sbml_rate_of(ode_model)
694+
685695
# fill in 'self._sym' based on prototypes and components in ode_model
686696
ode_model.generate_basic_variables()
687697

688-
# substitute SBML-rateOf constructs
689-
ode_model._process_sbml_rate_of()
690-
691698
return ode_model
692699

693700
@log_execution_time("importing SBML", logger)
@@ -3014,6 +3021,115 @@ def _transform_dxdt_to_concentration(
30143021

30153022
return dxdt / v
30163023

3024+
def _process_sbml_rate_of(self, de_model: DEModel) -> None:
3025+
"""Substitute any SBML-rateOf constructs in the model equations"""
3026+
from sbmlmath import rate_of as rate_of_func
3027+
3028+
sym_to_state_var: dict[sp.Symbol, DifferentialState] = {
3029+
state.get_sym(): state for state in de_model.differential_states()
3030+
}
3031+
3032+
is_ode = de_model.is_ode()
3033+
3034+
def get_rate(symbol: sp.Symbol):
3035+
"""Get rate of change of the given symbol
3036+
(symbol is argument of rateOf)"""
3037+
if not is_ode:
3038+
# Not implemented for DAE models
3039+
# sbml test case 01482
3040+
raise SBMLException(
3041+
"rateOf() is only supported in ODE models."
3042+
)
3043+
3044+
if symbol.find(rate_of_func):
3045+
raise SBMLException("Nesting rateOf() is not allowed.")
3046+
3047+
# Replace all rateOf(some_state_var) by their respective
3048+
# xdot equation
3049+
with contextlib.suppress(KeyError):
3050+
return sym_to_state_var[symbol].get_dt()
3051+
3052+
# For anything other than a state variable,
3053+
# rateOf(.) is 0 or invalid
3054+
return sp.Integer(0)
3055+
3056+
def do_subs(expr, rate_ofs) -> sp.Expr:
3057+
"""Substitute rateOf(...) in expr by appropriate expressions"""
3058+
expr = expr.subs(
3059+
{rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs}
3060+
)
3061+
if rate_ofs := expr.find(rate_of_func):
3062+
# recursively substitute until no rateOf remains
3063+
expr = do_subs(expr, rate_ofs)
3064+
return expr
3065+
3066+
# replace rateOf-instances in xdot
3067+
for state in de_model.differential_states():
3068+
if rate_ofs := state.get_dt().find(rate_of_func):
3069+
state.set_dt(do_subs(state.get_dt(), rate_ofs))
3070+
3071+
# replace rateOf-instances in expressions which we will need for
3072+
# substitutions later
3073+
for expr in de_model.expressions():
3074+
if rate_ofs := expr.get_val().find(rate_of_func):
3075+
expr.set_val(do_subs(expr.get_val(), rate_ofs))
3076+
3077+
w_toposorted = de_model.toposort_expressions()
3078+
3079+
# replace rateOf-instances in x0
3080+
# indices of state variables whose x0 was modified
3081+
changed_indices = []
3082+
for i_state, state in enumerate(de_model.differential_states()):
3083+
new, replacement = state.get_val().replace(
3084+
rate_of_func, get_rate, map=True
3085+
)
3086+
if replacement:
3087+
state.set_val(new)
3088+
changed_indices.append(i_state)
3089+
if changed_indices:
3090+
# Replace any newly introduced state variables
3091+
# by their x0 expressions.
3092+
subs = w_toposorted | toposort_symbols(
3093+
{
3094+
state.get_x_rdata(): state.get_val()
3095+
for state in de_model.differential_states()
3096+
}
3097+
)
3098+
states = de_model.differential_states()
3099+
for i_state in changed_indices:
3100+
states[i_state].set_val(
3101+
smart_subs_dict(states[i_state].get_val(), subs)
3102+
)
3103+
3104+
for component in chain(
3105+
de_model.observables(),
3106+
de_model.events(),
3107+
):
3108+
if rate_ofs := component.get_val().find(rate_of_func):
3109+
component.set_val(do_subs(component.get_val(), rate_ofs))
3110+
3111+
if isinstance(component, Event):
3112+
if rate_ofs:
3113+
# currently, `root` cannot depend on `w`.
3114+
# this could be changed, but for now,
3115+
# we just flatten out w expressions
3116+
component.set_val(
3117+
smart_subs_dict(component.get_val(), w_toposorted)
3118+
)
3119+
3120+
# for events, also substitute in state updates
3121+
for target, assignment in component._assignments.items():
3122+
if rate_ofs := assignment.find(rate_of_func):
3123+
new_assignment = do_subs(assignment, rate_ofs)
3124+
# currently, deltax cannot depend on `w`.
3125+
# this could be changed, or we could use xdot
3126+
# symbols, but for now, we just flatten out w
3127+
# expressions
3128+
new_assignment = smart_subs_dict(
3129+
new_assignment, w_toposorted
3130+
)
3131+
component._assignments[target] = new_assignment
3132+
30173133

30183134
def _check_lib_sbml_errors(
30193135
sbml_doc: libsbml.SBMLDocument, show_warnings: bool = False

0 commit comments

Comments
 (0)