|
9 | 9 | import re |
10 | 10 | from collections.abc import Callable, Sequence |
11 | 11 | from itertools import chain |
| 12 | +from operator import itemgetter |
12 | 13 | from typing import TYPE_CHECKING |
13 | 14 |
|
14 | 15 | import numpy as np |
@@ -798,12 +799,12 @@ def num_events(self) -> int: |
798 | 799 |
|
799 | 800 | def num_events_solver(self) -> int: |
800 | 801 | """ |
801 | | - Number of Events. |
| 802 | + Number of Events that rely on numerical root-finding. |
802 | 803 |
|
803 | 804 | :return: |
804 | 805 | number of event symbols (length of the root vector in AMICI) |
805 | 806 | """ |
806 | | - constant_syms = set(self.sym("k")) | set(self.sym("p")) |
| 807 | + constant_syms = self._static_symbols(["k", "p", "w"]) |
807 | 808 | return sum( |
808 | 809 | not event.has_explicit_trigger_times(constant_syms) |
809 | 810 | for event in self.events() |
@@ -1090,6 +1091,29 @@ def static_indices(self, name: str) -> list[int]: |
1090 | 1091 |
|
1091 | 1092 | raise NotImplementedError(name) |
1092 | 1093 |
|
| 1094 | + def _static_symbols(self, names: list[str]) -> set[sp.Symbol]: |
| 1095 | + """ |
| 1096 | + Return the static symbols among the given model entities. |
| 1097 | +
|
| 1098 | + E.g., `static_symbols(["p", "w"])` returns all symbols in `p` and `w` |
| 1099 | + that do not depend on time, neither directly nor indirectly. |
| 1100 | + """ |
| 1101 | + result = set() |
| 1102 | + |
| 1103 | + for name in names: |
| 1104 | + if name in ("k", "p"): |
| 1105 | + result |= set(self.sym(name)) |
| 1106 | + elif name == "w": |
| 1107 | + static_indices = self.static_indices("w") |
| 1108 | + if len(static_indices) == 1: |
| 1109 | + result.add(self.sym("w")[static_indices[0]]) |
| 1110 | + elif len(static_indices) > 1: |
| 1111 | + result |= set(itemgetter(*static_indices)(self.sym("w"))) |
| 1112 | + else: |
| 1113 | + raise ValueError(name) |
| 1114 | + |
| 1115 | + return result |
| 1116 | + |
1093 | 1117 | def dynamic_indices(self, name: str) -> list[int]: |
1094 | 1118 | """ |
1095 | 1119 | Return the indices of dynamic expressions in the given model entity. |
@@ -1307,13 +1331,8 @@ def parse_events(self) -> None: |
1307 | 1331 | # add roots of heaviside functions |
1308 | 1332 | self.add_component(root) |
1309 | 1333 |
|
1310 | | - # # Substitute 'w' expressions into root expressions, to avoid rewriting |
1311 | | - # # 'root.cpp' and 'stau.cpp' headers to include 'w.h'. |
1312 | | - # for event in self.events(): |
1313 | | - # event.set_val(event.get_val().subs(w_toposorted)) |
1314 | | - |
1315 | 1334 | # re-order events - first those that require root tracking, then the others |
1316 | | - constant_syms = set(self.sym("k")) | set(self.sym("p")) |
| 1335 | + constant_syms = self._static_symbols(["k", "p", "w"]) |
1317 | 1336 | self._events = list( |
1318 | 1337 | chain( |
1319 | 1338 | itertools.filterfalse( |
@@ -2506,11 +2525,6 @@ def _process_heavisides( |
2506 | 2525 | heavisides = [] |
2507 | 2526 | # run through the expression tree and get the roots |
2508 | 2527 | tmp_roots_old = self._collect_heaviside_roots((dxdt,)) |
2509 | | - # TODO remove: substitute 'w' symbols in the root expression by their equations, |
2510 | | - # because currently, |
2511 | | - # 1) root functions must not depend on 'w' |
2512 | | - # FIXME 2) the check for time-dependence currently assumes only state |
2513 | | - # variables are implicitly time-dependent |
2514 | 2528 | for tmp_root_old, tmp_x0_old in unique_preserve_order(tmp_roots_old): |
2515 | 2529 | # we want unique identifiers for the roots |
2516 | 2530 | tmp_root_new = self._get_unique_root(tmp_root_old, roots) |
|
0 commit comments