Skip to content

Commit 37c193e

Browse files
committed
..
1 parent ca6dca0 commit 37c193e

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

python/sdist/amici/de_model.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import re
1010
from collections.abc import Callable, Sequence
1111
from itertools import chain
12+
from operator import itemgetter
1213
from typing import TYPE_CHECKING
1314

1415
import numpy as np
@@ -798,12 +799,12 @@ def num_events(self) -> int:
798799

799800
def num_events_solver(self) -> int:
800801
"""
801-
Number of Events.
802+
Number of Events that rely on numerical root-finding.
802803
803804
:return:
804805
number of event symbols (length of the root vector in AMICI)
805806
"""
806-
constant_syms = set(self.sym("k")) | set(self.sym("p"))
807+
constant_syms = self._static_symbols(["k", "p", "w"])
807808
return sum(
808809
not event.has_explicit_trigger_times(constant_syms)
809810
for event in self.events()
@@ -1090,6 +1091,29 @@ def static_indices(self, name: str) -> list[int]:
10901091

10911092
raise NotImplementedError(name)
10921093

1094+
def _static_symbols(self, names: list[str]) -> set[sp.Symbol]:
1095+
"""
1096+
Return the static symbols among the given model entities.
1097+
1098+
E.g., `static_symbols(["p", "w"])` returns all symbols in `p` and `w`
1099+
that do not depend on time, neither directly nor indirectly.
1100+
"""
1101+
result = set()
1102+
1103+
for name in names:
1104+
if name in ("k", "p"):
1105+
result |= set(self.sym(name))
1106+
elif name == "w":
1107+
static_indices = self.static_indices("w")
1108+
if len(static_indices) == 1:
1109+
result.add(self.sym("w")[static_indices[0]])
1110+
elif len(static_indices) > 1:
1111+
result |= set(itemgetter(*static_indices)(self.sym("w")))
1112+
else:
1113+
raise ValueError(name)
1114+
1115+
return result
1116+
10931117
def dynamic_indices(self, name: str) -> list[int]:
10941118
"""
10951119
Return the indices of dynamic expressions in the given model entity.
@@ -1307,13 +1331,8 @@ def parse_events(self) -> None:
13071331
# add roots of heaviside functions
13081332
self.add_component(root)
13091333

1310-
# # Substitute 'w' expressions into root expressions, to avoid rewriting
1311-
# # 'root.cpp' and 'stau.cpp' headers to include 'w.h'.
1312-
# for event in self.events():
1313-
# event.set_val(event.get_val().subs(w_toposorted))
1314-
13151334
# re-order events - first those that require root tracking, then the others
1316-
constant_syms = set(self.sym("k")) | set(self.sym("p"))
1335+
constant_syms = self._static_symbols(["k", "p", "w"])
13171336
self._events = list(
13181337
chain(
13191338
itertools.filterfalse(
@@ -2506,11 +2525,6 @@ def _process_heavisides(
25062525
heavisides = []
25072526
# run through the expression tree and get the roots
25082527
tmp_roots_old = self._collect_heaviside_roots((dxdt,))
2509-
# TODO remove: substitute 'w' symbols in the root expression by their equations,
2510-
# because currently,
2511-
# 1) root functions must not depend on 'w'
2512-
# FIXME 2) the check for time-dependence currently assumes only state
2513-
# variables are implicitly time-dependent
25142528
for tmp_root_old, tmp_x0_old in unique_preserve_order(tmp_roots_old):
25152529
# we want unique identifiers for the roots
25162530
tmp_root_new = self._get_unique_root(tmp_root_old, roots)

0 commit comments

Comments
 (0)