Skip to content

Commit 972b514

Browse files
committed
Refactor rateOf handling
Pull rateOf-handling out of DEModel and keep it along other SBML processing where it belongs. This is easier to follow and prevents some lingering issues with the old approach due to the xdot / w interdependencies. This also handles rateOf expressions in some additional, previously unsupported places like event assignments.
1 parent 7bf3b76 commit 972b514

File tree

2 files changed

+169
-181
lines changed

2 files changed

+169
-181
lines changed

python/sdist/amici/de_model.py

Lines changed: 49 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import contextlib
65
import copy
76
import itertools
87
import logging
@@ -47,7 +46,6 @@
4746
from .exporters.sundials.cxxcodeprinter import csc_matrix
4847
from .importers.utils import (
4948
ObservableTransformation,
50-
SBMLException,
5149
_default_simplify,
5250
amici_time_symbol,
5351
smart_subs_dict,
@@ -404,165 +402,6 @@ def states(self) -> list[State]:
404402
"""Get all states."""
405403
return self._differential_states + self._algebraic_states
406404

407-
def _process_sbml_rate_of(self) -> None:
408-
"""Substitute any SBML-rateOf constructs in the model equations"""
409-
from sbmlmath import rate_of as rate_of_func
410-
411-
species_sym_to_xdot = dict(
412-
zip(self.sym("x"), self.sym("xdot"), strict=True)
413-
)
414-
species_sym_to_idx = {x: i for i, x in enumerate(self.sym("x"))}
415-
416-
def get_rate(symbol: sp.Symbol):
417-
"""Get rate of change of the given symbol"""
418-
if symbol.find(rate_of_func):
419-
raise SBMLException("Nesting rateOf() is not allowed.")
420-
421-
# Replace all rateOf(some_species) by their respective xdot equation
422-
with contextlib.suppress(KeyError):
423-
return self._eqs["xdot"][species_sym_to_idx[symbol]]
424-
425-
# For anything other than a state, rateOf(.) is 0 or invalid
426-
return 0
427-
428-
# replace rateOf-instances in xdot by xdot symbols
429-
made_substitutions = False
430-
for i_state in range(len(self.eq("xdot"))):
431-
if rate_ofs := self._eqs["xdot"][i_state].find(rate_of_func):
432-
self._eqs["xdot"][i_state] = self._eqs["xdot"][i_state].subs(
433-
{
434-
# either the rateOf argument is a state, or it's 0
435-
rate_of: species_sym_to_xdot.get(rate_of.args[0], 0)
436-
for rate_of in rate_ofs
437-
}
438-
)
439-
made_substitutions = True
440-
441-
if made_substitutions:
442-
# substitute in topological order
443-
subs = toposort_symbols(
444-
dict(zip(self.sym("xdot"), self.eq("xdot"), strict=True))
445-
)
446-
self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs)
447-
448-
# replace rateOf-instances in w by xdot equation
449-
# here we may need toposort, as xdot may depend on w
450-
made_substitutions = False
451-
for i_expr in range(len(self.eq("w"))):
452-
new, replacement = self._eqs["w"][i_expr].replace(
453-
rate_of_func, get_rate, map=True
454-
)
455-
if replacement:
456-
self._eqs["w"][i_expr] = new
457-
made_substitutions = True
458-
459-
if made_substitutions:
460-
# Sort expressions in self._expressions, w symbols, and w equations
461-
# in topological order. Ideally, this would already happen before
462-
# adding the expressions to the model, but at that point, we don't
463-
# have access to xdot yet.
464-
# NOTE: elsewhere, conservations law expressions are expected to
465-
# occur before any other w expressions, so we must maintain their
466-
# position
467-
# toposort everything but conservation law expressions,
468-
# then prepend conservation laws
469-
w_sorted = toposort_symbols(
470-
dict(
471-
zip(
472-
self.sym("w")[self.num_cons_law() :, :],
473-
self.eq("w")[self.num_cons_law() :, :],
474-
strict=True,
475-
)
476-
)
477-
)
478-
w_sorted = (
479-
dict(
480-
zip(
481-
self.sym("w")[: self.num_cons_law(), :],
482-
self.eq("w")[: self.num_cons_law(), :],
483-
strict=True,
484-
)
485-
)
486-
| w_sorted
487-
)
488-
old_syms = tuple(self._syms["w"])
489-
topo_expr_syms = tuple(w_sorted.keys())
490-
new_order = [old_syms.index(s) for s in topo_expr_syms]
491-
self._expressions = [self._expressions[i] for i in new_order]
492-
self._syms["w"] = sp.Matrix(topo_expr_syms)
493-
self._eqs["w"] = sp.Matrix(list(w_sorted.values()))
494-
495-
# replace rateOf-instances in x0 by xdot equation
496-
# indices of state variables whose x0 was modified
497-
changed_indices = []
498-
for i_state in range(len(self.eq("x0"))):
499-
new, replacement = self._eqs["x0"][i_state].replace(
500-
rate_of_func, get_rate, map=True
501-
)
502-
if replacement:
503-
self._eqs["x0"][i_state] = new
504-
changed_indices.append(i_state)
505-
if changed_indices:
506-
# Replace any newly introduced state variables
507-
# by their x0 expressions.
508-
# Also replace any newly introduced `w` symbols by their
509-
# expressions (after `w` was toposorted above).
510-
subs = toposort_symbols(
511-
dict(zip(self.sym("x_rdata"), self.eq("x0"), strict=True))
512-
)
513-
subs = dict(zip(self._syms["w"], self.eq("w"), strict=True)) | subs
514-
for i_state in changed_indices:
515-
self._eqs["x0"][i_state] = smart_subs_dict(
516-
self._eqs["x0"][i_state], subs
517-
)
518-
519-
for component in chain(
520-
self.observables(),
521-
self.events(),
522-
self._algebraic_equations,
523-
):
524-
if rate_ofs := component.get_val().find(rate_of_func):
525-
if isinstance(component, Event):
526-
# TODO froot(...) can currently not depend on `w`, so this substitution fails for non-zero rates
527-
# see, e.g., sbml test case 01293
528-
raise SBMLException(
529-
"AMICI does currently not support rateOf(.) inside event trigger functions."
530-
)
531-
532-
if isinstance(component, AlgebraicEquation):
533-
# TODO IDACalcIC fails with
534-
# "The linesearch algorithm failed: step too small or too many backtracks."
535-
# see, e.g., sbml test case 01482
536-
raise SBMLException(
537-
"AMICI does currently not support rateOf(.) inside AlgebraicRules."
538-
)
539-
540-
component.set_val(
541-
component.get_val().subs(
542-
{
543-
rate_of: get_rate(rate_of.args[0])
544-
for rate_of in rate_ofs
545-
}
546-
)
547-
)
548-
549-
for event in self.events():
550-
state_update = event.get_state_update(
551-
x=self.sym("x"), x_old=self.sym("x")
552-
)
553-
if state_update is None:
554-
continue
555-
556-
for i_state in range(len(state_update)):
557-
if rate_ofs := state_update[i_state].find(rate_of_func):
558-
raise SBMLException(
559-
"AMICI does currently not support rateOf(.) inside event state updates."
560-
)
561-
# TODO here we need xdot sym, not eqs
562-
# event._state_update[i_state] = event._state_update[i_state].subs(
563-
# {rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs}
564-
# )
565-
566405
def add_component(
567406
self, component: ModelQuantity, insert_first: bool | None = False
568407
) -> None:
@@ -2694,22 +2533,7 @@ def _process_hybridization(self, hybridization: dict) -> None:
26942533
self._observables = [self._observables[i] for i in new_order]
26952534

