1+ # cspell:ignore pbarksigma
12from __future__ import annotations
23
34from typing import TYPE_CHECKING
45
56import pytest
7+ import sympy as sp
68
79from ampform .sympy import cached
810
911if TYPE_CHECKING :
10- import sympy as sp
11-
1212 from ampform .helicity import HelicityModel
1313
1414
@@ -23,19 +23,69 @@ def test_doit(amplitude_model: tuple[str, HelicityModel]):
2323 assert unfolded_expr_2 == expected_expr
2424
2525
26+ def test_simplify ():
27+ a , b , c , d , x , y , z = sp .symbols ("a b c d x y z" )
28+ expr = (
29+ (a * x + b * y + c * z + d ) ** 2
30+ - (a * x ) ** 2
31+ - (b * y ) ** 2
32+ - (c * z ) ** 2
33+ - 2 * a * b * x * y
34+ - 2 * a * c * x * z
35+ - 2 * b * c * y * z
36+ - 2 * d * (a * x + b * y + c * z )
37+ )
38+
39+ simplified = expr .simplify ()
40+ assert simplified != expr
41+ assert simplified == d ** 2
42+
43+ cached_simplified = cached .simplify (expr )
44+ assert simplified == cached_simplified
45+
46+
47+ @pytest .mark .parametrize ("amplitude_idx" , list (range (4 )))
48+ def test_simplify_model (amplitude_model : tuple [str , HelicityModel ], amplitude_idx : int ):
49+ _ , model = amplitude_model
50+ amplitudes = [model .amplitudes [k ] for k in sorted (model .amplitudes , key = str )]
51+ amplitude_expr = amplitudes [amplitude_idx ]
52+
53+ simplified = amplitude_expr .simplify ()
54+ assert simplified != amplitude_expr
55+
56+ cached_simplified = cached .simplify (amplitude_expr )
57+ assert simplified == cached_simplified
58+
59+
60+ def test_trigsimp ():
61+ x , y = sp .symbols ("x y" )
62+ expr = (sp .sin (x ) * sp .cos (y ) + sp .cos (x ) * sp .sin (y )) ** 2 + (
63+ sp .cos (x ) * sp .cos (y ) - sp .sin (x ) * sp .sin (y )
64+ ) ** 2
65+ simplified = sp .trigsimp (expr )
66+ assert expr != simplified
67+ assert simplified == 1
68+ cached_simplified = cached .trigsimp (expr )
69+ assert simplified == cached_simplified
70+
71+
72+ @pytest .mark .parametrize ("substitute" , ["subs" , "xreplace" ])
2673@pytest .mark .parametrize (
2774 "substitution_name" , ["parameter_defaults" , "kinematic_variables" ]
2875)
29- def test_xreplace (amplitude_model : tuple [str , HelicityModel ], substitution_name : str ):
76+ def test_xreplace (
77+ amplitude_model : tuple [str , HelicityModel ], substitute : str , substitution_name : str
78+ ):
79+ cached_func = getattr (cached , substitute )
3080 _ , model = amplitude_model
3181 full_expression = model .expression .doit ()
3282 substitutions : dict [sp .Symbol , sp .Basic ] = getattr (model , substitution_name )
3383 expected_expr = full_expression .xreplace (substitutions )
3484 assert expected_expr != full_expression
3585
36- substituted_expr_1 = cached . xreplace (full_expression , substitutions )
86+ substituted_expr_1 = cached_func (full_expression , substitutions )
3787 assert substituted_expr_1 == expected_expr
38- substituted_expr_2 = cached . xreplace (full_expression , substitutions )
88+ substituted_expr_2 = cached_func (full_expression , substitutions )
3989 assert substituted_expr_2 == expected_expr
4090
4191
0 commit comments