Skip to content

Commit 360d006

Browse files
committed
Remove strip_pysb
Replace the remaining usages of `strip_pysb` after AMICI-dev#3005. Some were useless anyways. Closes AMICI-dev#3004.
1 parent 2717020 commit 360d006

File tree

4 files changed

+20
-40
lines changed

4 files changed

+20
-40
lines changed

python/sdist/amici/de_export.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,6 @@
5959
)
6060
from .de_model import DEModel
6161
from .de_model_components import *
62-
from .import_utils import (
63-
strip_pysb,
64-
)
6562
from .logging import get_logger, log_execution_time, set_log_level
6663
from .sympy_utils import (
6764
_custom_pow_eval_derivative,
@@ -323,7 +320,7 @@ def _generate_c_code(self) -> None:
323320
CXX_MAIN_TEMPLATE_FILE, os.path.join(self.model_path, "main.cpp")
324321
)
325322

326-
def _get_index(self, name: str) -> dict[sp.Symbol, int]:
323+
def _get_index(self, name: str) -> dict[str, int]:
327324
"""
328325
Compute indices for a symbolic array.
329326
:param name:
@@ -339,10 +336,7 @@ def _get_index(self, name: str) -> dict[sp.Symbol, int]:
339336
else:
340337
raise ValueError(f"Unknown symbolic array: {name}")
341338

342-
return {
343-
strip_pysb(symbol).name: index
344-
for index, symbol in enumerate(symbols)
345-
}
339+
return {symbol.name: index for index, symbol in enumerate(symbols)}
346340

347341
def _write_index_files(self, name: str) -> None:
348342
"""
@@ -369,9 +363,9 @@ def _write_index_files(self, name: str) -> None:
369363

370364
lines = []
371365
for index, symbol in enumerate(symbols):
372-
symbol_name = strip_pysb(symbol)
373366
if symbol.is_zero:
374367
continue
368+
symbol_name = symbol.name
375369
if str(symbol_name) == "":
376370
raise ValueError(f'{name} contains a symbol called ""')
377371
lines.append(f"#define {symbol_name} {name}[{index}]")

python/sdist/amici/import_utils.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,11 @@ def generate_measurement_symbol(observable_id: str | sp.Symbol):
860860
symbol for the corresponding measurement
861861
"""
862862
if not isinstance(observable_id, str):
863-
observable_id = strip_pysb(observable_id)
863+
observable_id = (
864+
observable_id.name
865+
if isinstance(observable_id, sp.Symbol)
866+
else observable_id
867+
)
864868
return symbol_with_assumptions(f"m{observable_id}")
865869

866870

@@ -875,7 +879,11 @@ def generate_regularization_symbol(observable_id: str | sp.Symbol):
875879
symbol for the corresponding regularization
876880
"""
877881
if not isinstance(observable_id, str):
878-
observable_id = strip_pysb(observable_id)
882+
observable_id = (
883+
observable_id.name
884+
if isinstance(observable_id, sp.Symbol)
885+
else observable_id
886+
)
879887
return symbol_with_assumptions(f"r{observable_id}")
880888

881889

@@ -913,26 +921,6 @@ def symbol_with_assumptions(name: str):
913921
return sp.Symbol(name, real=True)
914922

915923

916-
def strip_pysb(symbol: sp.Basic) -> sp.Basic:
917-
"""
918-
Strips pysb info from a :class:`pysb.Component` object
919-
920-
:param symbol:
921-
symbolic expression
922-
923-
:return:
924-
stripped expression
925-
"""
926-
# strip pysb type and transform into a flat sympy.Symbol.
927-
# this ensures that the pysb type specific __repr__ is used when converting
928-
# to string
929-
if pysb and isinstance(symbol, pysb.Component):
930-
return sp.Symbol(symbol.name, real=True)
931-
else:
932-
# in this case we will use sympy specific transform anyways
933-
return symbol
934-
935-
936924
def unique_preserve_order(seq: Sequence) -> list:
937925
"""Return a list of unique elements in Sequence, keeping only the first
938926
occurrence of each element

python/sdist/amici/jax/ode_export.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
from amici._codegen.template import apply_template
2525
from amici.de_export import is_valid_identifier
2626
from amici.de_model import DEModel
27-
from amici.import_utils import (
28-
strip_pysb,
29-
)
3027
from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter, _jnp_array_str
3128
from amici.jax.model import JAXModel
3229
from amici.jax.nn import generate_equinox
@@ -45,7 +42,7 @@ def _jax_variable_assignments(
4542
) -> dict:
4643
return {
4744
f"{sym_name.upper()}_SYMS": "".join(
48-
str(strip_pysb(s)) + ", " for s in model.sym(sym_name)
45+
f"{s.name}, " for s in model.sym(sym_name)
4946
)
5047
if model.sym(sym_name)
5148
else "_"
@@ -63,7 +60,7 @@ def _jax_variable_equations(
6360
return {
6461
f"{eq_name.upper()}_EQ": "\n".join(
6562
code_printer._get_sym_lines(
66-
(str(strip_pysb(s)) for s in model.sym(eq_name)),
63+
(s.name for s in model.sym(eq_name)),
6764
model.eq(eq_name).subs(subs),
6865
indent,
6966
)
@@ -78,7 +75,7 @@ def _jax_return_variables(
7875
) -> dict:
7976
return {
8077
f"{eq_name.upper()}_RET": _jnp_array_str(
81-
strip_pysb(s) for s in model.sym(eq_name)
78+
s.name for s in model.sym(eq_name)
8279
)
8380
if model.sym(eq_name)
8481
else "jnp.array([])"
@@ -89,7 +86,7 @@ def _jax_return_variables(
8986
def _jax_variable_ids(model: DEModel, sym_names: tuple[str, ...]) -> dict:
9087
return {
9188
f"{sym_name.upper()}_IDS": "".join(
92-
f'"{strip_pysb(s)}", ' for s in model.sym(sym_name)
89+
f'"{s.name}", ' for s in model.sym(sym_name)
9390
)
9491
if model.sym(sym_name)
9592
else "tuple()"

python/sdist/amici/petab/pysb_import.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from amici import MeasurementChannel
2424

25-
from ..import_utils import strip_pysb
2625
from ..logging import get_logger, log_execution_time, set_log_level
2726
from . import PREEQ_INDICATOR_ID
2827
from .import_helpers import (
@@ -78,7 +77,9 @@ def _add_observation_model(
7877

7978
# update forum
8079
if jax and changed_formula:
81-
obs_df.at[ir, col] = str(strip_pysb(sym))
80+
obs_df.at[ir, col] = (
81+
sym.name if isinstance(sym, sp.Symbol) else str(sym)
82+
)
8283

8384
# add observables and sigmas to pysb model
8485
for observable_id, observable_formula, noise_formula in zip(

0 commit comments

Comments
 (0)