diff --git a/python/sdist/amici/importers/sbml/__init__.py b/python/sdist/amici/importers/sbml/__init__.py index ffc4fe45b0..199deefc00 100644 --- a/python/sdist/amici/importers/sbml/__init__.py +++ b/python/sdist/amici/importers/sbml/__init__.py @@ -64,7 +64,11 @@ toposort_symbols, ) from amici.logging import get_logger, log_execution_time, set_log_level -from amici.sympy_utils import smart_is_zero_matrix, smart_multiply +from amici.sympy_utils import ( + _monkeypatch_sympy, + smart_is_zero_matrix, + smart_multiply, +) SymbolicFormula = dict[sp.Symbol, sp.Expr] @@ -537,6 +541,7 @@ def sbml2jax( ) exporter.generate_model_code() + @_monkeypatch_sympy def _build_ode_model( self, fixed_parameters: Iterable[str] = None, diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index ebf0d164b7..d855372b23 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -29,8 +29,7 @@ from amici.jax.nn import generate_equinox from amici.logging import get_logger, log_execution_time, set_log_level from amici.sympy_utils import ( - _custom_pow_eval_derivative, - _monkeypatched, + _monkeypatch_sympy, ) #: python log manager @@ -168,17 +167,15 @@ def __init__( self._code_printer = AmiciJaxCodePrinter() + @_monkeypatch_sympy @log_execution_time("generating jax code", logger) def generate_model_code(self) -> None: """ Generates the jax code for the loaded model """ - with _monkeypatched( - sp.Pow, "_eval_derivative", _custom_pow_eval_derivative - ): - self._prepare_model_folder() - self._generate_jax_code() - self._generate_nn_code() + self._prepare_model_folder() + self._generate_jax_code() + self._generate_nn_code() def _prepare_model_folder(self) -> None: """ diff --git a/python/sdist/amici/sympy_utils.py b/python/sdist/amici/sympy_utils.py index 0f241090ff..bd3eecc4ac 100644 --- a/python/sdist/amici/sympy_utils.py +++ b/python/sdist/amici/sympy_utils.py @@ -4,6 +4,7 @@ import logging import os from collections.abc import Callable +from functools import wraps from itertools import starmap from typing import Any @@ -62,6 +63,21 @@ def _monkeypatched(obj: object, name: str, patch: Any): setattr(obj, name, pre_patched_value) +def _monkeypatch_sympy(func): + """ + Decorator that temporarily monkeypatches sympy.Pow._eval_derivative. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + with _monkeypatched( + sp.Pow, "_eval_derivative", _custom_pow_eval_derivative + ): + return func(*args, **kwargs) + + return wrapper + + @log_execution_time("running smart_jacobian", logger) def smart_jacobian( eq: sp.MutableDenseMatrix, sym_var: sp.MutableDenseMatrix