Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@
)
from .de_model import DEModel
from .de_model_components import *
from .import_utils import (
strip_pysb,
)
from .logging import get_logger, log_execution_time, set_log_level
from .sympy_utils import (
_custom_pow_eval_derivative,
Expand Down Expand Up @@ -323,7 +320,7 @@ def _generate_c_code(self) -> None:
CXX_MAIN_TEMPLATE_FILE, os.path.join(self.model_path, "main.cpp")
)

def _get_index(self, name: str) -> dict[sp.Symbol, int]:
def _get_index(self, name: str) -> dict[str, int]:
"""
Compute indices for a symbolic array.
:param name:
Expand All @@ -339,10 +336,7 @@ def _get_index(self, name: str) -> dict[sp.Symbol, int]:
else:
raise ValueError(f"Unknown symbolic array: {name}")

return {
strip_pysb(symbol).name: index
for index, symbol in enumerate(symbols)
}
return {symbol.name: index for index, symbol in enumerate(symbols)}

def _write_index_files(self, name: str) -> None:
"""
Expand All @@ -369,9 +363,9 @@ def _write_index_files(self, name: str) -> None:

lines = []
for index, symbol in enumerate(symbols):
symbol_name = strip_pysb(symbol)
if symbol.is_zero:
continue
symbol_name = symbol.name
if str(symbol_name) == "":
raise ValueError(f'{name} contains a symbol called ""')
lines.append(f"#define {symbol_name} {name}[{index}]")
Expand Down
32 changes: 10 additions & 22 deletions python/sdist/amici/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,11 @@ def generate_measurement_symbol(observable_id: str | sp.Symbol):
symbol for the corresponding measurement
"""
if not isinstance(observable_id, str):
observable_id = strip_pysb(observable_id)
observable_id = (
observable_id.name
if isinstance(observable_id, sp.Symbol)
else observable_id
)
return symbol_with_assumptions(f"m{observable_id}")


Expand All @@ -875,7 +879,11 @@ def generate_regularization_symbol(observable_id: str | sp.Symbol):
symbol for the corresponding regularization
"""
if not isinstance(observable_id, str):
observable_id = strip_pysb(observable_id)
observable_id = (
observable_id.name
if isinstance(observable_id, sp.Symbol)
else observable_id
)
return symbol_with_assumptions(f"r{observable_id}")


Expand Down Expand Up @@ -913,26 +921,6 @@ def symbol_with_assumptions(name: str):
return sp.Symbol(name, real=True)


def strip_pysb(symbol: sp.Basic) -> sp.Basic:
"""
Strips pysb info from a :class:`pysb.Component` object

:param symbol:
symbolic expression

:return:
stripped expression
"""
# strip pysb type and transform into a flat sympy.Symbol.
# this ensures that the pysb type specific __repr__ is used when converting
# to string
if pysb and isinstance(symbol, pysb.Component):
return sp.Symbol(symbol.name, real=True)
else:
# in this case we will use sympy specific transform anyways
return symbol


def unique_preserve_order(seq: Sequence) -> list:
"""Return a list of unique elements in Sequence, keeping only the first
occurrence of each element
Expand Down
11 changes: 4 additions & 7 deletions python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
from amici._codegen.template import apply_template
from amici.de_export import is_valid_identifier
from amici.de_model import DEModel
from amici.import_utils import (
strip_pysb,
)
from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter, _jnp_array_str
from amici.jax.model import JAXModel
from amici.jax.nn import generate_equinox
Expand All @@ -45,7 +42,7 @@ def _jax_variable_assignments(
) -> dict:
return {
f"{sym_name.upper()}_SYMS": "".join(
str(strip_pysb(s)) + ", " for s in model.sym(sym_name)
f"{s.name}, " for s in model.sym(sym_name)
)
if model.sym(sym_name)
else "_"
Expand All @@ -63,7 +60,7 @@ def _jax_variable_equations(
return {
f"{eq_name.upper()}_EQ": "\n".join(
code_printer._get_sym_lines(
(str(strip_pysb(s)) for s in model.sym(eq_name)),
(s.name for s in model.sym(eq_name)),
model.eq(eq_name).subs(subs),
indent,
)
Expand All @@ -78,7 +75,7 @@ def _jax_return_variables(
) -> dict:
return {
f"{eq_name.upper()}_RET": _jnp_array_str(
strip_pysb(s) for s in model.sym(eq_name)
s.name for s in model.sym(eq_name)
)
if model.sym(eq_name)
else "jnp.array([])"
Expand All @@ -89,7 +86,7 @@ def _jax_return_variables(
def _jax_variable_ids(model: DEModel, sym_names: tuple[str, ...]) -> dict:
return {
f"{sym_name.upper()}_IDS": "".join(
f'"{strip_pysb(s)}", ' for s in model.sym(sym_name)
f'"{s.name}", ' for s in model.sym(sym_name)
)
if model.sym(sym_name)
else "tuple()"
Expand Down
5 changes: 3 additions & 2 deletions python/sdist/amici/petab/pysb_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from amici import MeasurementChannel

from ..import_utils import strip_pysb
from ..logging import get_logger, log_execution_time, set_log_level
from . import PREEQ_INDICATOR_ID
from .import_helpers import (
Expand Down Expand Up @@ -78,7 +77,9 @@ def _add_observation_model(

# update forum
if jax and changed_formula:
obs_df.at[ir, col] = str(strip_pysb(sym))
obs_df.at[ir, col] = (
sym.name if isinstance(sym, sp.Symbol) else str(sym)
)

# add observables and sigmas to pysb model
for observable_id, observable_formula, noise_formula in zip(
Expand Down
Loading