Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 61 additions & 21 deletions python/sdist/amici/importers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def smart_subs_dict(
subs: SymbolDef,
field: str | None = None,
reverse: bool = True,
flatten_first: bool | None = None,
) -> sp.Expr:
"""
Substitutes expressions completely flattening them out. Requires
Expand All @@ -418,6 +419,11 @@ def smart_subs_dict(
Whether ordering in subs should be reversed. Note that substitution
requires the reverse order of what is required for evaluation.

:param flatten_first:
Choice of algorithm: Flatten the substitution expressions first, then
substitute them simultaneously into `sym` (``True``), or substitute
them one by one into `sym` (``False``).

:return:
Substituted symbolic expression
"""
Expand All @@ -426,27 +432,61 @@ def smart_subs_dict(
else:
s = [(eid, expr[field]) for eid, expr in subs.items()]

if not reverse:
# counter-intuitive, but we need to reverse the order for reverse=False
s.reverse()

with sp.evaluate(False):
# The new expressions may themselves contain symbols to be substituted.
# We flatten them out first, so that the substitutions in `sym` can be
# performed simultaneously, which is usually more efficient than
# repeatedly substituting into `sym`.
# TODO(performance): This could probably be made more efficient by
# combining with toposort used to order `subs` in the first place.
# Some substitutions could be combined, and some terms not present in
# `sym` could be skipped.
for i in range(len(s) - 1):
for j in range(i + 1, len(s)):
if s[j][1].has(s[i][0]):
s[j] = s[j][0], s[j][1].xreplace({s[i][0]: s[i][1]})

s = dict(s)
sym = sym.xreplace(s)
return sym
# We have the choice to flatten the replacement expressions first or to
# substitute them one by one into `sym`. Flattening first is usually
# more efficient if `sym` is large (e.g., a matrix with many elements)
# and `subs` is cascading (i.e., substitutions depend on other
# substitutions). Otherwise, substituting one by one is usually more
# efficient, because flattening scales quadratically with the number of
# substitutions.
# The exact threshold is somewhat arbitrary and may need to be
# adjusted in the future.
if flatten_first is None:
flatten_first = (
isinstance(sym, sp.MatrixBase) and sym.rows * sym.cols > 20
)

if flatten_first:
if not reverse:
# counter-intuitive, but on this branch, we need to reverse the
# order for `reverse=False`
s.reverse()

with sp.evaluate(False):
# The new expressions may themselves contain symbols to be
# substituted. We flatten them out first, so that the
# substitutions in `sym` can be performed simultaneously,
# which can be more efficient than repeatedly substituting into
# `sym`.
# TODO(performance): This could probably be made more efficient by
# combining with toposort used to order `subs` in the first
# place.
# Some substitutions could be combined, and some terms not
# present in `sym` could be skipped.
# Furthermore, this would provide information on recursion depth,
# which might help decide which strategy is more efficient.
# For flat hierarchies, substituting one by one is most likely
# more efficient.
for i in range(len(s) - 1):
for j in range(i + 1, len(s)):
if s[j][1].has(s[i][0]):
s[j] = s[j][0], s[j][1].xreplace({s[i][0]: s[i][1]})

s = dict(s)
sym = sym.xreplace(s)
return sym

else:
if reverse:
s.reverse()

with sp.evaluate(False):
for old, new in s:
# note that substitution may change free symbols,
# so we have to do this recursively
if sym.has(old):
sym = sym.xreplace({old: new})
return sym


def smart_subs(element: sp.Expr, old: sp.Symbol, new: sp.Expr) -> sp.Expr:
Expand Down
11 changes: 8 additions & 3 deletions python/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def test_cmake_compilation(sbml_example_presimulation_module):


@skip_on_valgrind
def test_smart_subs_dict():
@pytest.mark.parametrize("flatten_first", [True, False])
def test_smart_subs_dict(flatten_first):
expr_str = "c + d"
subs_dict = {
"c": "a + b",
Expand All @@ -98,8 +99,12 @@ def test_smart_subs_dict():
expected_default = sp.sympify(expected_default_str)
expected_reverse = sp.sympify(expected_reverse_str)

result_default = smart_subs_dict(expr_sym, subs_sym)
result_reverse = smart_subs_dict(expr_sym, subs_sym, reverse=False)
result_default = smart_subs_dict(
expr_sym, subs_sym, flatten_first=flatten_first
)
result_reverse = smart_subs_dict(
expr_sym, subs_sym, reverse=False, flatten_first=flatten_first
)

assert sp.simplify(result_default - expected_default).is_zero
assert sp.simplify(result_reverse - expected_reverse).is_zero
Expand Down
8 changes: 4 additions & 4 deletions tests/performance/reference.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Reference wall times (seconds) with some buffer
create_sdist: 16
install_sdist: 150
petab_import: 2100
install_model: 120
petab_import: 720
install_model: 60
install_model_O0: 40
install_model_O1: 90
install_model_O2: 120
install_model_O1: 45
install_model_O2: 60
forward_simulation: 2
forward_sensitivities: 2
adjoint_sensitivities: 2.5
Expand Down
Loading