Skip to content

Commit a043ffe

Browse files
authored
FEAT: implement cached.simplify etc (#467)
* FEAT: implement cached simplify and trigsimp * FEAT: implement `cached.subs` * FIX: add back `cached.doit()` to API
1 parent a5aa68f commit a043ffe

File tree

4 files changed

+102
-8
lines changed

4 files changed

+102
-8
lines changed

docs/conf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _get_dataclasses(module):
107107
"P": "typing.ParamSpec",
108108
"ParameterValue": ("obj", "ampform.helicity.ParameterValue"),
109109
"Particle": "qrules.particle.Particle",
110+
"SympyObject": "typing.TypeVar",
110111
"ReactionInfo": "qrules.transition.ReactionInfo",
111112
"Slider": ("obj", "symplot.Slider"),
112113
"State": "qrules.transition.State",
@@ -323,10 +324,12 @@ def _get_dataclasses(module):
323324
nb_execution_timeout = -1
324325
nb_output_stderr = "remove"
325326
nitpick_ignore = [
327+
("py:class", "ampform.sympy._decorator.SymPyAssumptions"),
326328
("py:class", "ArraySum"),
329+
("py:class", "BufferedReader"),
327330
("py:class", "ExprClass"),
328331
("py:class", "MatrixMultiplication"),
329-
("py:class", "ampform.sympy._decorator.SymPyAssumptions"),
332+
("py:class", "SupportsWrite"),
330333
]
331334
nitpicky = True
332335
primary_domain = "py"

src/ampform/sympy/_cache.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
"""Helper functions for :func:`.cached.doit` and related functions."""
1+
"""Helper functions for :func:`.cached.doit` and related functions.
2+
3+
These methods are private, but can be imported from this module:
4+
5+
.. code-block:: python
6+
7+
import ampform.sympy._cache
8+
"""
29

310
from __future__ import annotations
411

src/ampform/sympy/cached.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
"""Handy aliases for working with cached SymPy expressions."""
1+
"""Handy aliases for working with cached SymPy expressions.
2+
3+
.. autofunction:: doit
4+
"""
25

36
from __future__ import annotations
47

@@ -35,6 +38,37 @@ def doit(expr: SympyObject) -> SympyObject:
3538
return expr.doit()
3639

3740

41+
@cache
42+
@cache_to_disk
43+
def simplify(expr: sp.Expr, *args, **kwargs) -> sp.Expr:
44+
"""Perform :func:`~sympy.simplify.simplify.simplify` and cache the result to disk.
45+
46+
.. versionadded:: 0.15.7
47+
"""
48+
return sp.simplify(expr, *args, **kwargs)
49+
50+
51+
@cache
52+
@cache_to_disk
53+
def trigsimp(expr: sp.Expr, *args, **kwargs) -> sp.Expr:
54+
"""Perform :func:`~sympy.simplify.trigsimp.trigsimp` and cache the result to disk.
55+
56+
.. versionadded:: 0.15.7
57+
"""
58+
return sp.trigsimp(expr, *args, **kwargs)
59+
60+
61+
def subs(expr: sp.Expr, substitutions: Mapping[sp.Basic, sp.Basic]) -> sp.Expr:
62+
"""Call :meth:`~sympy.core.basic.Basic.subs` and cache the result to disk."""
63+
return _subs_impl(expr, frozendict(substitutions))
64+
65+
66+
@cache
67+
@cache_to_disk(function_name="subs", dependencies=["sympy"])
68+
def _subs_impl(expr: sp.Expr, substitutions: frozendict[sp.Basic, sp.Basic]) -> sp.Expr:
69+
return expr.xreplace(substitutions)
70+
71+
3872
def xreplace(expr: sp.Expr, substitutions: Mapping[sp.Basic, sp.Basic]) -> sp.Expr:
3973
"""Call :meth:`~sympy.core.basic.Basic.xreplace` and cache the result to disk."""
4074
return _xreplace_impl(expr, frozendict(substitutions))

tests/sympy/test_cached.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1+
# cspell:ignore pbarksigma
12
from __future__ import annotations
23

34
from typing import TYPE_CHECKING
45

56
import pytest
7+
import sympy as sp
68

79
from ampform.sympy import cached
810

911
if TYPE_CHECKING:
10-
import sympy as sp
11-
1212
from ampform.helicity import HelicityModel
1313

1414

@@ -23,19 +23,69 @@ def test_doit(amplitude_model: tuple[str, HelicityModel]):
2323
assert unfolded_expr_2 == expected_expr
2424

2525

26+
def test_simplify():
27+
a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
28+
expr = (
29+
(a * x + b * y + c * z + d) ** 2
30+
- (a * x) ** 2
31+
- (b * y) ** 2
32+
- (c * z) ** 2
33+
- 2 * a * b * x * y
34+
- 2 * a * c * x * z
35+
- 2 * b * c * y * z
36+
- 2 * d * (a * x + b * y + c * z)
37+
)
38+
39+
simplified = expr.simplify()
40+
assert simplified != expr
41+
assert simplified == d**2
42+
43+
cached_simplified = cached.simplify(expr)
44+
assert simplified == cached_simplified
45+
46+
47+
@pytest.mark.parametrize("amplitude_idx", list(range(4)))
48+
def test_simplify_model(amplitude_model: tuple[str, HelicityModel], amplitude_idx: int):
49+
_, model = amplitude_model
50+
amplitudes = [model.amplitudes[k] for k in sorted(model.amplitudes, key=str)]
51+
amplitude_expr = amplitudes[amplitude_idx]
52+
53+
simplified = amplitude_expr.simplify()
54+
assert simplified != amplitude_expr
55+
56+
cached_simplified = cached.simplify(amplitude_expr)
57+
assert simplified == cached_simplified
58+
59+
60+
def test_trigsimp():
61+
x, y = sp.symbols("x y")
62+
expr = (sp.sin(x) * sp.cos(y) + sp.cos(x) * sp.sin(y)) ** 2 + (
63+
sp.cos(x) * sp.cos(y) - sp.sin(x) * sp.sin(y)
64+
) ** 2
65+
simplified = sp.trigsimp(expr)
66+
assert expr != simplified
67+
assert simplified == 1
68+
cached_simplified = cached.trigsimp(expr)
69+
assert simplified == cached_simplified
70+
71+
72+
@pytest.mark.parametrize("substitute", ["subs", "xreplace"])
2673
@pytest.mark.parametrize(
2774
"substitution_name", ["parameter_defaults", "kinematic_variables"]
2875
)
29-
def test_xreplace(amplitude_model: tuple[str, HelicityModel], substitution_name: str):
76+
def test_xreplace(
77+
amplitude_model: tuple[str, HelicityModel], substitute: str, substitution_name: str
78+
):
79+
cached_func = getattr(cached, substitute)
3080
_, model = amplitude_model
3181
full_expression = model.expression.doit()
3282
substitutions: dict[sp.Symbol, sp.Basic] = getattr(model, substitution_name)
3383
expected_expr = full_expression.xreplace(substitutions)
3484
assert expected_expr != full_expression
3585

36-
substituted_expr_1 = cached.xreplace(full_expression, substitutions)
86+
substituted_expr_1 = cached_func(full_expression, substitutions)
3787
assert substituted_expr_1 == expected_expr
38-
substituted_expr_2 = cached.xreplace(full_expression, substitutions)
88+
substituted_expr_2 = cached_func(full_expression, substitutions)
3989
assert substituted_expr_2 == expected_expr
4090

4191

0 commit comments

Comments
 (0)