diff --git a/python/sdist/amici/importers/utils.py b/python/sdist/amici/importers/utils.py index 2ad7e32322..ff05260d81 100644 --- a/python/sdist/amici/importers/utils.py +++ b/python/sdist/amici/importers/utils.py @@ -400,6 +400,7 @@ def smart_subs_dict( subs: SymbolDef, field: str | None = None, reverse: bool = True, + flatten_first: bool | None = None, ) -> sp.Expr: """ Substitutes expressions completely flattening them out. Requires @@ -418,6 +419,11 @@ def smart_subs_dict( Whether ordering in subs should be reversed. Note that substitution requires the reverse order of what is required for evaluation. + :param flatten_first: + Choice of algorithm: Flatten the substitution expressions first, then + substitute them simultaneously into `sym` (``True``), or substitute + them one by one into `sym` (``False``). + :return: Substituted symbolic expression """ @@ -426,27 +432,61 @@ def smart_subs_dict( else: s = [(eid, expr[field]) for eid, expr in subs.items()] - if not reverse: - # counter-intuitive, but we need to reverse the order for reverse=False - s.reverse() - - with sp.evaluate(False): - # The new expressions may themselves contain symbols to be substituted. - # We flatten them out first, so that the substitutions in `sym` can be - # performed simultaneously, which is usually more efficient than - # repeatedly substituting into `sym`. - # TODO(performance): This could probably be made more efficient by - # combining with toposort used to order `subs` in the first place. - # Some substitutions could be combined, and some terms not present in - # `sym` could be skipped. - for i in range(len(s) - 1): - for j in range(i + 1, len(s)): - if s[j][1].has(s[i][0]): - s[j] = s[j][0], s[j][1].xreplace({s[i][0]: s[i][1]}) - - s = dict(s) - sym = sym.xreplace(s) - return sym + # We have the choice to flatten the replacement expressions first or to + # substitute them one by one into `sym`. Flattening first is usually + # more efficient if `sym` is large (e.g., a matrix with many elements) + # and `subs` is cascading (i.e., substitutions depend on other + # substitutions). Otherwise, substituting one by one is usually more + # efficient, because flattening scales quadratically with the number of + # substitutions. + # The exact threshold is somewhat arbitrary and may need to be + # adjusted in the future. + if flatten_first is None: + flatten_first = ( + isinstance(sym, sp.MatrixBase) and sym.rows * sym.cols > 20 + ) + + if flatten_first: + if not reverse: + # counter-intuitive, but on this branch, we need to reverse the + # order for `reverse=False` + s.reverse() + + with sp.evaluate(False): + # The new expressions may themselves contain symbols to be + # substituted. We flatten them out first, so that the + # substitutions in `sym` can be performed simultaneously, + # which can be more efficient than repeatedly substituting into + # `sym`. + # TODO(performance): This could probably be made more efficient by + # combining with toposort used to order `subs` in the first + # place. + # Some substitutions could be combined, and some terms not + # present in `sym` could be skipped. + # Furthermore, this would provide information on recursion depth, + # which might help decide which strategy is more efficient. + # For flat hierarchies, substituting one by one is most likely + # more efficient. + for i in range(len(s) - 1): + for j in range(i + 1, len(s)): + if s[j][1].has(s[i][0]): + s[j] = s[j][0], s[j][1].xreplace({s[i][0]: s[i][1]}) + + s = dict(s) + sym = sym.xreplace(s) + return sym + + else: + if reverse: + s.reverse() + + with sp.evaluate(False): + for old, new in s: + # note that substitution may change free symbols, + # so we have to do this recursively + if sym.has(old): + sym = sym.xreplace({old: new}) + return sym def smart_subs(element: sp.Expr, old: sp.Symbol, new: sp.Expr) -> sp.Expr: diff --git a/python/tests/test_misc.py b/python/tests/test_misc.py index 9ca0b82bfa..83aa6a5d4b 100644 --- a/python/tests/test_misc.py +++ b/python/tests/test_misc.py @@ -84,7 +84,8 @@ def test_cmake_compilation(sbml_example_presimulation_module): @skip_on_valgrind -def test_smart_subs_dict(): +@pytest.mark.parametrize("flatten_first", [True, False]) +def test_smart_subs_dict(flatten_first): expr_str = "c + d" subs_dict = { "c": "a + b", @@ -98,8 +99,12 @@ def test_smart_subs_dict(): expected_default = sp.sympify(expected_default_str) expected_reverse = sp.sympify(expected_reverse_str) - result_default = smart_subs_dict(expr_sym, subs_sym) - result_reverse = smart_subs_dict(expr_sym, subs_sym, reverse=False) + result_default = smart_subs_dict( + expr_sym, subs_sym, flatten_first=flatten_first + ) + result_reverse = smart_subs_dict( + expr_sym, subs_sym, reverse=False, flatten_first=flatten_first + ) assert sp.simplify(result_default - expected_default).is_zero assert sp.simplify(result_reverse - expected_reverse).is_zero diff --git a/tests/performance/reference.yml b/tests/performance/reference.yml index f72387ae5e..0f2a090615 100644 --- a/tests/performance/reference.yml +++ b/tests/performance/reference.yml @@ -1,11 +1,11 @@ # Reference wall times (seconds) with some buffer create_sdist: 16 install_sdist: 150 -petab_import: 2100 -install_model: 120 +petab_import: 720 +install_model: 60 install_model_O0: 40 -install_model_O1: 90 -install_model_O2: 120 +install_model_O1: 45 +install_model_O2: 60 forward_simulation: 2 forward_sensitivities: 2 adjoint_sensitivities: 2.5