@@ -1305,18 +1305,36 @@ def parse_events(self) -> None:
13051305 and replaces the formulae of the found roots by identifiers of AMICI's
13061306 Heaviside function implementation in the right-hand side
13071307 """
1308+ # toposorted w_sym -> w_expr for substitution of 'w' in trigger function
1309+ # do only once. `w` is not modified during this function.
1310+ w_toposorted = toposort_symbols (
1311+ dict (
1312+ zip (
1313+ [expr .get_id () for expr in self ._expressions ],
1314+ [expr .get_val () for expr in self ._expressions ],
1315+ strict = True ,
1316+ )
1317+ )
1318+ )
1319+
13081320 # Track all roots functions in the right-hand side
13091321 roots = copy .deepcopy (self ._events )
13101322 for state in self ._differential_states :
1311- state .set_dt (self ._process_heavisides (state .get_dt (), roots ))
1323+ state .set_dt (
1324+ self ._process_heavisides (state .get_dt (), roots , w_toposorted )
1325+ )
13121326
13131327 for expr in self ._expressions :
1314- expr .set_val (self ._process_heavisides (expr .get_val (), roots ))
1328+ expr .set_val (
1329+ self ._process_heavisides (expr .get_val (), roots , w_toposorted )
1330+ )
13151331
13161332 # remove all possible Heavisides from roots, which may arise from
13171333 # the substitution of `'w'` in `_collect_heaviside_roots`
13181334 for root in roots :
1319- root .set_val (self ._process_heavisides (root .get_val (), roots ))
1335+ root .set_val (
1336+ self ._process_heavisides (root .get_val (), roots , w_toposorted )
1337+ )
13201338
13211339 # Now add the found roots to the model components
13221340 for root in roots :
@@ -1326,6 +1344,11 @@ def parse_events(self) -> None:
13261344 # add roots of heaviside functions
13271345 self .add_component (root )
13281346
1347+ # Substitute 'w' expressions into root expressions, to avoid rewriting
1348+ # 'root.cpp' and 'stau.cpp' headers to include 'w.h'.
1349+ for event in self .events ():
1350+ event .set_val (event .get_val ().subs (w_toposorted ))
1351+
13291352 # re-order events - first those that require root tracking, then the others
13301353 constant_syms = set (self .sym ("k" )) | set (self .sym ("p" ))
13311354 self ._events = list (
@@ -2391,7 +2414,7 @@ def _expr_is_time_dependent(self, expr: sp.Expr) -> bool:
23912414 expr_syms = {str (sym ) for sym in expr .free_symbols }
23922415
23932416 # Check if the time variable is in the expression.
2394- if "t" in expr_syms :
2417+ if amici_time_symbol . name in expr_syms :
23952418 return True
23962419
23972420 # Check if any time-dependent states are in the expression.
@@ -2464,33 +2487,11 @@ def _collect_heaviside_roots(
24642487
24652488 return root_funs
24662489
2467- def _substitute_w_in_roots (
2468- self ,
2469- root_funs : list [tuple [sp .Expr , sp .Expr ]],
2470- ) -> list [tuple [sp .Expr , sp .Expr ]]:
2471- """
2472- Substitute 'w' expressions into root expressions, to avoid rewriting
2473- 'root.cpp' and 'stau.cpp' headers to include 'w.h'.
2474- """
2475- w_sorted = toposort_symbols (
2476- dict (
2477- zip (
2478- [expr .get_id () for expr in self ._expressions ],
2479- [expr .get_val () for expr in self ._expressions ],
2480- strict = True ,
2481- )
2482- )
2483- )
2484- root_funs = [
2485- (r [0 ].subs (w_sorted ), r [1 ].subs (w_sorted )) for r in root_funs
2486- ]
2487-
2488- return root_funs
2489-
24902490 def _process_heavisides (
24912491 self ,
24922492 dxdt : sp .Expr ,
24932493 roots : list [Event ],
2494+ w_toposorted : dict [sp .Symbol , sp .Expr ],
24942495 ) -> sp .Expr :
24952496 """
24962497 Parses the RHS of a state variable, checks for Heaviside functions,
@@ -2502,7 +2503,8 @@ def _process_heavisides(
25022503 right-hand side of state variable
25032504 :param roots:
25042505 list of known root functions with identifier
2505-
2506+ :param w_toposorted:
2507+ `w` symbols->expressions sorted in topological order
25062508 :returns:
25072509 dxdt with Heaviside functions replaced by amici helper variables
25082510 """
@@ -2511,7 +2513,15 @@ def _process_heavisides(
25112513 heavisides = []
25122514 # run through the expression tree and get the roots
25132515 tmp_roots_old = self ._collect_heaviside_roots ((dxdt ,))
2514- tmp_roots_old = self ._substitute_w_in_roots (tmp_roots_old )
2516+ # substitute 'w' symbols in the root expression by their equations,
2517+ # because currently,
2518+ # 1) root functions must not depend on 'w'
2519+ # 2) the check for time-dependence currently assumes only state
2520+ # variables are implicitly time-dependent
2521+ tmp_roots_old = [
2522+ (a .subs (w_toposorted ), b .subs (w_toposorted ))
2523+ for a , b in tmp_roots_old
2524+ ]
25152525 for tmp_root_old , tmp_x0_old in unique_preserve_order (tmp_roots_old ):
25162526 # we want unique identifiers for the roots
25172527 tmp_root_new = self ._get_unique_root (tmp_root_old , roots )
0 commit comments