Skip to content

Commit 861a135

Browse files
committed
Add ModelQuantity.get_sym, make ModelQuantity.get_id return str
Distinguish between `id: str` and `symbol: sp.Symbol`. The old behavior (`def get_id(self) -> sp.Symbol`) was confusing. Fix inconsistent type for `AlgebraicEquation` symbol/id. Closes #2940.
1 parent 3ada40c commit 861a135

File tree

7 files changed

+142
-132
lines changed

7 files changed

+142
-132
lines changed

python/sdist/amici/de_model.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def add_conservation_law(
648648
try:
649649
ix = next(
650650
filter(
651-
lambda is_s: is_s[1].get_id() == state,
651+
lambda is_s: is_s[1].get_sym() == state,
652652
enumerate(self._differential_states),
653653
)
654654
)[0]
@@ -657,7 +657,7 @@ def add_conservation_law(
657657
f"Specified state {state} was not found in the model states."
658658
)
659659

660-
state_id = self._differential_states[ix].get_id()
660+
state_id = self._differential_states[ix].get_sym()
661661

662662
# \sum_{i≠j}(a_i * x_i)/a_j
663663
target_expression = (
@@ -704,7 +704,7 @@ def add_spline(self, spline: AbstractSpline, spline_expr: sp.Expr) -> None:
704704
self._splines.append(spline)
705705
self.add_component(
706706
Expression(
707-
identifier=spline.sbml_id,
707+
symbol=spline.sbml_id,
708708
name=str(spline.sbml_id),
709709
value=spline_expr,
710710
)
@@ -1136,7 +1136,7 @@ def _generate_symbol(self, name: str) -> None:
11361136
components = sorted(
11371137
components,
11381138
key=lambda x: int(
1139-
str(strip_pysb(x.get_id())).replace(
1139+
str(strip_pysb(x.get_sym())).replace(
11401140
"observableParameter", ""
11411141
)
11421142
),
@@ -1145,13 +1145,13 @@ def _generate_symbol(self, name: str) -> None:
11451145
components = sorted(
11461146
components,
11471147
key=lambda x: int(
1148-
str(strip_pysb(x.get_id())).replace(
1148+
str(strip_pysb(x.get_sym())).replace(
11491149
"noiseParameter", ""
11501150
)
11511151
),
11521152
)
11531153
self._syms[name] = sp.Matrix(
1154-
[comp.get_id() for comp in components]
1154+
[comp.get_sym() for comp in components]
11551155
)
11561156
if name == "y":
11571157
self._syms["my"] = sp.Matrix(
@@ -1168,7 +1168,7 @@ def _generate_symbol(self, name: str) -> None:
11681168
elif name == "x":
11691169
self._syms[name] = sp.Matrix(
11701170
[
1171-
state.get_id()
1171+
state.get_sym()
11721172
for state in self.states()
11731173
if not state.has_conservation_law()
11741174
]
@@ -1214,8 +1214,8 @@ def _generate_symbol(self, name: str) -> None:
12141214
[
12151215
[
12161216
sp.Symbol(
1217-
f"s{strip_pysb(tcl.get_id())}__"
1218-
f"{strip_pysb(par.get_id())}",
1217+
f"s{strip_pysb(tcl.get_sym())}__"
1218+
f"{strip_pysb(par.get_sym())}",
12191219
real=True,
12201220
)
12211221
for par in self._parameters
@@ -1312,7 +1312,7 @@ def parse_events(self) -> None:
13121312
w_toposorted = toposort_symbols(
13131313
dict(
13141314
zip(
1315-
[expr.get_id() for expr in self._expressions],
1315+
[expr.get_sym() for expr in self._expressions],
13161316
[expr.get_val() for expr in self._expressions],
13171317
strict=True,
13181318
)
@@ -1393,9 +1393,11 @@ def get_appearance_counts(self, idxs: list[int]) -> list[int]:
13931393
)
13941394

13951395
return [
1396-
free_symbols_dt.count(str(self._differential_states[idx].get_id()))
1396+
free_symbols_dt.count(
1397+
str(self._differential_states[idx].get_sym())
1398+
)
13971399
+ free_symbols_expr.count(
1398-
str(self._differential_states[idx].get_id())
1400+
str(self._differential_states[idx].get_sym())
13991401
)
14001402
for idx in idxs
14011403
]
@@ -1528,7 +1530,7 @@ def _compute_equation(self, name: str) -> None:
15281530
elif name == "x_solver":
15291531
self._eqs[name] = sp.Matrix(
15301532
[
1531-
state.get_id()
1533+
state.get_sym()
15321534
for state in self.states()
15331535
if not state.has_conservation_law()
15341536
]
@@ -1704,7 +1706,7 @@ def _compute_equation(self, name: str) -> None:
17041706
event_observables = [
17051707
sp.zeros(self.num_eventobs(), 1) for _ in self._events
17061708
]
1707-
event_ids = [e.get_id() for e in self._events]
1709+
event_ids = [e.get_sym() for e in self._events]
17081710
z2event = [
17091711
event_ids.index(event_obs.get_event())
17101712
for event_obs in self._event_observables
@@ -2285,7 +2287,7 @@ def get_conservation_laws(self) -> list[tuple[sp.Symbol, sp.Expr]]:
22852287
list of state identifiers
22862288
"""
22872289
return [
2288-
(state.get_id(), state.get_x_rdata())
2290+
(state.get_sym(), state.get_x_rdata())
22892291
for state in self.states()
22902292
if state.has_conservation_law()
22912293
]
@@ -2338,7 +2340,7 @@ def state_has_fixed_parameter_initial_condition(self, ix: int) -> bool:
23382340
if not isinstance(ic, sp.Basic):
23392341
return False
23402342
return any(
2341-
fp in (c.get_id() for c in self._constants)
2343+
fp in (c.get_sym() for c in self._constants)
23422344
for fp in ic.free_symbols
23432345
)
23442346

@@ -2450,20 +2452,20 @@ def _get_unique_root(
24502452

24512453
for root in roots:
24522454
if sp.simplify(root_found - root.get_val()).is_zero:
2453-
return root.get_id()
2455+
return root.get_sym()
24542456

24552457
# create an event for a new root function
24562458
root_symstr = f"Heaviside_{len(roots)}"
24572459
roots.append(
24582460
Event(
2459-
identifier=sp.Symbol(root_symstr),
2461+
symbol=sp.Symbol(root_symstr),
24602462
name=root_symstr,
24612463
value=root_found,
24622464
assignments=None,
24632465
use_values_from_trigger_time=True,
24642466
)
24652467
)
2466-
return roots[-1].get_id()
2468+
return roots[-1].get_sym()
24672469

24682470
def _collect_heaviside_roots(
24692471
self,
@@ -2579,22 +2581,22 @@ def _process_hybridization(self, hybridization: dict) -> None:
25792581
https://petab-sciml.readthedocs.io/latest/format.html#problem-yaml-file
25802582
"""
25812583
added_expressions = False
2582-
orig_obs = tuple([s.get_id() for s in self._observables])
2584+
orig_obs = tuple([s.get_sym() for s in self._observables])
25832585
for net_id, net in hybridization.items():
25842586
if net["static"]:
25852587
continue # do not integrate into ODEs, handle in amici.jax.petab
25862588
inputs = [
25872589
comp
25882590
for comp in self._components
2589-
if str(comp.get_id()) in net["input_vars"]
2591+
if str(comp.get_sym()) in net["input_vars"]
25902592
]
25912593
# sort inputs by order in input_vars
25922594
inputs = sorted(
25932595
inputs,
2594-
key=lambda comp: net["input_vars"].index(str(comp.get_id())),
2596+
key=lambda comp: net["input_vars"].index(str(comp.get_sym())),
25952597
)
25962598
if len(inputs) != len(net["input_vars"]):
2597-
found_vars = {str(comp.get_id()) for comp in inputs}
2599+
found_vars = {str(comp.get_sym()) for comp in inputs}
25982600
missing_vars = set(net["input_vars"]) - found_vars
25992601
raise ValueError(
26002602
f"Could not find all input variables for neural network {net_id}. "
@@ -2616,9 +2618,9 @@ def _process_hybridization(self, hybridization: dict) -> None:
26162618
outputs = {
26172619
out_var: {"comp": comp, "ind": net["output_vars"][out_var]}
26182620
for comp in self._components
2619-
if (out_var := str(comp.get_id())) in net["output_vars"]
2621+
if (out_var := str(comp.get_sym())) in net["output_vars"]
26202622
# TODO: SYNTAX NEEDS to CHANGE
2621-
or (out_var := str(comp.get_id()) + "_dot")
2623+
or (out_var := str(comp.get_sym()) + "_dot")
26222624
in net["output_vars"]
26232625
}
26242626
if len(outputs.keys()) != len(net["output_vars"]):
@@ -2645,7 +2647,7 @@ def _process_hybridization(self, hybridization: dict) -> None:
26452647

26462648
# generate dummy Function
26472649
out_val = sp.Function(net_id)(
2648-
*[input.get_id() for input in inputs], parts["ind"]
2650+
*[input.get_sym() for input in inputs], parts["ind"]
26492651
)
26502652

26512653
# add to the model
@@ -2659,7 +2661,7 @@ def _process_hybridization(self, hybridization: dict) -> None:
26592661
else:
26602662
self.add_component(
26612663
Expression(
2662-
identifier=comp.get_id(),
2664+
symbol=comp.get_sym(),
26632665
name=net_id,
26642666
value=out_val,
26652667
)
@@ -2669,7 +2671,7 @@ def _process_hybridization(self, hybridization: dict) -> None:
26692671
observables = {
26702672
ob_var: {"comp": comp, "ind": net["observable_vars"][ob_var]}
26712673
for comp in self._components
2672-
if (ob_var := str(comp.get_id())) in net["observable_vars"]
2674+
if (ob_var := str(comp.get_sym())) in net["observable_vars"]
26732675
# # TODO: SYNTAX NEEDS to CHANGE
26742676
# or (ob_var := str(comp.get_id()) + "_dot")
26752677
# in net["observable_vars"]
@@ -2691,18 +2693,18 @@ def _process_hybridization(self, hybridization: dict) -> None:
26912693
f"{comp.get_name()} ({type(comp)}) is not an observable."
26922694
)
26932695
out_val = sp.Function(net_id)(
2694-
*[input.get_id() for input in inputs], parts["ind"]
2696+
*[input.get_sym() for input in inputs], parts["ind"]
26952697
)
26962698
# add to the model
26972699
self.add_component(
26982700
Observable(
2699-
identifier=comp.get_id(),
2701+
symbol=comp.get_sym(),
27002702
name=net_id,
27012703
value=out_val,
27022704
)
27032705
)
27042706

2705-
new_order = [orig_obs.index(s.get_id()) for s in self._observables]
2707+
new_order = [orig_obs.index(s.get_sym()) for s in self._observables]
27062708
self._observables = [self._observables[i] for i in new_order]
27072709

27082710
if added_expressions:

0 commit comments

Comments
 (0)