26962535
if added_expressions:
2697-
# toposort expressions
2698-
w_sorted = toposort_symbols(
2699-
dict(
2700-
zip(
2701-
self.sym("w"),
2702-
self.eq("w"),
2703-
strict=True,
2704-
)
2705-
)
2706-
)
2707-
old_syms = tuple(self._syms["w"])
2708-
topo_expr_syms = tuple(w_sorted.keys())
2709-
new_order = [old_syms.index(s) for s in topo_expr_syms]
2710-
self._expressions = [self._expressions[i] for i in new_order]
2711-
self._syms["w"] = sp.Matrix(topo_expr_syms)
2712-
self._eqs["w"] = sp.Matrix(list(w_sorted.values()))
2536+
self.toposort_expressions()
27132537

27142538
def get_explicit_roots(self) -> set[sp.Expr]:
27152539
"""
@@ -2752,3 +2576,51 @@ def has_event_assignments(self) -> bool:
27522576
boolean indicating if event assignments are present
27532577
"""
27542578
return any(event.updates_state for event in self._events)
2579+
2580+
def toposort_expressions(self) -> dict[sp.Symbol, sp.Expr]:
2581+
"""
2582+
Sort expressions in topological order.
2583+
2584+
:return:
2585+
dict of expression symbols to expressions in topological order
2586+
"""
2587+
# ensure no symbols or equations that depend on `w` have been generated
2588+
# yet, otherwise the re-ordering might break dependencies
2589+
if (
2590+
generated := set(self._syms)
2591+
| set(self._eqs)
2592+
| set(self._sparsesyms)
2593+
| set(self._sparseeqs)
2594+
) - {"w", "p", "k", "x", "x_rdata"}:
2595+
raise AssertionError(
2596+
"This function must be called before computing any "
2597+
"derivatives. The following symbols/equations are already "
2598+
f"generated: {generated}"
2599+
)
2600+
2601+
# NOTE: elsewhere, conservations law expressions are expected to
2602+
# occur before any other w expressions, so we must maintain their
2603+
# position.
2604+
# toposort everything but conservation law expressions,
2605+
# then prepend conservation laws
2606+
2607+
w_toposorted = toposort_symbols(
2608+
{
2609+
e.get_sym(): e.get_val()
2610+
for e in self.expressions()[self.num_cons_law() :]
2611+
}
2612+
)
2613+
2614+
w_toposorted = {
2615+
e.get_sym(): e.get_val()
2616+
for e in self.expressions()[: self.num_cons_law()]
2617+
} | w_toposorted
2618+
2619+
old_syms = tuple(e.get_sym() for e in self.expressions())
2620+
topo_expr_syms = tuple(w_toposorted)
2621+
new_order = [old_syms.index(s) for s in topo_expr_syms]
2622+
self._expressions = [self._expressions[i] for i in new_order]
2623+
self._syms["w"] = sp.Matrix(topo_expr_syms)
2624+
self._eqs["w"] = sp.Matrix(list(w_toposorted.values()))
2625+
2626+
return w_toposorted

0 commit comments

Comments
 (0)