diff --git a/python/sdist/amici/importers/sbml/__init__.py b/python/sdist/amici/importers/sbml/__init__.py index 37e2776f8a..d633425d7c 100644 --- a/python/sdist/amici/importers/sbml/__init__.py +++ b/python/sdist/amici/importers/sbml/__init__.py @@ -66,6 +66,7 @@ from amici.logging import get_logger, log_execution_time, set_log_level from amici.sympy_utils import ( _monkeypatch_sympy, + _piecewise_to_minmax, smart_is_zero_matrix, smart_multiply, ) @@ -2882,6 +2883,7 @@ def subs_locals(expr: sp.Basic) -> sp.Basic: # piecewise to heavisides if piecewise_to_heaviside: try: + expr = expr.replace(sp.Piecewise, _piecewise_to_minmax) expr = expr.replace( sp.Piecewise, lambda *args: _parse_piecewise_to_heaviside(args), diff --git a/python/sdist/amici/sympy_utils.py b/python/sdist/amici/sympy_utils.py index bd3eecc4ac..85a0e986a8 100644 --- a/python/sdist/amici/sympy_utils.py +++ b/python/sdist/amici/sympy_utils.py @@ -216,3 +216,19 @@ def _parallel_applyfunc(obj: sp.Matrix, func: Callable) -> sp.Matrix: "to a module-level function or disable parallelization by " "setting `AMICI_IMPORT_NPROCS=1`." ) from e + + +def _piecewise_to_minmax( + *expr_cond_pairs: tuple[tuple[sp.Basic, sp.Basic], ...], +) -> sp.Basic: + """Replace min/max defined via Piecewise with plain Min/Max. + + To be used in ``expr = expr.replace(sp.Piecewise, pw_to_minmax)``. + """ + if len(expr_cond_pairs) == 2 and expr_cond_pairs[-1][1] == sp.true: + (expr1, cond1), (expr2, cond2) = expr_cond_pairs + if cond1.args == (expr1, expr2) and cond1.func in (sp.Lt, sp.Le): + return sp.Min(expr1, expr2) + elif cond1.args == (expr1, expr2) and cond1.func in (sp.Gt, sp.Ge): + return sp.Max(expr1, expr2) + return sp.Piecewise(*expr_cond_pairs) diff --git a/python/tests/test_sbml_import.py b/python/tests/test_sbml_import.py index 574bfcf7b2..4d86496f07 100644 --- a/python/tests/test_sbml_import.py +++ b/python/tests/test_sbml_import.py @@ -13,6 +13,7 @@ import sympy as sp from amici import import_model_module from amici.gradient_check import check_derivatives +from amici.importers.antimony import antimony2sbml from amici.importers.sbml import SbmlImporter, SymbolId from amici.importers.utils import ( MeasurementChannel as MC, @@ -1196,3 +1197,21 @@ def test_time_dependent_initial_assignment(compute_conservation_laws: bool): symbol_with_assumptions("p0"), amici_time_symbol * 1.0 + 3.0, ] + + +@skip_on_valgrind +def test_minmax_piecewise_is_converted_to_minmax(): + """Test that _piecewise_to_minmax is applied during SBML import.""" + sbml_str = antimony2sbml(""" + x' = piecewise(a, a > b, b) + y' = piecewise(a, a < b, b) + """) + sbml_importer = SbmlImporter(sbml_source=sbml_str, from_file=False) + de_model = sbml_importer._build_ode_model() + # no events should be created for min/max + assert not de_model.events() + assert len(de_model.sym("h")) == 0 + # min/max are present in the equations + xdot = de_model.eq("xdot") + assert xdot.has(sp.Min) + assert xdot.has(sp.Max) diff --git a/python/tests/test_sympy_utils.py b/python/tests/test_sympy_utils.py index 3f21d0b928..a868cd8f8a 100644 --- a/python/tests/test_sympy_utils.py +++ b/python/tests/test_sympy_utils.py @@ -1,7 +1,11 @@ """Tests related to the sympy_utils module.""" import sympy as sp -from amici.sympy_utils import _custom_pow_eval_derivative, _monkeypatched +from amici.sympy_utils import ( + _custom_pow_eval_derivative, + _monkeypatched, + _piecewise_to_minmax, +) from amici.testing import skip_on_valgrind @@ -22,3 +26,51 @@ def test_monkeypatch(): # check that the monkeypatch is transient assert (t**n).diff(t).subs(vals) is sp.nan + + +@skip_on_valgrind +def test_rewrite_piecewise_minmax(): + """Test rewriting of piecewise min/max to sympy Min/Max functions.""" + x, y, z = sp.symbols("x y z") + + assert sp.Piecewise((x, x < y), (y, True)).replace( + sp.Piecewise, _piecewise_to_minmax + ) == sp.Min(x, y) + assert sp.Piecewise((x, x <= y), (y, True)).replace( + sp.Piecewise, _piecewise_to_minmax + ) == sp.Min(x, y) + assert sp.Piecewise((x, x > y), (y, True)).replace( + sp.Piecewise, _piecewise_to_minmax + ) == sp.Max(x, y) + assert sp.Piecewise((x, x >= y), (y, True)).replace( + sp.Piecewise, _piecewise_to_minmax + ) == sp.Max(x, y) + assert sp.Piecewise((x, y > x), (y, True)).replace( + sp.Piecewise, _piecewise_to_minmax + ) == sp.Min(x, y) + assert sp.Piecewise((x, y >= x), (y, True)).replace( + sp.Piecewise, _piecewise_to_minmax + ) == sp.Min(x, y) + assert sp.Piecewise((x, y < x), (y, True)).replace( + sp.Piecewise, _piecewise_to_minmax + ) == sp.Max(x, y) + assert sp.Piecewise((x, y <= x), (y, True)).replace( + sp.Piecewise, _piecewise_to_minmax + ) == sp.Max(x, y) + + # can't replace + assert sp.Piecewise((z, y <= x), (y, True)).replace( + sp.Piecewise, _piecewise_to_minmax + ) == sp.Piecewise((z, y <= x), (y, True)) + + # replace recursively + expr = sp.Piecewise( + (sp.Piecewise((x, x < y), (y, True)), x < z), + (sp.Piecewise((y, y < z), (z, True)), True), + ) + replaced = expr.replace(sp.Piecewise, _piecewise_to_minmax) + expected = sp.Piecewise( + (sp.Min(x, y), x < z), + (sp.Min(y, z), True), + ) + assert replaced == expected