Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sdist/amici/importers/sbml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from amici.logging import get_logger, log_execution_time, set_log_level
from amici.sympy_utils import (
_monkeypatch_sympy,
_piecewise_to_minmax,
smart_is_zero_matrix,
smart_multiply,
)
Expand Down Expand Up @@ -2882,6 +2883,7 @@ def subs_locals(expr: sp.Basic) -> sp.Basic:
# piecewise to heavisides
if piecewise_to_heaviside:
try:
expr = expr.replace(sp.Piecewise, _piecewise_to_minmax)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also add small test with actual model to make sure that we are not reverting this substitution someplace else?

expr = expr.replace(
sp.Piecewise,
lambda *args: _parse_piecewise_to_heaviside(args),
Expand Down
16 changes: 16 additions & 0 deletions python/sdist/amici/sympy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,19 @@ def _parallel_applyfunc(obj: sp.Matrix, func: Callable) -> sp.Matrix:
"to a module-level function or disable parallelization by "
"setting `AMICI_IMPORT_NPROCS=1`."
) from e


def _piecewise_to_minmax(
*expr_cond_pairs: tuple[tuple[sp.Basic, sp.Basic], ...],
) -> sp.Basic:
"""Replace min/max defined via Piecewise with plain Min/Max.

To be used in ``expr = expr.replace(sp.Piecewise, pw_to_minmax)``.
"""
if len(expr_cond_pairs) == 2 and expr_cond_pairs[-1][1] == sp.true:
(expr1, cond1), (expr2, cond2) = expr_cond_pairs
if cond1.args == (expr1, expr2) and cond1.func in (sp.Lt, sp.Le):
return sp.Min(expr1, expr2)
elif cond1.args == (expr1, expr2) and cond1.func in (sp.Gt, sp.Ge):
return sp.Max(expr1, expr2)
return sp.Piecewise(*expr_cond_pairs)
19 changes: 19 additions & 0 deletions python/tests/test_sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sympy as sp
from amici import import_model_module
from amici.gradient_check import check_derivatives
from amici.importers.antimony import antimony2sbml
from amici.importers.sbml import SbmlImporter, SymbolId
from amici.importers.utils import (
MeasurementChannel as MC,
Expand Down Expand Up @@ -1196,3 +1197,21 @@ def test_time_dependent_initial_assignment(compute_conservation_laws: bool):
symbol_with_assumptions("p0"),
amici_time_symbol * 1.0 + 3.0,
]


@skip_on_valgrind
def test_minmax_piecewise_is_converted_to_minmax():
"""Test that _piecewise_to_minmax is applied during SBML import."""
sbml_str = antimony2sbml("""
x' = piecewise(a, a > b, b)
y' = piecewise(a, a < b, b)
""")
sbml_importer = SbmlImporter(sbml_source=sbml_str, from_file=False)
de_model = sbml_importer._build_ode_model()
# no events should be created for min/max
assert not de_model.events()
assert len(de_model.sym("h")) == 0
# min/max are present in the equations
xdot = de_model.eq("xdot")
assert xdot.has(sp.Min)
assert xdot.has(sp.Max)
54 changes: 53 additions & 1 deletion python/tests/test_sympy_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Tests related to the sympy_utils module."""

import sympy as sp
from amici.sympy_utils import _custom_pow_eval_derivative, _monkeypatched
from amici.sympy_utils import (
_custom_pow_eval_derivative,
_monkeypatched,
_piecewise_to_minmax,
)
from amici.testing import skip_on_valgrind


Expand All @@ -22,3 +26,51 @@ def test_monkeypatch():

# check that the monkeypatch is transient
assert (t**n).diff(t).subs(vals) is sp.nan


@skip_on_valgrind
def test_rewrite_piecewise_minmax():
"""Test rewriting of piecewise min/max to sympy Min/Max functions."""
x, y, z = sp.symbols("x y z")

assert sp.Piecewise((x, x < y), (y, True)).replace(
sp.Piecewise, _piecewise_to_minmax
) == sp.Min(x, y)
assert sp.Piecewise((x, x <= y), (y, True)).replace(
sp.Piecewise, _piecewise_to_minmax
) == sp.Min(x, y)
assert sp.Piecewise((x, x > y), (y, True)).replace(
sp.Piecewise, _piecewise_to_minmax
) == sp.Max(x, y)
assert sp.Piecewise((x, x >= y), (y, True)).replace(
sp.Piecewise, _piecewise_to_minmax
) == sp.Max(x, y)
assert sp.Piecewise((x, y > x), (y, True)).replace(
sp.Piecewise, _piecewise_to_minmax
) == sp.Min(x, y)
assert sp.Piecewise((x, y >= x), (y, True)).replace(
sp.Piecewise, _piecewise_to_minmax
) == sp.Min(x, y)
assert sp.Piecewise((x, y < x), (y, True)).replace(
sp.Piecewise, _piecewise_to_minmax
) == sp.Max(x, y)
assert sp.Piecewise((x, y <= x), (y, True)).replace(
sp.Piecewise, _piecewise_to_minmax
) == sp.Max(x, y)

# can't replace
assert sp.Piecewise((z, y <= x), (y, True)).replace(
sp.Piecewise, _piecewise_to_minmax
) == sp.Piecewise((z, y <= x), (y, True))

# replace recursively
expr = sp.Piecewise(
(sp.Piecewise((x, x < y), (y, True)), x < z),
(sp.Piecewise((y, y < z), (z, True)), True),
)
replaced = expr.replace(sp.Piecewise, _piecewise_to_minmax)
expected = sp.Piecewise(
(sp.Min(x, y), x < z),
(sp.Min(y, z), True),
)
assert replaced == expected
Loading