Skip to content

Commit 8fc33dc

Browse files
authored
Restore ordering of w (#3051)
With #3036, the ordering of `w` symbols and expressions has changed. This is not a problem per se, but would require recreating the test oracles for the models in `models/`. Therefore, don't do any unnecessary reordering.
1 parent 6039cec commit 8fc33dc

File tree

2 files changed

+35
-17
lines changed

2 files changed

+35
-17
lines changed

python/sdist/amici/de_model.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2588,27 +2588,19 @@ def has_event_assignments(self) -> bool:
25882588
"""
25892589
return any(event.updates_state for event in self._events)
25902590

2591-
def toposort_expressions(self) -> dict[sp.Symbol, sp.Expr]:
2591+
def toposort_expressions(
2592+
self, reorder: bool = True
2593+
) -> dict[sp.Symbol, sp.Expr]:
25922594
"""
25932595
Sort expressions in topological order.
25942596
2597+
:param reorder:
2598+
Whether to reorder the internal expression list (``True``) or
2599+
just return the toposorted expressions (``False``).
2600+
25952601
:return:
25962602
dict of expression symbols to expressions in topological order
25972603
"""
2598-
# ensure no symbols or equations that depend on `w` have been generated
2599-
# yet, otherwise the re-ordering might break dependencies
2600-
if (
2601-
generated := set(self._syms)
2602-
| set(self._eqs)
2603-
| set(self._sparsesyms)
2604-
| set(self._sparseeqs)
2605-
) - {"w", "p", "k", "x", "x_rdata"}:
2606-
raise AssertionError(
2607-
"This function must be called before computing any "
2608-
"derivatives. The following symbols/equations are already "
2609-
f"generated: {generated}"
2610-
)
2611-
26122604
# NOTE: elsewhere, conservations law expressions are expected to
26132605
# occur before any other w expressions, so we must maintain their
26142606
# position.
@@ -2627,6 +2619,23 @@ def toposort_expressions(self) -> dict[sp.Symbol, sp.Expr]:
26272619
for e in self.expressions()[: self.num_cons_law()]
26282620
} | w_toposorted
26292621

2622+
if not reorder:
2623+
return w_toposorted
2624+
2625+
# ensure no symbols or equations that depend on `w` have been generated
2626+
# yet, otherwise the re-ordering might break dependencies
2627+
if (
2628+
generated := set(self._syms)
2629+
| set(self._eqs)
2630+
| set(self._sparsesyms)
2631+
| set(self._sparseeqs)
2632+
) - {"w", "p", "k", "x", "x_rdata"}:
2633+
raise AssertionError(
2634+
"This function must be called before computing any "
2635+
"derivatives. The following symbols/equations are already "
2636+
f"generated: {generated}"
2637+
)
2638+
26302639
old_syms = tuple(e.get_sym() for e in self.expressions())
26312640
topo_expr_syms = tuple(w_toposorted)
26322641
new_order = [old_syms.index(s) for s in topo_expr_syms]

python/sdist/amici/importers/sbml/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3094,11 +3094,20 @@ def do_subs(expr, rate_ofs) -> sp.Expr:
30943094

30953095
# replace rateOf-instances in expressions which we will need for
30963096
# substitutions later
3097+
expressions_changed = False
30973098
for expr in de_model.expressions():
30983099
if rate_ofs := expr.get_val().find(rate_of_func):
30993100
expr.set_val(do_subs(expr.get_val(), rate_ofs))
3100-
3101-
w_toposorted = de_model.toposort_expressions()
3101+
expressions_changed = True
3102+
3103+
# get toposorted `w`, but only changed the order of expressions in the
3104+
# model if any expression changed.
3105+
# (in principle, we could always reorder the expressions, but this
3106+
# will require regenerating all test oracle for
3107+
# `test_pregenerated_models.py`)
3108+
w_toposorted = de_model.toposort_expressions(
3109+
reorder=expressions_changed
3110+
)
31023111

31033112
# replace rateOf-instances in x0
31043113
# indices of state variables whose x0 was modified

0 commit comments

Comments
 (0)