Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2588,27 +2588,19 @@ def has_event_assignments(self) -> bool:
"""
return any(event.updates_state for event in self._events)

def toposort_expressions(self) -> dict[sp.Symbol, sp.Expr]:
def toposort_expressions(
self, reorder: bool = True
) -> dict[sp.Symbol, sp.Expr]:
"""
Sort expressions in topological order.

:param reorder:
Whether to reorder the internal expression list (``True``) or
just return the toposorted expressions (``False``).

:return:
dict of expression symbols to expressions in topological order
"""
# ensure no symbols or equations that depend on `w` have been generated
# yet, otherwise the re-ordering might break dependencies
if (
generated := set(self._syms)
| set(self._eqs)
| set(self._sparsesyms)
| set(self._sparseeqs)
) - {"w", "p", "k", "x", "x_rdata"}:
raise AssertionError(
"This function must be called before computing any "
"derivatives. The following symbols/equations are already "
f"generated: {generated}"
)

# NOTE: elsewhere, conservations law expressions are expected to
# occur before any other w expressions, so we must maintain their
# position.
Expand All @@ -2627,6 +2619,23 @@ def toposort_expressions(self) -> dict[sp.Symbol, sp.Expr]:
for e in self.expressions()[: self.num_cons_law()]
} | w_toposorted

if not reorder:
return w_toposorted

# ensure no symbols or equations that depend on `w` have been generated
# yet, otherwise the re-ordering might break dependencies
if (
generated := set(self._syms)
| set(self._eqs)
| set(self._sparsesyms)
| set(self._sparseeqs)
) - {"w", "p", "k", "x", "x_rdata"}:
raise AssertionError(
"This function must be called before computing any "
"derivatives. The following symbols/equations are already "
f"generated: {generated}"
)

old_syms = tuple(e.get_sym() for e in self.expressions())
topo_expr_syms = tuple(w_toposorted)
new_order = [old_syms.index(s) for s in topo_expr_syms]
Expand Down
13 changes: 11 additions & 2 deletions python/sdist/amici/importers/sbml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3094,11 +3094,20 @@ def do_subs(expr, rate_ofs) -> sp.Expr:

# replace rateOf-instances in expressions which we will need for
# substitutions later
expressions_changed = False
for expr in de_model.expressions():
if rate_ofs := expr.get_val().find(rate_of_func):
expr.set_val(do_subs(expr.get_val(), rate_ofs))

w_toposorted = de_model.toposort_expressions()
expressions_changed = True

# get toposorted `w`, but only changed the order of expressions in the
# model if any expression changed.
# (in principle, we could always reorder the expressions, but this
# will require regenerating all test oracle for
# `test_pregenerated_models.py`)
w_toposorted = de_model.toposort_expressions(
reorder=expressions_changed
)

# replace rateOf-instances in x0
# indices of state variables whose x0 was modified
Expand Down
Loading