@@ -2588,27 +2588,19 @@ def has_event_assignments(self) -> bool:
25882588 """
25892589 return any (event .updates_state for event in self ._events )
25902590
2591- def toposort_expressions (self ) -> dict [sp .Symbol , sp .Expr ]:
2591+ def toposort_expressions (
2592+ self , reorder : bool = True
2593+ ) -> dict [sp .Symbol , sp .Expr ]:
25922594 """
25932595 Sort expressions in topological order.
25942596
2597+ :param reorder:
2598+ Whether to reorder the internal expression list (``True``) or
2599+ just return the toposorted expressions (``False``).
2600+
25952601 :return:
25962602 dict of expression symbols to expressions in topological order
25972603 """
2598- # ensure no symbols or equations that depend on `w` have been generated
2599- # yet, otherwise the re-ordering might break dependencies
2600- if (
2601- generated := set (self ._syms )
2602- | set (self ._eqs )
2603- | set (self ._sparsesyms )
2604- | set (self ._sparseeqs )
2605- ) - {"w" , "p" , "k" , "x" , "x_rdata" }:
2606- raise AssertionError (
2607- "This function must be called before computing any "
2608- "derivatives. The following symbols/equations are already "
2609- f"generated: { generated } "
2610- )
2611-
26122604 # NOTE: elsewhere, conservations law expressions are expected to
26132605 # occur before any other w expressions, so we must maintain their
26142606 # position.
@@ -2627,6 +2619,23 @@ def toposort_expressions(self) -> dict[sp.Symbol, sp.Expr]:
26272619 for e in self .expressions ()[: self .num_cons_law ()]
26282620 } | w_toposorted
26292621
2622+ if not reorder :
2623+ return w_toposorted
2624+
2625+ # ensure no symbols or equations that depend on `w` have been generated
2626+ # yet, otherwise the re-ordering might break dependencies
2627+ if (
2628+ generated := set (self ._syms )
2629+ | set (self ._eqs )
2630+ | set (self ._sparsesyms )
2631+ | set (self ._sparseeqs )
2632+ ) - {"w" , "p" , "k" , "x" , "x_rdata" }:
2633+ raise AssertionError (
2634+ "This function must be called before computing any "
2635+ "derivatives. The following symbols/equations are already "
2636+ f"generated: { generated } "
2637+ )
2638+
26302639 old_syms = tuple (e .get_sym () for e in self .expressions ())
26312640 topo_expr_syms = tuple (w_toposorted )
26322641 new_order = [old_syms .index (s ) for s in topo_expr_syms ]
0 commit comments