Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 33 additions & 31 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def add_conservation_law(
try:
ix = next(
filter(
lambda is_s: is_s[1].get_id() == state,
lambda is_s: is_s[1].get_sym() == state,
enumerate(self._differential_states),
)
)[0]
Expand All @@ -657,7 +657,7 @@ def add_conservation_law(
f"Specified state {state} was not found in the model states."
)

state_id = self._differential_states[ix].get_id()
state_id = self._differential_states[ix].get_sym()

# \sum_{i≠j}(a_i * x_i)/a_j
target_expression = (
Expand Down Expand Up @@ -704,7 +704,7 @@ def add_spline(self, spline: AbstractSpline, spline_expr: sp.Expr) -> None:
self._splines.append(spline)
self.add_component(
Expression(
identifier=spline.sbml_id,
symbol=spline.sbml_id,
name=str(spline.sbml_id),
value=spline_expr,
)
Expand Down Expand Up @@ -1136,7 +1136,7 @@ def _generate_symbol(self, name: str) -> None:
components = sorted(
components,
key=lambda x: int(
str(strip_pysb(x.get_id())).replace(
str(strip_pysb(x.get_sym())).replace(
"observableParameter", ""
)
),
Expand All @@ -1145,13 +1145,13 @@ def _generate_symbol(self, name: str) -> None:
components = sorted(
components,
key=lambda x: int(
str(strip_pysb(x.get_id())).replace(
str(strip_pysb(x.get_sym())).replace(
"noiseParameter", ""
)
),
)
self._syms[name] = sp.Matrix(
[comp.get_id() for comp in components]
[comp.get_sym() for comp in components]
)
if name == "y":
self._syms["my"] = sp.Matrix(
Expand All @@ -1168,7 +1168,7 @@ def _generate_symbol(self, name: str) -> None:
elif name == "x":
self._syms[name] = sp.Matrix(
[
state.get_id()
state.get_sym()
for state in self.states()
if not state.has_conservation_law()
]
Expand Down Expand Up @@ -1214,8 +1214,8 @@ def _generate_symbol(self, name: str) -> None:
[
[
sp.Symbol(
f"s{strip_pysb(tcl.get_id())}__"
f"{strip_pysb(par.get_id())}",
f"s{strip_pysb(tcl.get_sym())}__"
f"{strip_pysb(par.get_sym())}",
real=True,
)
for par in self._parameters
Expand Down Expand Up @@ -1312,7 +1312,7 @@ def parse_events(self) -> None:
w_toposorted = toposort_symbols(
dict(
zip(
[expr.get_id() for expr in self._expressions],
[expr.get_sym() for expr in self._expressions],
[expr.get_val() for expr in self._expressions],
strict=True,
)
Expand Down Expand Up @@ -1393,9 +1393,11 @@ def get_appearance_counts(self, idxs: list[int]) -> list[int]:
)

return [
free_symbols_dt.count(str(self._differential_states[idx].get_id()))
free_symbols_dt.count(
str(self._differential_states[idx].get_sym())
)
+ free_symbols_expr.count(
str(self._differential_states[idx].get_id())
str(self._differential_states[idx].get_sym())
)
for idx in idxs
]
Expand Down Expand Up @@ -1528,7 +1530,7 @@ def _compute_equation(self, name: str) -> None:
elif name == "x_solver":
self._eqs[name] = sp.Matrix(
[
state.get_id()
state.get_sym()
for state in self.states()
if not state.has_conservation_law()
]
Expand Down Expand Up @@ -1704,7 +1706,7 @@ def _compute_equation(self, name: str) -> None:
event_observables = [
sp.zeros(self.num_eventobs(), 1) for _ in self._events
]
event_ids = [e.get_id() for e in self._events]
event_ids = [e.get_sym() for e in self._events]
z2event = [
event_ids.index(event_obs.get_event())
for event_obs in self._event_observables
Expand Down Expand Up @@ -2285,7 +2287,7 @@ def get_conservation_laws(self) -> list[tuple[sp.Symbol, sp.Expr]]:
list of state identifiers
"""
return [
(state.get_id(), state.get_x_rdata())
(state.get_sym(), state.get_x_rdata())
for state in self.states()
if state.has_conservation_law()
]
Expand Down Expand Up @@ -2338,7 +2340,7 @@ def state_has_fixed_parameter_initial_condition(self, ix: int) -> bool:
if not isinstance(ic, sp.Basic):
return False
return any(
fp in (c.get_id() for c in self._constants)
fp in (c.get_sym() for c in self._constants)
for fp in ic.free_symbols
)

Expand Down Expand Up @@ -2450,20 +2452,20 @@ def _get_unique_root(

for root in roots:
if sp.simplify(root_found - root.get_val()).is_zero:
return root.get_id()
return root.get_sym()

# create an event for a new root function
root_symstr = f"Heaviside_{len(roots)}"
roots.append(
Event(
identifier=sp.Symbol(root_symstr),
symbol=sp.Symbol(root_symstr),
name=root_symstr,
value=root_found,
assignments=None,
use_values_from_trigger_time=True,
)
)
return roots[-1].get_id()
return roots[-1].get_sym()

def _collect_heaviside_roots(
self,
Expand Down Expand Up @@ -2579,22 +2581,22 @@ def _process_hybridization(self, hybridization: dict) -> None:
https://petab-sciml.readthedocs.io/latest/format.html#problem-yaml-file
"""
added_expressions = False
orig_obs = tuple([s.get_id() for s in self._observables])
orig_obs = tuple([s.get_sym() for s in self._observables])
for net_id, net in hybridization.items():
if net["static"]:
continue # do not integrate into ODEs, handle in amici.jax.petab
inputs = [
comp
for comp in self._components
if str(comp.get_id()) in net["input_vars"]
if str(comp.get_sym()) in net["input_vars"]
]
# sort inputs by order in input_vars
inputs = sorted(
inputs,
key=lambda comp: net["input_vars"].index(str(comp.get_id())),
key=lambda comp: net["input_vars"].index(str(comp.get_sym())),
)
if len(inputs) != len(net["input_vars"]):
found_vars = {str(comp.get_id()) for comp in inputs}
found_vars = {str(comp.get_sym()) for comp in inputs}
missing_vars = set(net["input_vars"]) - found_vars
raise ValueError(
f"Could not find all input variables for neural network {net_id}. "
Expand All @@ -2616,9 +2618,9 @@ def _process_hybridization(self, hybridization: dict) -> None:
outputs = {
out_var: {"comp": comp, "ind": net["output_vars"][out_var]}
for comp in self._components
if (out_var := str(comp.get_id())) in net["output_vars"]
if (out_var := str(comp.get_sym())) in net["output_vars"]
# TODO: SYNTAX NEEDS to CHANGE
or (out_var := str(comp.get_id()) + "_dot")
or (out_var := str(comp.get_sym()) + "_dot")
in net["output_vars"]
}
if len(outputs.keys()) != len(net["output_vars"]):
Expand All @@ -2645,7 +2647,7 @@ def _process_hybridization(self, hybridization: dict) -> None:

# generate dummy Function
out_val = sp.Function(net_id)(
*[input.get_id() for input in inputs], parts["ind"]
*[input.get_sym() for input in inputs], parts["ind"]
)

# add to the model
Expand All @@ -2659,7 +2661,7 @@ def _process_hybridization(self, hybridization: dict) -> None:
else:
self.add_component(
Expression(
identifier=comp.get_id(),
symbol=comp.get_sym(),
name=net_id,
value=out_val,
)
Expand All @@ -2669,7 +2671,7 @@ def _process_hybridization(self, hybridization: dict) -> None:
observables = {
ob_var: {"comp": comp, "ind": net["observable_vars"][ob_var]}
for comp in self._components
if (ob_var := str(comp.get_id())) in net["observable_vars"]
if (ob_var := str(comp.get_sym())) in net["observable_vars"]
# # TODO: SYNTAX NEEDS to CHANGE
# or (ob_var := str(comp.get_id()) + "_dot")
# in net["observable_vars"]
Expand All @@ -2691,18 +2693,18 @@ def _process_hybridization(self, hybridization: dict) -> None:
f"{comp.get_name()} ({type(comp)}) is not an observable."
)
out_val = sp.Function(net_id)(
*[input.get_id() for input in inputs], parts["ind"]
*[input.get_sym() for input in inputs], parts["ind"]
)
# add to the model
self.add_component(
Observable(
identifier=comp.get_id(),
symbol=comp.get_sym(),
name=net_id,
value=out_val,
)
)

new_order = [orig_obs.index(s.get_id()) for s in self._observables]
new_order = [orig_obs.index(s.get_sym()) for s in self._observables]
self._observables = [self._observables[i] for i in new_order]

if added_expressions:
Expand Down
Loading
Loading