diff --git a/python/sdist/amici/importers/utils.py b/python/sdist/amici/importers/utils.py index a4348783a8..2ad7e32322 100644 --- a/python/sdist/amici/importers/utils.py +++ b/python/sdist/amici/importers/utils.py @@ -422,20 +422,30 @@ def smart_subs_dict( Substituted symbolic expression """ if field is None: - s = [(eid, expr) for eid, expr in subs.items()] + s = list(subs.items()) else: s = [(eid, expr[field]) for eid, expr in subs.items()] - if reverse: + if not reverse: + # counter-intuitive, but we need to reverse the order for reverse=False s.reverse() with sp.evaluate(False): - for old, new in s: - # note that substitution may change free symbols, so we have to do - # this recursively - if sym.has(old): - sym = sym.xreplace({old: new}) - + # The new expressions may themselves contain symbols to be substituted. + # We flatten them out first, so that the substitutions in `sym` can be + # performed simultaneously, which is usually more efficient than + # repeatedly substituting into `sym`. + # TODO(performance): This could probably be made more efficient by + # combining with toposort used to order `subs` in the first place. + # Some substitutions could be combined, and some terms not present in + # `sym` could be skipped. + for i in range(len(s) - 1): + for j in range(i + 1, len(s)): + if s[j][1].has(s[i][0]): + s[j] = s[j][0], s[j][1].xreplace({s[i][0]: s[i][1]}) + + s = dict(s) + sym = sym.xreplace(s) return sym @@ -450,7 +460,7 @@ def smart_subs(element: sp.Expr, old: sp.Symbol, new: sp.Expr) -> sp.Expr: to be substituted :param new: - subsitution value + substitution value :return: substituted expression