55import functools
66import operator
77from collections import abc
8- from functools import cache
8+ from collections .abc import Callable
9+ from functools import cache , wraps
910from itertools import product
10- from typing import TYPE_CHECKING , Protocol
11+ from typing import TYPE_CHECKING , Any , Protocol
1112from warnings import warn
1213
14+ import attrs
1315import sympy as sp
1416from ampform .kinematics .phasespace import compute_third_mandelstam
1517from ampform .sympy import PoolSum
3537from ampform_dpd .spin import create_spin_range
3638
3739if TYPE_CHECKING :
38- from collections .abc import Iterable
40+ from collections .abc import Callable , Iterable
3941 from typing import Any , Literal
4042
4143
@@ -181,12 +183,10 @@ def formulate_subsystem_amplitude( # noqa: PLR0914
181183 self .decay .final_state [3 ].spin ,
182184 )
183185 λR = sp .Symbol (R"\lambda_R" , rational = True )
184- terms = []
185- parameter_defaults : dict [sp .Basic , complex ] = {}
186+ amplitude_sum = DefinedExpression (0 )
186187 for chain in self .decay .get_subsystem (subsystem_id ).chains :
187188 formulate_dynamics = self .dynamics_choices .get_builder (chain .resonance .name )
188- dynamics = formulate_dynamics (chain )
189- parameter_defaults .update (dynamics .parameters )
189+ amplitude = formulate_dynamics (chain )
190190 resonance_spin = sp .Rational (chain .resonance .spin )
191191 resonance_helicities = create_spin_range (resonance_spin )
192192 for λR_val in resonance_helicities :
@@ -200,26 +200,25 @@ def formulate_subsystem_amplitude( # noqa: PLR0914
200200 )
201201 if isinstance (scaling_factors , tuple ):
202202 h_prod , h_dec = scaling_factors
203- parameter_defaults [h_prod ] = 1 + 0j
204- parameter_defaults [h_dec ] = 1
203+ amplitude . parameters [h_prod ] = 1 + 0j
204+ amplitude . parameters [h_dec ] = 1
205205 else :
206- parameter_defaults [scaling_factors ] = 1 + 0j
206+ amplitude . parameters [scaling_factors ] = 1 + 0j
207207 scaling_factors = _create_scaling_factors (
208208 chain ,
209209 (self .use_production_helicity_couplings , λR , λ [k ]),
210210 (self .use_decay_helicity_couplings , λ [i ], λ [j ]),
211211 one_scalar_per_chain = use_coefficients ,
212212 )
213- sub_amp_expr = (
213+ amplitude * = (
214214 sp .KroneckerDelta (λ [0 ], λR - λ [k ])
215215 * (- 1 ) ** (spin [k ] - λ [k ])
216- * dynamics .expression
217216 * Wigner .d (resonance_spin , λR , λ [i ] - λ [j ], θij )
218217 * _product (scaling_factors )
219218 * (- 1 ) ** (spin [j ] - λ [j ])
220219 )
221220 if not self .use_decay_helicity_couplings :
222- sub_amp_expr *= _formulate_clebsch_gordan_factors (
221+ amplitude *= _formulate_clebsch_gordan_factors (
223222 chain .decay_node ,
224223 helicities = {
225224 self .decay .final_state [i ]: λ [i ],
@@ -228,27 +227,25 @@ def formulate_subsystem_amplitude( # noqa: PLR0914
228227 )
229228 if not self .use_production_helicity_couplings :
230229 production_isobar = chain .decay
231- sub_amp_expr *= _formulate_clebsch_gordan_factors (
230+ amplitude *= _formulate_clebsch_gordan_factors (
232231 production_isobar ,
233232 helicities = {
234233 chain .resonance : λR ,
235234 self .decay .final_state [k ]: λ [k ],
236235 },
237236 )
238- sub_amp = PoolSum (
239- sub_amp_expr ,
240- ( λR , resonance_helicities ),
237+ amplitude_sum += attrs . evolve (
238+ amplitude ,
239+ expression = PoolSum ( amplitude . expression , ( λR , resonance_helicities ) ),
241240 )
242- terms .append (sub_amp )
243241 A = _generate_amplitude_index_bases ()
244242 amp_symbol = A [subsystem_id ][λ0 , λ1 , λ2 , λ3 ]
245- amp_expr = sp .Add (* terms )
246243 return AmplitudeModel (
247244 decay = self .decay ,
248245 intensity = sp .Abs (amp_symbol ) ** 2 ,
249- amplitudes = {amp_symbol : amp_expr },
250- variables = {θij : θij_expr },
251- parameter_defaults = parameter_defaults ,
246+ amplitudes = {amp_symbol : amplitude_sum . expression },
247+ variables = amplitude_sum . subexpressions | {θij : θij_expr },
248+ parameter_defaults = amplitude_sum . parameters ,
252249 )
253250
254251 def formulate_aligned_amplitude (
@@ -480,27 +477,43 @@ class DynamicsBuilder(Protocol):
480477 def __call__ (self , decay_chain : ThreeBodyDecayChain ) -> DefinedExpression : ...
481478
482479
480+ def _binary_operation (op : Callable [[Any , Any ], Any ]):
481+ def decorator (func ):
482+ @wraps (func )
483+ def wrapper (self : DefinedExpression , other ):
484+ if isinstance (other , DefinedExpression ):
485+ return DefinedExpression (
486+ expression = op (self .expression , other .expression ),
487+ parameters = self .parameters | other .parameters ,
488+ subexpressions = self .subexpressions | other .subexpressions ,
489+ )
490+ return DefinedExpression (
491+ expression = op (self .expression , other ),
492+ parameters = self .parameters ,
493+ subexpressions = self .subexpressions ,
494+ )
495+
496+ return wrapper
497+
498+ return decorator
499+
500+
483501@define
484502class DefinedExpression :
485- expression : sp .Expr = sp .S .One
503+ expression : sp .Expr = field ( converter = sp .sympify , default = sp . S .One ) # type:ignore[misc]
486504 parameters : dict [sp .Symbol , complex | float ] = field (factory = dict )
487-
488- def __mul__ (self , other : Any ) -> DefinedExpression :
489- if isinstance (other , DefinedExpression ):
490- return DefinedExpression (
491- expression = self .expression * other .expression ,
492- parameters = {** self .parameters , ** other .parameters },
493- )
494- if isinstance (other , abc .Sequence ) and len (other ) == 2 : # noqa: PLR2004
495- expression , definitions = other
496- return DefinedExpression (
497- expression = self .expression * expression ,
498- parameters = {** self .parameters , ** definitions },
499- )
500- return DefinedExpression (
501- expression = self .expression * other ,
502- parameters = self .parameters ,
503- )
505+ subexpressions : dict [sp .Symbol , sp .Expr ] = field (factory = dict )
506+
507+ @_binary_operation (operator .mul )
508+ def __mul__ (self , other ) -> DefinedExpression : ... # type:ignore[empty-body]
509+ @_binary_operation (operator .add )
510+ def __add__ (self , other ) -> DefinedExpression : ... # type:ignore[empty-body]
511+ @_binary_operation (operator .sub )
512+ def __sub__ (self , other ) -> DefinedExpression : ... # type:ignore[empty-body]
513+ @_binary_operation (operator .truediv )
514+ def __truediv__ (self , other ) -> DefinedExpression : ... # type:ignore[empty-body]
515+ @_binary_operation (operator .pow )
516+ def __pow__ (self , other ) -> DefinedExpression : ... # type:ignore[empty-body]
504517
505518
506519def create_mass_symbol_mapping (decay : ThreeBodyDecay ) -> dict [sp .Symbol , float ]:
0 commit comments