Skip to content

Commit a666ef9

Browse files
authored
Speed up DEModel._collect_heaviside_roots (#2977)
Avoid unnecessary repeated toposorting of `w` during `_collect_heaviside_root`. Sort and substitute only once after all roots have been collected. This saves a couple of seconds for models with heavily nested piecewise functions.
1 parent 11781f5 commit a666ef9

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

python/sdist/amici/de_model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2462,12 +2462,16 @@ def _collect_heaviside_roots(
24622462
elif arg.has(sp.Heaviside):
24632463
root_funs.extend(self._collect_heaviside_roots(arg.args))
24642464

2465-
if not root_funs:
2466-
return []
2465+
return root_funs
24672466

2468-
# substitute 'w' expressions into root expressions now, to avoid
2469-
# rewriting 'root.cpp' and 'stau.cpp' headers
2470-
# to include 'w.h'
2467+
def _substitute_w_in_roots(
2468+
self,
2469+
root_funs: list[tuple[sp.Expr, sp.Expr]],
2470+
) -> list[tuple[sp.Expr, sp.Expr]]:
2471+
"""
2472+
Substitute 'w' expressions into root expressions, to avoid rewriting
2473+
'root.cpp' and 'stau.cpp' headers to include 'w.h'.
2474+
"""
24712475
w_sorted = toposort_symbols(
24722476
dict(
24732477
zip(
@@ -2507,6 +2511,7 @@ def _process_heavisides(
25072511
heavisides = []
25082512
# run through the expression tree and get the roots
25092513
tmp_roots_old = self._collect_heaviside_roots((dxdt,))
2514+
tmp_roots_old = self._substitute_w_in_roots(tmp_roots_old)
25102515
for tmp_root_old, tmp_x0_old in unique_preserve_order(tmp_roots_old):
25112516
# we want unique identifiers for the roots
25122517
tmp_root_new = self._get_unique_root(tmp_root_old, roots)

0 commit comments

Comments
 (0)