Skip to content

Commit aeb5f34

Browse files
authored
Refactor de_export.py, extract sympy_utils.py (#2307)
No changes in functionality. Related to #2306.
1 parent 8271da1 commit aeb5f34

File tree

7 files changed

+241
-218
lines changed

7 files changed

+241
-218
lines changed

python/sdist/amici/de_export.py

Lines changed: 11 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
C++ Export
33
----------
4-
This module provides all necessary functionality specify an DE model and
4+
This module provides all necessary functionality specify a DE model and
55
generate executable C++ simulation code. The user generally won't have to
66
directly call any function from this module as this will be done by
77
:py:func:`amici.pysb_import.pysb2amici`,
@@ -18,12 +18,11 @@
1818
import subprocess
1919
import sys
2020
from dataclasses import dataclass
21-
from itertools import chain, starmap
21+
from itertools import chain
2222
from pathlib import Path
2323
from string import Template
2424
from typing import (
2525
TYPE_CHECKING,
26-
Any,
2726
Callable,
2827
Literal,
2928
Optional,
@@ -59,8 +58,17 @@
5958
strip_pysb,
6059
toposort_symbols,
6160
unique_preserve_order,
61+
_default_simplify,
6262
)
6363
from .logging import get_logger, log_execution_time, set_log_level
64+
from .sympy_utils import (
65+
_custom_pow_eval_derivative,
66+
_monkeypatched,
67+
smart_jacobian,
68+
smart_multiply,
69+
smart_is_zero_matrix,
70+
_parallel_applyfunc,
71+
)
6472

6573
if TYPE_CHECKING:
6674
from . import sbml_import
@@ -509,109 +517,6 @@ def var_in_function_signature(name: str, varname: str, ode: bool) -> bool:
509517
}
510518

511519

