Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/sdist/amici/importers/sbml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@
toposort_symbols,
)
from amici.logging import get_logger, log_execution_time, set_log_level
from amici.sympy_utils import smart_is_zero_matrix, smart_multiply
from amici.sympy_utils import (
_monkeypatch_sympy,
smart_is_zero_matrix,
smart_multiply,
)

SymbolicFormula = dict[sp.Symbol, sp.Expr]

Expand Down Expand Up @@ -537,6 +541,7 @@ def sbml2jax(
)
exporter.generate_model_code()

@_monkeypatch_sympy
def _build_ode_model(
self,
fixed_parameters: Iterable[str] = None,
Expand Down
13 changes: 5 additions & 8 deletions python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
from amici.jax.nn import generate_equinox
from amici.logging import get_logger, log_execution_time, set_log_level
from amici.sympy_utils import (
_custom_pow_eval_derivative,
_monkeypatched,
_monkeypatch_sympy,
)

#: python log manager
Expand Down Expand Up @@ -168,17 +167,15 @@ def __init__(

self._code_printer = AmiciJaxCodePrinter()

@_monkeypatch_sympy
@log_execution_time("generating jax code", logger)
def generate_model_code(self) -> None:
"""
Generates the jax code for the loaded model
"""
with _monkeypatched(
sp.Pow, "_eval_derivative", _custom_pow_eval_derivative
):
self._prepare_model_folder()
self._generate_jax_code()
self._generate_nn_code()
self._prepare_model_folder()
self._generate_jax_code()
self._generate_nn_code()

def _prepare_model_folder(self) -> None:
"""
Expand Down
16 changes: 16 additions & 0 deletions python/sdist/amici/sympy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
from collections.abc import Callable
from functools import wraps
from itertools import starmap
from typing import Any

Expand Down Expand Up @@ -62,6 +63,21 @@ def _monkeypatched(obj: object, name: str, patch: Any):
setattr(obj, name, pre_patched_value)


def _monkeypatch_sympy(func):
"""
Decorator that temporarily monkeypatches sympy.Pow._eval_derivative.
"""

@wraps(func)
def wrapper(*args, **kwargs):
with _monkeypatched(
sp.Pow, "_eval_derivative", _custom_pow_eval_derivative
):
return func(*args, **kwargs)

return wrapper


@log_execution_time("running smart_jacobian", logger)
def smart_jacobian(
eq: sp.MutableDenseMatrix, sym_var: sp.MutableDenseMatrix
Expand Down
Loading