diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 05498c30fc..4cc90f2396 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -59,9 +59,6 @@ ) from .de_model import DEModel from .de_model_components import * -from .import_utils import ( - strip_pysb, -) from .logging import get_logger, log_execution_time, set_log_level from .sympy_utils import ( _custom_pow_eval_derivative, @@ -323,7 +320,7 @@ def _generate_c_code(self) -> None: CXX_MAIN_TEMPLATE_FILE, os.path.join(self.model_path, "main.cpp") ) - def _get_index(self, name: str) -> dict[sp.Symbol, int]: + def _get_index(self, name: str) -> dict[str, int]: """ Compute indices for a symbolic array. :param name: @@ -339,10 +336,7 @@ def _get_index(self, name: str) -> dict[sp.Symbol, int]: else: raise ValueError(f"Unknown symbolic array: {name}") - return { - strip_pysb(symbol).name: index - for index, symbol in enumerate(symbols) - } + return {symbol.name: index for index, symbol in enumerate(symbols)} def _write_index_files(self, name: str) -> None: """ @@ -369,9 +363,9 @@ def _write_index_files(self, name: str) -> None: lines = [] for index, symbol in enumerate(symbols): - symbol_name = strip_pysb(symbol) if symbol.is_zero: continue + symbol_name = symbol.name if str(symbol_name) == "": raise ValueError(f'{name} contains a symbol called ""') lines.append(f"#define {symbol_name} {name}[{index}]") diff --git a/python/sdist/amici/import_utils.py b/python/sdist/amici/import_utils.py index 5e5e6d399c..d6c19bfcaa 100644 --- a/python/sdist/amici/import_utils.py +++ b/python/sdist/amici/import_utils.py @@ -860,7 +860,11 @@ def generate_measurement_symbol(observable_id: str | sp.Symbol): symbol for the corresponding measurement """ if not isinstance(observable_id, str): - observable_id = strip_pysb(observable_id) + observable_id = ( + observable_id.name + if isinstance(observable_id, sp.Symbol) + else observable_id + ) return symbol_with_assumptions(f"m{observable_id}") @@ -875,7 +879,11 @@ def generate_regularization_symbol(observable_id: str | sp.Symbol): symbol for the corresponding regularization """ if not isinstance(observable_id, str): - observable_id = strip_pysb(observable_id) + observable_id = ( + observable_id.name + if isinstance(observable_id, sp.Symbol) + else observable_id + ) return symbol_with_assumptions(f"r{observable_id}") @@ -913,26 +921,6 @@ def symbol_with_assumptions(name: str): return sp.Symbol(name, real=True) -def strip_pysb(symbol: sp.Basic) -> sp.Basic: - """ - Strips pysb info from a :class:`pysb.Component` object - - :param symbol: - symbolic expression - - :return: - stripped expression - """ - # strip pysb type and transform into a flat sympy.Symbol. - # this ensures that the pysb type specific __repr__ is used when converting - # to string - if pysb and isinstance(symbol, pysb.Component): - return sp.Symbol(symbol.name, real=True) - else: - # in this case we will use sympy specific transform anyways - return symbol - - def unique_preserve_order(seq: Sequence) -> list: """Return a list of unique elements in Sequence, keeping only the first occurrence of each element diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index fa8fa259d6..f1533a0aab 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -24,9 +24,6 @@ from amici._codegen.template import apply_template from amici.de_export import is_valid_identifier from amici.de_model import DEModel -from amici.import_utils import ( - strip_pysb, -) from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter, _jnp_array_str from amici.jax.model import JAXModel from amici.jax.nn import generate_equinox @@ -45,7 +42,7 @@ def _jax_variable_assignments( ) -> dict: return { f"{sym_name.upper()}_SYMS": "".join( - str(strip_pysb(s)) + ", " for s in model.sym(sym_name) + f"{s.name}, " for s in model.sym(sym_name) ) if model.sym(sym_name) else "_" @@ -63,7 +60,7 @@ def _jax_variable_equations( return { f"{eq_name.upper()}_EQ": "\n".join( code_printer._get_sym_lines( - (str(strip_pysb(s)) for s in model.sym(eq_name)), + (s.name for s in model.sym(eq_name)), model.eq(eq_name).subs(subs), indent, ) @@ -78,7 +75,7 @@ def _jax_return_variables( ) -> dict: return { f"{eq_name.upper()}_RET": _jnp_array_str( - strip_pysb(s) for s in model.sym(eq_name) + s.name for s in model.sym(eq_name) ) if model.sym(eq_name) else "jnp.array([])" @@ -89,7 +86,7 @@ def _jax_return_variables( def _jax_variable_ids(model: DEModel, sym_names: tuple[str, ...]) -> dict: return { f"{sym_name.upper()}_IDS": "".join( - f'"{strip_pysb(s)}", ' for s in model.sym(sym_name) + f'"{s.name}", ' for s in model.sym(sym_name) ) if model.sym(sym_name) else "tuple()" diff --git a/python/sdist/amici/petab/pysb_import.py b/python/sdist/amici/petab/pysb_import.py index 6ae3fcb15c..806198439b 100644 --- a/python/sdist/amici/petab/pysb_import.py +++ b/python/sdist/amici/petab/pysb_import.py @@ -22,7 +22,6 @@ from amici import MeasurementChannel -from ..import_utils import strip_pysb from ..logging import get_logger, log_execution_time, set_log_level from . import PREEQ_INDICATOR_ID from .import_helpers import ( @@ -78,7 +77,9 @@ def _add_observation_model( # update forum if jax and changed_formula: - obs_df.at[ir, col] = str(strip_pysb(sym)) + obs_df.at[ir, col] = ( + sym.name if isinstance(sym, sp.Symbol) else str(sym) + ) # add observables and sigmas to pysb model for observable_id, observable_formula, noise_formula in zip(