@@ -296,6 +296,8 @@ def __init__(
296296 ] = {
297297 "sroot" : {
298298 "eq" : "root" ,
299+ # TODO?
300+ # "chainvars": ["x", "w"],
299301 "chainvars" : ["x" ],
300302 "var" : "p" ,
301303 "dxdz_name" : "sx" ,
@@ -528,17 +530,9 @@ def get_rate(symbol: sp.Symbol):
528530
529531 for component in chain (
530532 self .observables (),
531- self .events (),
532533 self ._algebraic_equations ,
533534 ):
534535 if rate_ofs := component .get_val ().find (rate_of_func ):
535- if isinstance (component , Event ):
536- # TODO froot(...) can currently not depend on `w`, so this substitution fails for non-zero rates
537- # see, e.g., sbml test case 01293
538- raise SBMLException (
539- "AMICI does currently not support rateOf(.) inside event trigger functions."
540- )
541-
542536 if isinstance (component , AlgebraicEquation ):
543537 # TODO IDACalcIC fails with
544538 # "The linesearch algorithm failed: step too small or too many backtracks."
@@ -1340,10 +1334,10 @@ def parse_events(self) -> None:
13401334 # add roots of heaviside functions
13411335 self .add_component (root )
13421336
1343- # Substitute 'w' expressions into root expressions, to avoid rewriting
1344- # 'root.cpp' and 'stau.cpp' headers to include 'w.h'.
1345- for event in self .events ():
1346- event .set_val (event .get_val ().subs (w_toposorted ))
1337+ # # Substitute 'w' expressions into root expressions, to avoid rewriting
1338+ # # 'root.cpp' and 'stau.cpp' headers to include 'w.h'.
1339+ # for event in self.events():
1340+ # event.set_val(event.get_val().subs(w_toposorted))
13471341
13481342 # re-order events - first those that require root tracking, then the others
13491343 constant_syms = set (self .sym ("k" )) | set (self .sym ("p" ))
@@ -1668,18 +1662,23 @@ def _compute_equation(self, name: str) -> None:
16681662 self ._eqs [name ] = smart_jacobian (self .eq ("root" ), time_symbol )
16691663
16701664 elif name == "drootdt_total" :
1665+ # root(t, x(t), w(t, x(t)))
1666+ # drootdt_total = drootdt + drootdx * dxdt + drootdw * dwdt_total
1667+ # dwdt_total = dwdt + dwdx * dxdt
16711668 self ._eqs [name ] = self .eq ("drootdt" )
1672- # backsubstitution of optimized right-hand side terms into RHS
1673- # calling subs() is costly. We can skip it if we don't have any
1674- # state-dependent roots.
1669+
1670+ xdot = self .eq ("xdot" )
16751671 if self .num_states_solver () and not smart_is_zero_matrix (
16761672 drootdx := self .eq ("drootdx" )
16771673 ):
1678- w_sorted = toposort_symbols (
1679- dict (zip (self .sym ("w" ), self .eq ("w" ), strict = True ))
1680- )
1681- tmp_xdot = smart_subs_dict (self .eq ("xdot" ), w_sorted )
1682- self ._eqs [name ] += smart_multiply (drootdx , tmp_xdot )
1674+ self ._eqs [name ] += smart_multiply (drootdx , xdot )
1675+
1676+ drootdw = self .eq ("drootdw" )
1677+ dwdt = self .eq ("dwdt" )
1678+ dwdx = self .eq ("dwdx" )
1679+ dwdt_total = dwdt + smart_multiply (dwdx , xdot )
1680+
1681+ self ._eqs [name ] += smart_multiply (drootdw , dwdt_total )
16831682
16841683 elif name == "deltax" :
16851684 # fill boluses for Heaviside functions, as empty state updates
@@ -1763,16 +1762,26 @@ def _compute_equation(self, name: str) -> None:
17631762 ]
17641763
17651764 elif name == "dtaudx" :
1765+ # TODO drootdx + drootdw * dwdx
17661766 self ._eqs [name ] = [
1767- self .eq ("drootdx" )[ie , :] / self .eq ("drootdt_total" )[ie ]
1767+ (
1768+ self .eq ("drootdx" )[ie , :]
1769+ + self .eq ("drootdw" )[ie , :] * self .eq ("dwdx" )
1770+ )
1771+ / self .eq ("drootdt_total" )[ie ]
17681772 if not self .eq ("drootdt_total" )[ie ].is_zero
17691773 else sp .zeros (* self .eq ("drootdx" )[ie , :].shape )
17701774 for ie in range (self .num_events ())
17711775 ]
17721776
17731777 elif name == "dtaudp" :
1778+ # TODO drootdp + drootdw * dwdp
17741779 self ._eqs [name ] = [
1775- self .eq ("drootdp" )[ie , :] / self .eq ("drootdt_total" )[ie ]
1780+ (
1781+ self .eq ("drootdp" )[ie , :]
1782+ + self .eq ("drootdw" )[ie , :] * self .eq ("dwdp" )
1783+ )
1784+ / self .eq ("drootdt_total" )[ie ]
17761785 if not self .eq ("drootdt_total" )[ie ].is_zero
17771786 else sp .zeros (* self .eq ("drootdp" )[ie , :].shape )
17781787 for ie in range (self .num_events ())
@@ -1922,6 +1931,9 @@ def _compute_equation(self, name: str) -> None:
19221931 smart_jacobian (self .eq ("w" )[self .num_cons_law () :, :], x )
19231932 )
19241933
1934+ elif name == "dwdt" :
1935+ self ._eqs [name ] = smart_jacobian (self .eq ("w" ), time_symbol )
1936+
19251937 elif name == "iroot" :
19261938 self ._eqs [name ] = sp .Matrix (
19271939 [
@@ -2110,7 +2122,9 @@ def _derivative(self, eq: str, var: str, name: str = None) -> None:
21102122 "attach this model."
21112123 )
21122124
2113- if name == "dydw" and not smart_is_zero_matrix (derivative ):
2125+ elif name in ("dydw" , "drootdw" ) and not smart_is_zero_matrix (
2126+ derivative
2127+ ):
21142128 dwdw = self .eq ("dwdw" )
21152129 # h(k) = d{eq}dw*dwdw^k* (k=1)
21162130 h = smart_multiply (derivative , dwdw )
@@ -2407,6 +2421,8 @@ def _expr_is_time_dependent(self, expr: sp.Expr) -> bool:
24072421 :returns:
24082422 Whether the expression is time-dependent.
24092423 """
2424+ # TODO: handle w-dependency
2425+
24102426 # `expr.free_symbols` will be different to `self._states.keys()`, so
24112427 # it's easier to compare as `str`.
24122428 expr_syms = {str (sym ) for sym in expr .free_symbols }
@@ -2515,13 +2531,9 @@ def _process_heavisides(
25152531 tmp_roots_old = self ._collect_heaviside_roots ((dxdt ,))
25162532 # substitute 'w' symbols in the root expression by their equations,
25172533 # because currently,
2518- # 1) root functions must not depend on 'w'
2519- # 2) the check for time-dependence currently assumes only state
2534+ # # 1) root functions must not depend on 'w'
2535+ # FIXME 2) the check for time-dependence currently assumes only state
25202536 # variables are implicitly time-dependent
2521- tmp_roots_old = [
2522- (a .subs (w_toposorted ), b .subs (w_toposorted ))
2523- for a , b in tmp_roots_old
2524- ]
25252537 for tmp_root_old , tmp_x0_old in unique_preserve_order (tmp_roots_old ):
25262538 # we want unique identifiers for the roots
25272539 tmp_root_new = self ._get_unique_root (tmp_root_old , roots )
0 commit comments