512-
@log_execution_time("running smart_jacobian", logger)
513-
def smart_jacobian(
514-
eq: sp.MutableDenseMatrix, sym_var: sp.MutableDenseMatrix
515-
) -> sp.MutableSparseMatrix:
516-
"""
517-
Wrapper around symbolic jacobian with some additional checks that reduce
518-
computation time for large matrices
519-
520-
:param eq:
521-
equation
522-
:param sym_var:
523-
differentiation variable
524-
:return:
525-
jacobian of eq wrt sym_var
526-
"""
527-
nrow = eq.shape[0]
528-
ncol = sym_var.shape[0]
529-
if (
530-
not min(eq.shape)
531-
or not min(sym_var.shape)
532-
or smart_is_zero_matrix(eq)
533-
or smart_is_zero_matrix(sym_var)
534-
):
535-
return sp.MutableSparseMatrix(nrow, ncol, dict())
536-
537-
# preprocess sparsity pattern
538-
elements = (
539-
(i, j, a, b)
540-
for i, a in enumerate(eq)
541-
for j, b in enumerate(sym_var)
542-
if a.has(b)
543-
)
544-
545-
if (n_procs := int(os.environ.get("AMICI_IMPORT_NPROCS", 1))) == 1:
546-
# serial
547-
return sp.MutableSparseMatrix(
548-
nrow, ncol, dict(starmap(_jacobian_element, elements))
549-
)
550-
551-
# parallel
552-
from multiprocessing import get_context
553-
554-
# "spawn" should avoid potential deadlocks occurring with fork
555-
# see e.g. https://stackoverflow.com/a/66113051
556-
ctx = get_context("spawn")
557-
with ctx.Pool(n_procs) as p:
558-
mapped = p.starmap(_jacobian_element, elements)
559-
return sp.MutableSparseMatrix(nrow, ncol, dict(mapped))
560-
561-
562-
@log_execution_time("running smart_multiply", logger)
563-
def smart_multiply(
564-
x: Union[sp.MutableDenseMatrix, sp.MutableSparseMatrix],
565-
y: sp.MutableDenseMatrix,
566-
) -> Union[sp.MutableDenseMatrix, sp.MutableSparseMatrix]:
567-
"""
568-
Wrapper around symbolic multiplication with some additional checks that
569-
reduce computation time for large matrices
570-
571-
:param x:
572-
educt 1
573-
:param y:
574-
educt 2
575-
:return:
576-
product
577-
"""
578-
if (
579-
not x.shape[0]
580-
or not y.shape[1]
581-
or smart_is_zero_matrix(x)
582-
or smart_is_zero_matrix(y)
583-
):
584-
return sp.zeros(x.shape[0], y.shape[1])
585-
return x.multiply(y)
586-
587-
588-
def smart_is_zero_matrix(
589-
x: Union[sp.MutableDenseMatrix, sp.MutableSparseMatrix],
590-
) -> bool:
591-
"""A faster implementation of sympy's is_zero_matrix
592-
593-
Avoids repeated indexer type checks and double iteration to distinguish
594-
False/None. Found to be about 100x faster for large matrices.
595-
596-
:param x: Matrix to check
597-
"""
598-
599-
if isinstance(x, sp.MutableDenseMatrix):
600-
return all(xx.is_zero is True for xx in x.flat())
601-
602-
if isinstance(x, list):
603-
return all(smart_is_zero_matrix(xx) for xx in x)
604-
605-
return x.nnz() == 0
606-
607-
608-
def _default_simplify(x):
609-
"""Default simplification applied in DEModel"""
610-
# We need this as a free function instead of a lambda to have it picklable
611-
# for parallel simplification
612-
return sp.powsimp(x, deep=True)
613-
614-
615520
class DEModel:
616521
"""
617522
Defines a Differential Equation as set of ModelQuantities.
@@ -4304,94 +4209,6 @@ def is_valid_identifier(x: str) -> bool:
43044209
return IDENTIFIER_PATTERN.match(x) is not None
43054210

43064211

4307-
@contextlib.contextmanager
4308-
def _monkeypatched(obj: object, name: str, patch: Any):
4309-
"""
4310-
Temporarily monkeypatches an object.
4311-
4312-
:param obj:
4313-
object to be patched
4314-
4315-
:param name:
4316-
name of the attribute to be patched
4317-
4318-
:param patch:
4319-
patched value
4320-
"""
4321-
pre_patched_value = getattr(obj, name)
4322-
setattr(obj, name, patch)
4323-
try:
4324-
yield object
4325-
finally:
4326-
setattr(obj, name, pre_patched_value)
4327-
4328-
4329-
def _custom_pow_eval_derivative(self, s):
4330-
"""
4331-
Custom Pow derivative that removes a removable singularity for
4332-
``self.base == 0`` and ``self.base.diff(s) == 0``. This function is
4333-
intended to be monkeypatched into :py:method:`sympy.Pow._eval_derivative`.
4334-
4335-
:param self:
4336-
sp.Pow class
4337-
4338-
:param s:
4339-
variable with respect to which the derivative will be computed
4340-
"""
4341-
dbase = self.base.diff(s)
4342-
dexp = self.exp.diff(s)
4343-
part1 = sp.Pow(self.base, self.exp - 1) * self.exp * dbase
4344-
part2 = self * dexp * sp.log(self.base)
4345-
if self.base.is_nonzero or dbase.is_nonzero or part2.is_zero:
4346-
# first piece never applies or is zero anyways
4347-
return part1 + part2
4348-
4349-
return part1 + sp.Piecewise(
4350-
(self.base, sp.And(sp.Eq(self.base, 0), sp.Eq(dbase, 0))),
4351-
(part2, True),
4352-
)
4353-
4354-
4355-
def _jacobian_element(i, j, eq_i, sym_var_j):
4356-
"""Compute a single element of a jacobian"""
4357-
return (i, j), eq_i.diff(sym_var_j)
4358-
4359-
4360-
def _parallel_applyfunc(obj: sp.Matrix, func: Callable) -> sp.Matrix:
4361-
"""Parallel implementation of sympy's Matrix.applyfunc"""
4362-
if (n_procs := int(os.environ.get("AMICI_IMPORT_NPROCS", 1))) == 1:
4363-
# serial
4364-
return obj.applyfunc(func)
4365-
4366-
# parallel
4367-
from multiprocessing import get_context
4368-
from pickle import PicklingError
4369-
4370-
from sympy.matrices.dense import DenseMatrix
4371-
4372-
# "spawn" should avoid potential deadlocks occurring with fork
4373-
# see e.g. https://stackoverflow.com/a/66113051
4374-
ctx = get_context("spawn")
4375-
with ctx.Pool(n_procs) as p:
4376-
try:
4377-
if isinstance(obj, DenseMatrix):
4378-
return obj._new(obj.rows, obj.cols, p.map(func, obj))
4379-
elif isinstance(obj, sp.SparseMatrix):
4380-
dok = obj.todok()
4381-
mapped = p.map(func, dok.values())
4382-
dok = {k: v for k, v in zip(dok.keys(), mapped) if v != 0}
4383-
return obj._new(obj.rows, obj.cols, dok)
4384-
else:
4385-
raise ValueError(f"Unsupported matrix type {type(obj)}")
4386-
except PicklingError as e:
4387-
raise ValueError(
4388-
f"Couldn't pickle {func}. This is likely because the argument "
4389-
"was not a module-level function. Either rewrite the argument "
4390-
"to a module-level function or disable parallelization by "
4391-
"setting `AMICI_IMPORT_NPROCS=1`."
4392-
) from e
4393-
4394-
43954212
def _write_gitignore(dest_dir: Path) -> None:
43964213
"""Write .gitignore file.
43974214

python/sdist/amici/import_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,3 +748,10 @@ def unique_preserve_order(seq: Sequence) -> list:
748748

749749
sbml_time_symbol = symbol_with_assumptions("time")
750750
amici_time_symbol = symbol_with_assumptions("t")
751+
752+
753+
def _default_simplify(x):
754+
"""Default simplification applied in DEModel"""
755+
# We need this as a free function instead of a lambda to have it picklable
756+
# for parallel simplification
757+
return sp.powsimp(x, deep=True)

python/sdist/amici/pysb_import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@
3434
Observable,
3535
Parameter,
3636
SigmaY,
37-
_default_simplify,
3837
)
3938
from .import_utils import (
4039
_get_str_symbol_identifiers,
4140
_parse_special_functions,
4241
generate_measurement_symbol,
4342
noise_distribution_to_cost_function,
4443
noise_distribution_to_observable_transformation,
44+
_default_simplify,
4545
)
4646
from .logging import get_logger, log_execution_time, set_log_level
4747

python/sdist/amici/sbml_import.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@
3131
from .de_export import (
3232
DEExporter,
3333
DEModel,
34-
_default_simplify,
35-
smart_is_zero_matrix,
3634
)
35+
from .sympy_utils import smart_is_zero_matrix
3736
from .import_utils import (
3837
RESERVED_SYMBOLS,
3938
_check_unsupported_functions,
@@ -50,6 +49,7 @@
5049
smart_subs_dict,
5150
symbol_with_assumptions,
5251
toposort_symbols,
52+
_default_simplify,
5353
)
5454
from .logging import get_logger, log_execution_time, set_log_level
5555
from .sbml_utils import SBMLException, _parse_logical_operators

0 commit comments

Comments
 (0)