Skip to content

Commit 6598f11

Browse files
committed
..
1 parent 5272301 commit 6598f11

File tree

2 files changed

+27
-34
lines changed

2 files changed

+27
-34
lines changed

python/sdist/amici/de_model.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import TYPE_CHECKING
1313

1414
import numpy as np
15+
import pysb
1516
import sympy as sp
1617
from sympy import ImmutableDenseMatrix, MutableDenseMatrix
1718

@@ -1295,36 +1296,22 @@ def parse_events(self) -> None:
12951296
and replaces the formulae of the found roots by identifiers of AMICI's
12961297
Heaviside function implementation in the right-hand side
12971298
"""
1298-
# toposorted w_sym -> w_expr for substitution of 'w' in trigger function
1299-
# do only once. `w` is not modified during this function.
1300-
w_toposorted = toposort_symbols(
1301-
dict(
1302-
zip(
1303-
[expr.get_sym() for expr in self._expressions],
1304-
[expr.get_val() for expr in self._expressions],
1305-
strict=True,
1306-
)
1307-
)
1308-
)
1309-
13101299
# Track all roots functions in the right-hand side
13111300
roots = copy.deepcopy(self._events)
13121301
for state in self._differential_states:
1313-
state.set_dt(
1314-
self._process_heavisides(state.get_dt(), roots, w_toposorted)
1315-
)
1302+
state.set_dt(self._process_heavisides(state.get_dt(), roots))
13161303

13171304
for expr in self._expressions:
1318-
expr.set_val(
1319-
self._process_heavisides(expr.get_val(), roots, w_toposorted)
1320-
)
1305+
expr.set_val(self._process_heavisides(expr.get_val(), roots))
1306+
if isinstance((pysb_component := expr.get_sym()), pysb.Expression):
1307+
# Make sure pysb Expression definitions stay in sync.
1308+
# Remove with https://github.com/AMICI-dev/AMICI/issues/3035
1309+
pysb_component.expr = expr.get_val()
13211310

13221311
# remove all possible Heavisides from roots, which may arise from
13231312
# the substitution of `'w'` in `_collect_heaviside_roots`
13241313
for root in roots:
1325-
root.set_val(
1326-
self._process_heavisides(root.get_val(), roots, w_toposorted)
1327-
)
1314+
root.set_val(self._process_heavisides(root.get_val(), roots))
13281315

13291316
# Now add the found roots to the model components
13301317
for root in roots:
@@ -2461,9 +2448,6 @@ def _get_unique_root(
24612448
unique identifier for root, or ``None`` if the root is not
24622449
time-dependent
24632450
"""
2464-
if not self._expr_is_time_dependent(root_found):
2465-
return None
2466-
24672451
for root in roots:
24682452
if (difference := (root_found - root.get_val())).is_zero or (
24692453
self._simplify and self._simplify(difference).is_zero
@@ -2481,6 +2465,12 @@ def _get_unique_root(
24812465
use_values_from_trigger_time=True,
24822466
)
24832467
)
2468+
2469+
if not self._expr_is_time_dependent(root_found):
2470+
# Not time-dependent. Return None, but we still need to create
2471+
# the event above
2472+
return None
2473+
24842474
return roots[-1].get_sym()
24852475

24862476
def _collect_heaviside_roots(
@@ -2511,7 +2501,6 @@ def _process_heavisides(
25112501
self,
25122502
dxdt: sp.Expr,
25132503
roots: list[Event],
2514-
w_toposorted: dict[sp.Symbol, sp.Expr],
25152504
) -> sp.Expr:
25162505
"""
25172506
Parses the RHS of a state variable, checks for Heaviside functions,
@@ -2523,8 +2512,6 @@ def _process_heavisides(
25232512
right-hand side of state variable
25242513
:param roots:
25252514
list of known root functions with identifier
2526-
:param w_toposorted:
2527-
`w` symbols->expressions sorted in topological order
25282515
:returns:
25292516
dxdt with Heaviside functions replaced by amici helper variables
25302517
"""
@@ -2533,9 +2520,9 @@ def _process_heavisides(
25332520
heavisides = []
25342521
# run through the expression tree and get the roots
25352522
tmp_roots_old = self._collect_heaviside_roots((dxdt,))
2536-
# substitute 'w' symbols in the root expression by their equations,
2523+
# TODO remove: substitute 'w' symbols in the root expression by their equations,
25372524
# because currently,
2538-
# # 1) root functions must not depend on 'w'
2525+
# 1) root functions must not depend on 'w'
25392526
# FIXME 2) the check for time-dependence currently assumes only state
25402527
# variables are implicitly time-dependent
25412528
for tmp_root_old, tmp_x0_old in unique_preserve_order(tmp_roots_old):

python/sdist/amici/importers/pysb/__init__.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -669,12 +669,18 @@ def _add_expression(
669669
name == str(channel.sigma)
670670
for channel in observation_model.values()
671671
):
672-
component = SigmaY
672+
component_type = SigmaY
673673
else:
674-
component = Expression
675-
ode_model.add_component(
676-
component(sym, name, _parse_special_functions(expr))
677-
)
674+
component_type = Expression
675+
676+
component = component_type(sym, name, _parse_special_functions(expr))
677+
678+
if isinstance(sym, pysb.Expression):
679+
# Make sure pysb Expression definitions stay in sync.
680+
# Remove with https://github.com/AMICI-dev/AMICI/issues/3035
681+
expr.expr = component.get_val()
682+
683+
ode_model.add_component(component)
678684

679685
if name in observation_model:
680686
noise_dist = observation_model[name].noise_distribution

0 commit comments

Comments
 (0)